Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Sebastian Bindgen
pystencils
Commits
3bcfac93
Commit
3bcfac93
authored
Apr 02, 2018
by
Martin Bauer
Browse files
PEP8 naming
parent
ef924b18
Changes
34
Expand all
Hide whitespace changes
Inline
Side-by-side
__init__.py
View file @
3bcfac93
...
...
@@ -2,7 +2,7 @@ from pystencils.field import Field, FieldType, extractCommonSubexpressions
from
pystencils.data_types
import
TypedSymbol
from
pystencils.slicing
import
makeSlice
from
pystencils.kernelcreation
import
createKernel
,
createIndexedKernel
from
pystencils.display_utils
import
show
C
ode
,
to
D
ot
from
pystencils.display_utils
import
show
_c
ode
,
to
_d
ot
from
pystencils.assignment_collection
import
AssignmentCollection
from
pystencils.assignment
import
Assignment
from
pystencils.sympyextensions
import
SymbolCreator
...
...
@@ -11,7 +11,7 @@ __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions',
'TypedSymbol'
,
'makeSlice'
,
'createKernel'
,
'createIndexedKernel'
,
'show
C
ode'
,
'to
D
ot'
,
'show
_c
ode'
,
'to
_d
ot'
,
'AssignmentCollection'
,
'Assignment'
,
'SymbolCreator'
]
assignment.py
View file @
3bcfac93
# -*- coding: utf-8 -*-
from
sympy.codegen.ast
import
Assignment
from
sympy.printing.latex
import
LatexPrinter
...
...
@@ -11,4 +12,9 @@ def print_assignment_latex(printer, expr):
return
f
"
{
printed_lhs
}
\leftarrow
{
printed_rhs
}
"
def
assignment_str
(
assignment
):
return
f
"
{
assignment
.
lhs
}
←
{
assignment
.
rhs
}
"
Assignment
.
__str__
=
assignment_str
LatexPrinter
.
_print_Assignment
=
print_assignment_latex
assignment_collection/__init__.py
View file @
3bcfac93
from
pystencils.assignment_collection.assignment_collection
import
AssignmentCollection
from
pystencils.assignment_collection.simplificationstrategy
import
SimplificationStrategy
from
pystencils.assignment_collection.simplifications
import
sympy_cse
,
sympy_cse_on_assignment_list
,
\
apply_to_all_assignments
,
apply_on_all_subexpressions
,
subexpression_substitution_in_existing_subexpressions
,
\
subexpression_substitution_in_main_assignments
,
add_subexpressions_for_divisions
__all__
=
[
'AssignmentCollection'
,
'SimplificationStrategy'
,
'sympy_cse'
,
'sympy_cse_on_assignment_list'
,
'apply_to_all_assignments'
,
'apply_on_all_subexpressions'
,
'subexpression_substitution_in_existing_subexpressions'
,
'subexpression_substitution_in_main_assignments'
,
'add_subexpressions_for_divisions'
]
assignment_collection/assignment_collection.py
View file @
3bcfac93
...
...
@@ -61,7 +61,7 @@ class AssignmentCollection:
left hand side symbol (which could have been generated)
"""
if
lhs
is
None
:
lhs
=
sp
.
Dummy
(
)
lhs
=
next
(
self
.
subexpression_symbol_generator
)
eq
=
Assignment
(
lhs
,
rhs
)
self
.
subexpressions
.
append
(
eq
)
if
topological_sort
:
...
...
@@ -135,25 +135,6 @@ class AssignmentCollection:
return
handled_symbols
def
get
(
self
,
symbols
:
Sequence
[
sp
.
Symbol
],
from_main_assignments_only
=
False
)
->
List
[
Assignment
]:
"""Extracts all assignments that have a left hand side that is contained in the symbols parameter.
Args:
symbols: return assignments that have one of these symbols as left hand side
from_main_assignments_only: search only in main assignments (exclude subexpressions)
"""
if
not
hasattr
(
symbols
,
"__len__"
):
symbols
=
set
(
symbols
)
else
:
symbols
=
set
(
symbols
)
if
not
from_main_assignments_only
:
assignments_to_search
=
self
.
all_assignments
else
:
assignments_to_search
=
self
.
main_assignments
return
[
assignment
for
assignment
in
assignments_to_search
if
assignment
.
lhs
in
symbols
]
def
lambdify
(
self
,
symbols
:
Sequence
[
sp
.
Symbol
],
fixed_symbols
:
Optional
[
Dict
[
sp
.
Symbol
,
Any
]]
=
None
,
module
=
None
):
"""Returns a python function to evaluate this equation collection.
...
...
@@ -343,26 +324,25 @@ class AssignmentCollection:
return
"Equation Collection for "
+
","
.
join
([
str
(
eq
.
lhs
)
for
eq
in
self
.
main_assignments
])
def
__str__
(
self
):
result
=
"Subexpressions
\n
"
result
=
"Subexpressions
:
\n
"
for
eq
in
self
.
subexpressions
:
result
+=
str
(
eq
)
+
"
\n
"
result
+=
"Main Assignments
\n
"
result
+=
f
"
\t
{
eq
}
\n
"
result
+=
"Main Assignments
:
\n
"
for
eq
in
self
.
main_assignments
:
result
+=
str
(
eq
)
+
"
\n
"
result
+=
f
"
{
eq
}
\n
"
return
result
class
SymbolGen
:
"""Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
def
__init__
(
self
):
def
__init__
(
self
,
symbol
=
"xi"
):
self
.
_ctr
=
0
self
.
_symbol
=
symbol
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
name
=
f
"
{
self
.
_symbol
}
_
{
self
.
_ctr
}
"
self
.
_ctr
+=
1
return
sp
.
Symbol
(
"xi_"
+
str
(
self
.
_ctr
))
def
next
(
self
):
return
self
.
__next__
()
return
sp
.
Symbol
(
name
)
assignment_collection/simplifications.py
View file @
3bcfac93
import
sympy
as
sp
from
typing
import
Callable
,
List
from
pystencils
import
Assignment
,
AssignmentCollection
from
pystencils.assignment
import
Assignment
from
pystencils.assignment_collection.assignment_collection
import
AssignmentCollection
from
pystencils.sympyextensions
import
subs_additive
def
sympy_cse_on_assignment_list
(
assignments
:
List
[
Assignment
])
->
List
[
Assignment
]:
"""Extracts common subexpressions from a list of assignments."""
ec
=
AssignmentCollection
(
assignments
,
[])
return
sympy_cse
(
ec
).
all_assignments
def
sympy_cse
(
ac
:
AssignmentCollection
)
->
AssignmentCollection
:
"""Searches for common subexpressions inside the equation collection.
...
...
@@ -32,21 +27,28 @@ def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection:
return
ac
.
copy
(
modified_update_equations
,
new_subexpressions
)
def
sympy_cse_on_assignment_list
(
assignments
:
List
[
Assignment
])
->
List
[
Assignment
]:
"""Extracts common subexpressions from a list of assignments."""
ec
=
AssignmentCollection
(
assignments
,
[])
return
sympy_cse
(
ec
).
all_assignments
def
apply_to_all_assignments
(
assignment_collection
:
AssignmentCollection
,
operation
:
Callable
[[
sp
.
Expr
],
sp
.
Expr
])
->
AssignmentCollection
:
"""Applies sympy expand operation to all equations in collection"""
"""Applies sympy expand operation to all equations in collection
.
"""
result
=
[
Assignment
(
eq
.
lhs
,
operation
(
eq
.
rhs
))
for
eq
in
assignment_collection
.
main_assignments
]
return
assignment_collection
.
copy
(
result
)
def
apply_on_all_subexpressions
(
ac
:
AssignmentCollection
,
operation
:
Callable
[[
sp
.
Expr
],
sp
.
Expr
])
->
AssignmentCollection
:
"""Applies the given operation on all subexpressions of the AssignmentCollection."""
result
=
[
Assignment
(
eq
.
lhs
,
operation
(
eq
.
rhs
))
for
eq
in
ac
.
subexpressions
]
return
ac
.
copy
(
ac
.
main_assignments
,
result
)
def
subexpression_substitution_in_existing_subexpressions
(
ac
:
AssignmentCollection
)
->
AssignmentCollection
:
"""Goes through the subexpressions list and replaces the term in the following subexpressions"""
"""Goes through the subexpressions list and replaces the term in the following subexpressions
.
"""
result
=
[]
for
outerCtr
,
s
in
enumerate
(
ac
.
subexpressions
):
new_rhs
=
s
.
rhs
...
...
astnodes.py
View file @
3bcfac93
This diff is collapsed.
Click to expand it.
backends/__init__.py
View file @
3bcfac93
from
.cbackend
import
generateC
from
.cbackend
import
print_c
try
:
from
.dot
import
dot
print
from
.dot
import
print
_dot
from
.llvm
import
generateLLVM
except
ImportError
:
pass
backends/cbackend.py
View file @
3bcfac93
import
sympy
as
sp
from
pystencils.bitoperations
import
xor
,
rightShift
,
leftShift
,
bitwiseAnd
,
bitwiseOr
from
collections
import
namedtuple
from
sympy.core
import
S
from
typing
import
Optional
try
:
from
sympy.utilities.codegen
import
CCodePrinter
except
ImportError
:
from
sympy.printing.ccode
import
C99CodePrinter
as
CCodePrinter
except
ImportError
:
from
sympy.printing.ccode
import
CCodePrinter
# for sympy versions < 1.1
from
collections
import
namedtuple
from
sympy.core.mul
import
_keep_coeff
from
sympy.core
import
S
from
pystencils.bitoperations
import
xor
,
rightShift
,
leftShift
,
bitwiseAnd
,
bitwiseOr
from
pystencils.astnodes
import
Node
,
ResolvedFieldAccess
,
SympyAssignment
from
pystencils.data_types
import
create
T
ype
,
PointerType
,
get
T
ype
OfE
xpression
,
VectorType
,
castFunc
from
pystencils.data_types
import
create
_t
ype
,
PointerType
,
get
_t
ype
_of_e
xpression
,
VectorType
,
castFunc
from
pystencils.backends.simd_instruction_sets
import
selectedInstructionSet
__all__
=
[
'print_c'
]
def
generateC
(
astNode
,
signatureOnly
=
False
):
"""
Prints the abstract syntax tree as C function
def
print_c
(
ast_node
:
Node
,
signature_only
:
bool
=
False
,
use_float_constants
:
Optional
[
bool
]
=
None
)
->
str
:
"""Prints an abstract syntax tree node as C or CUDA code.
This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different create_kernel
functions.
Args:
ast_node:
signature_only:
use_float_constants:
Returns:
C-like code for the ast node and its descendants
"""
fieldTypes
=
set
([
f
.
dtype
for
f
in
astNode
.
fieldsAccessed
])
useFloatConstants
=
createType
(
"double"
)
not
in
fieldTypes
if
use_float_constants
is
None
:
field_types
=
set
(
o
.
field
.
dtype
for
o
in
ast_node
.
atoms
(
ResolvedFieldAccess
))
double
=
create_type
(
'double'
)
use_float_constants
=
double
not
in
field_types
vectorIS
=
selectedInstructionSet
[
'double'
]
printer
=
CBackend
(
constantsAsFloats
=
useFloatConstants
,
signatureOnly
=
signatureOnly
,
vectorInstructionSet
=
vectorIS
)
return
printer
(
astNode
)
vector_is
=
selectedInstructionSet
[
'double'
]
printer
=
CBackend
(
constants_as_floats
=
use_float_constants
,
signature_only
=
signature_only
,
vector_instruction_set
=
vector_is
)
return
printer
(
ast_node
)
def
get
H
eaders
(
ast
N
ode
):
def
get
_h
eaders
(
ast
_n
ode
):
headers
=
set
()
if
hasattr
(
ast
N
ode
,
'headers'
):
headers
.
update
(
ast
N
ode
.
headers
)
elif
isinstance
(
ast
N
ode
,
SympyAssignment
):
if
type
(
get
T
ype
OfE
xpression
(
ast
N
ode
.
rhs
))
is
VectorType
:
if
hasattr
(
ast
_n
ode
,
'headers'
):
headers
.
update
(
ast
_n
ode
.
headers
)
elif
isinstance
(
ast
_n
ode
,
SympyAssignment
):
if
type
(
get
_t
ype
_of_e
xpression
(
ast
_n
ode
.
rhs
))
is
VectorType
:
headers
.
update
(
selectedInstructionSet
[
'double'
][
'headers'
])
for
a
in
ast
N
ode
.
args
:
for
a
in
ast
_n
ode
.
args
:
if
isinstance
(
a
,
Node
):
headers
.
update
(
get
H
eaders
(
a
))
headers
.
update
(
get
_h
eaders
(
a
))
return
headers
...
...
@@ -48,10 +62,11 @@ def getHeaders(astNode):
class
CustomCppCode
(
Node
):
def
__init__
(
self
,
code
,
symbolsRead
,
symbolsDefined
):
def
__init__
(
self
,
code
,
symbols_read
,
symbols_defined
,
parent
=
None
):
super
(
CustomCppCode
,
self
).
__init__
(
parent
=
parent
)
self
.
_code
=
"
\n
"
+
code
self
.
_symbolsRead
=
set
(
symbols
R
ead
)
self
.
_symbolsDefined
=
set
(
symbols
D
efined
)
self
.
_symbolsRead
=
set
(
symbols
_r
ead
)
self
.
_symbolsDefined
=
set
(
symbols
_d
efined
)
self
.
headers
=
[]
@
property
...
...
@@ -63,75 +78,78 @@ class CustomCppCode(Node):
return
[]
@
property
def
symbols
D
efined
(
self
):
def
symbols
_d
efined
(
self
):
return
self
.
_symbolsDefined
@
property
def
undefined
S
ymbols
(
self
):
return
self
.
symbols
D
efined
-
self
.
_symbolsRead
def
undefined
_s
ymbols
(
self
):
return
self
.
symbols
_d
efined
-
self
.
_symbolsRead
class
PrintNode
(
CustomCppCode
):
def
__init__
(
self
,
symbolToPrint
):
code
=
'
\n
std::cout << "%s = " << %s << std::endl;
\n
'
%
(
symbolToPrint
.
name
,
symbolToPrint
.
name
)
super
(
PrintNode
,
self
).
__init__
(
code
,
symbolsRead
=
[
symbolToPrint
],
symbolsDefined
=
set
())
# noinspection SpellCheckingInspection
def
__init__
(
self
,
symbol_to_print
):
code
=
'
\n
std::cout << "%s = " << %s << std::endl;
\n
'
%
(
symbol_to_print
.
name
,
symbol_to_print
.
name
)
super
(
PrintNode
,
self
).
__init__
(
code
,
symbols_read
=
[
symbol_to_print
],
symbols_defined
=
set
())
self
.
headers
.
append
(
"<iostream>"
)
# ------------------------------------------- Printer ------------------------------------------------------------------
class
CBackend
(
object
):
# noinspection PyPep8Naming
class
CBackend
:
def
__init__
(
self
,
constantsAsFloats
=
False
,
sympyPrinter
=
None
,
signatureOnly
=
False
,
vectorInstructionSet
=
None
):
if
sympyPrinter
is
None
:
self
.
sympyPrinter
=
CustomSympyPrinter
(
constantsAsFloats
)
if
vectorInstructionSet
is
not
None
:
self
.
sympyPrinter
=
VectorizedCustomSympyPrinter
(
vectorInstructionSet
,
constantsAsFloats
)
def
__init__
(
self
,
constants_as_floats
=
False
,
sympy_printer
=
None
,
signature_only
=
False
,
vector_instruction_set
=
None
):
if
sympy_printer
is
None
:
self
.
sympyPrinter
=
CustomSympyPrinter
(
constants_as_floats
)
if
vector_instruction_set
is
not
None
:
self
.
sympyPrinter
=
VectorizedCustomSympyPrinter
(
vector_instruction_set
,
constants_as_floats
)
else
:
self
.
sympyPrinter
=
CustomSympyPrinter
(
constants
AsF
loats
)
self
.
sympyPrinter
=
CustomSympyPrinter
(
constants
_as_f
loats
)
else
:
self
.
sympyPrinter
=
sympy
P
rinter
self
.
sympyPrinter
=
sympy
_p
rinter
self
.
_vectorInstructionSet
=
vector
I
nstruction
S
et
self
.
_vectorInstructionSet
=
vector
_i
nstruction
_s
et
self
.
_indent
=
" "
self
.
_signatureOnly
=
signature
O
nly
self
.
_signatureOnly
=
signature
_o
nly
def
__call__
(
self
,
node
):
prev
I
s
=
VectorType
.
instructionSet
prev
_i
s
=
VectorType
.
instructionSet
VectorType
.
instructionSet
=
self
.
_vectorInstructionSet
result
=
str
(
self
.
_print
(
node
))
VectorType
.
instructionSet
=
prev
I
s
VectorType
.
instructionSet
=
prev
_i
s
return
result
def
_print
(
self
,
node
):
for
cls
in
type
(
node
).
__mro__
:
method
N
ame
=
"_print_"
+
cls
.
__name__
if
hasattr
(
self
,
method
N
ame
):
return
getattr
(
self
,
method
N
ame
)(
node
)
raise
NotImplementedError
(
"CBackend does not support node of type "
+
cls
.
__name__
)
method
_n
ame
=
"_print_"
+
cls
.
__name__
if
hasattr
(
self
,
method
_n
ame
):
return
getattr
(
self
,
method
_n
ame
)(
node
)
raise
NotImplementedError
(
"CBackend does not support node of type "
+
str
(
type
(
node
))
)
def
_print_KernelFunction
(
self
,
node
):
function
A
rguments
=
[
"%s %s"
%
(
str
(
s
.
dtype
),
s
.
name
)
for
s
in
node
.
parameters
]
func
D
eclaration
=
"FUNC_PREFIX void %s(%s)"
%
(
node
.
functionName
,
", "
.
join
(
function
A
rguments
))
function
_a
rguments
=
[
"%s %s"
%
(
str
(
s
.
dtype
),
s
.
name
)
for
s
in
node
.
parameters
]
func
_d
eclaration
=
"FUNC_PREFIX void %s(%s)"
%
(
node
.
functionName
,
", "
.
join
(
function
_a
rguments
))
if
self
.
_signatureOnly
:
return
func
D
eclaration
return
func
_d
eclaration
body
=
self
.
_print
(
node
.
body
)
return
func
D
eclaration
+
"
\n
"
+
body
return
func
_d
eclaration
+
"
\n
"
+
body
def
_print_Block
(
self
,
node
):
block
C
ontents
=
"
\n
"
.
join
([
self
.
_print
(
child
)
for
child
in
node
.
args
])
return
"{
\n
%s
\n
}"
%
(
self
.
_indent
+
self
.
_indent
.
join
(
block
C
ontents
.
splitlines
(
True
)))
block
_c
ontents
=
"
\n
"
.
join
([
self
.
_print
(
child
)
for
child
in
node
.
args
])
return
"{
\n
%s
\n
}"
%
(
self
.
_indent
+
self
.
_indent
.
join
(
block
_c
ontents
.
splitlines
(
True
)))
def
_print_PragmaBlock
(
self
,
node
):
return
"%s
\n
%s"
%
(
node
.
pragmaLine
,
self
.
_print_Block
(
node
))
def
_print_LoopOverCoordinate
(
self
,
node
):
counter
Var
=
node
.
loop
C
ounter
N
ame
start
=
"int %s = %s"
%
(
counter
Var
,
self
.
sympyPrinter
.
doprint
(
node
.
start
))
condition
=
"%s < %s"
%
(
counter
Var
,
self
.
sympyPrinter
.
doprint
(
node
.
stop
))
update
=
"%s += %s"
%
(
counter
Var
,
self
.
sympyPrinter
.
doprint
(
node
.
step
),)
counter
_symbol
=
node
.
loop
_c
ounter
_n
ame
start
=
"int %s = %s"
%
(
counter
_symbol
,
self
.
sympyPrinter
.
doprint
(
node
.
start
))
condition
=
"%s < %s"
%
(
counter
_symbol
,
self
.
sympyPrinter
.
doprint
(
node
.
stop
))
update
=
"%s += %s"
%
(
counter
_symbol
,
self
.
sympyPrinter
.
doprint
(
node
.
step
),)
loopStr
=
"for (%s; %s; %s)"
%
(
start
,
condition
,
update
)
prefix
=
"
\n
"
.
join
(
node
.
prefixLines
)
...
...
@@ -140,12 +158,12 @@ class CBackend(object):
return
"%s%s
\n
%s"
%
(
prefix
,
loopStr
,
self
.
_print
(
node
.
body
))
def
_print_SympyAssignment
(
self
,
node
):
if
node
.
is
D
eclaration
:
dtype
=
"const "
+
str
(
node
.
lhs
.
dtype
)
+
" "
if
node
.
is
C
onst
else
str
(
node
.
lhs
.
dtype
)
+
" "
return
"%s %s = %s;"
%
(
dtype
,
self
.
sympyPrinter
.
doprint
(
node
.
lhs
),
self
.
sympyPrinter
.
doprint
(
node
.
rhs
))
if
node
.
is
_d
eclaration
:
d
ata_
type
=
"const "
+
str
(
node
.
lhs
.
dtype
)
+
" "
if
node
.
is
_c
onst
else
str
(
node
.
lhs
.
dtype
)
+
" "
return
"%s %s = %s;"
%
(
d
ata_
type
,
self
.
sympyPrinter
.
doprint
(
node
.
lhs
),
self
.
sympyPrinter
.
doprint
(
node
.
rhs
))
else
:
lhs
T
ype
=
get
T
ype
OfE
xpression
(
node
.
lhs
)
if
type
(
lhs
T
ype
)
is
VectorType
and
node
.
lhs
.
func
==
castFunc
:
lhs
_t
ype
=
get
_t
ype
_of_e
xpression
(
node
.
lhs
)
if
type
(
lhs
_t
ype
)
is
VectorType
and
node
.
lhs
.
func
==
castFunc
:
return
self
.
_vectorInstructionSet
[
'storeU'
].
format
(
"&"
+
self
.
sympyPrinter
.
doprint
(
node
.
lhs
.
args
[
0
]),
self
.
sympyPrinter
.
doprint
(
node
.
rhs
))
+
';'
else
:
...
...
@@ -153,31 +171,33 @@ class CBackend(object):
def
_print_TemporaryMemoryAllocation
(
self
,
node
):
return
"%s %s = new %s[%s];"
%
(
node
.
symbol
.
dtype
,
self
.
sympyPrinter
.
doprint
(
node
.
symbol
.
name
),
node
.
symbol
.
dtype
.
base
T
ype
,
self
.
sympyPrinter
.
doprint
(
node
.
size
))
node
.
symbol
.
dtype
.
base
_t
ype
,
self
.
sympyPrinter
.
doprint
(
node
.
size
))
def
_print_TemporaryMemoryFree
(
self
,
node
):
return
"delete [] %s;"
%
(
self
.
sympyPrinter
.
doprint
(
node
.
symbol
.
name
),)
def
_print_CustomCppCode
(
self
,
node
):
@
staticmethod
def
_print_CustomCppCode
(
node
):
return
node
.
code
def
_print_Conditional
(
self
,
node
):
condition
E
xpr
=
self
.
sympyPrinter
.
doprint
(
node
.
conditionExpr
)
true
B
lock
=
self
.
_print_Block
(
node
.
trueBlock
)
result
=
"if (%s)
\n
%s "
%
(
condition
E
xpr
,
true
B
lock
)
condition
_e
xpr
=
self
.
sympyPrinter
.
doprint
(
node
.
conditionExpr
)
true
_b
lock
=
self
.
_print_Block
(
node
.
trueBlock
)
result
=
"if (%s)
\n
%s "
%
(
condition
_e
xpr
,
true
_b
lock
)
if
node
.
falseBlock
:
false
B
lock
=
self
.
_print_Block
(
node
.
falseBlock
)
result
+=
"else "
+
false
B
lock
false
_b
lock
=
self
.
_print_Block
(
node
.
falseBlock
)
result
+=
"else "
+
false
_b
lock
return
result
# ------------------------------------------ Helper function & classes -------------------------------------------------
# noinspection PyPep8Naming
class
CustomSympyPrinter
(
CCodePrinter
):
def
__init__
(
self
,
constants
AsF
loats
=
False
):
self
.
_constantsAsFloats
=
constants
AsF
loats
def
__init__
(
self
,
constants
_as_f
loats
=
False
):
self
.
_constantsAsFloats
=
constants
_as_f
loats
super
(
CustomSympyPrinter
,
self
).
__init__
()
def
_print_Pow
(
self
,
expr
):
...
...
@@ -210,7 +230,7 @@ class CustomSympyPrinter(CCodePrinter):
return
res
def
_print_Function
(
self
,
expr
):
function
M
ap
=
{
function
_m
ap
=
{
xor
:
'^'
,
rightShift
:
'>>'
,
leftShift
:
'<<'
,
...
...
@@ -218,33 +238,34 @@ class CustomSympyPrinter(CCodePrinter):
bitwiseAnd
:
'&'
,
}
if
expr
.
func
==
castFunc
:
arg
,
type
=
expr
.
args
return
"*((%s)(& %s))"
%
(
PointerType
(
type
),
self
.
_print
(
arg
))
elif
expr
.
func
in
function
M
ap
:
return
"(%s %s %s)"
%
(
self
.
_print
(
expr
.
args
[
0
]),
function
M
ap
[
expr
.
func
],
self
.
_print
(
expr
.
args
[
1
]))
arg
,
data_
type
=
expr
.
args
return
"*((%s)(& %s))"
%
(
PointerType
(
data_
type
),
self
.
_print
(
arg
))
elif
expr
.
func
in
function
_m
ap
:
return
"(%s %s %s)"
%
(
self
.
_print
(
expr
.
args
[
0
]),
function
_m
ap
[
expr
.
func
],
self
.
_print
(
expr
.
args
[
1
]))
else
:
return
super
(
CustomSympyPrinter
,
self
).
_print_Function
(
expr
)
# noinspection PyPep8Naming
class
VectorizedCustomSympyPrinter
(
CustomSympyPrinter
):
SummandInfo
=
namedtuple
(
"SummandInfo"
,
[
'sign'
,
'term'
])
def
__init__
(
self
,
instruction
S
et
,
constants
AsF
loats
=
False
):
super
(
VectorizedCustomSympyPrinter
,
self
).
__init__
(
constants
AsF
loats
)
self
.
instructionSet
=
instruction
S
et
def
__init__
(
self
,
instruction
_s
et
,
constants
_as_f
loats
=
False
):
super
(
VectorizedCustomSympyPrinter
,
self
).
__init__
(
constants
_as_f
loats
)
self
.
instructionSet
=
instruction
_s
et
def
_scalarFallback
(
self
,
func
N
ame
,
expr
,
*
args
,
**
kwargs
):
expr
T
ype
=
get
T
ype
OfE
xpression
(
expr
)
if
type
(
expr
T
ype
)
is
not
VectorType
:
return
getattr
(
super
(
VectorizedCustomSympyPrinter
,
self
),
func
N
ame
)(
expr
,
*
args
,
**
kwargs
)
def
_scalarFallback
(
self
,
func
_n
ame
,
expr
,
*
args
,
**
kwargs
):
expr
_t
ype
=
get
_t
ype
_of_e
xpression
(
expr
)
if
type
(
expr
_t
ype
)
is
not
VectorType
:
return
getattr
(
super
(
VectorizedCustomSympyPrinter
,
self
),
func
_n
ame
)(
expr
,
*
args
,
**
kwargs
)
else
:
assert
self
.
instructionSet
[
'width'
]
==
expr
T
ype
.
width
assert
self
.
instructionSet
[
'width'
]
==
expr
_t
ype
.
width
return
None
def
_print_Function
(
self
,
expr
):
if
expr
.
func
==
castFunc
:
arg
,
dtype
=
expr
.
args
if
type
(
dtype
)
is
VectorType
:
arg
,
d
ata_
type
=
expr
.
args
if
type
(
d
ata_
type
)
is
VectorType
:
if
type
(
arg
)
is
ResolvedFieldAccess
:
return
self
.
instructionSet
[
'loadU'
].
format
(
"& "
+
self
.
_print
(
arg
))
else
:
...
...
@@ -257,10 +278,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if
result
:
return
result
arg
S
trings
=
[
self
.
_print
(
a
)
for
a
in
expr
.
args
]
assert
len
(
arg
S
trings
)
>
0
result
=
arg
S
trings
[
0
]
for
item
in
arg
S
trings
[
1
:]:
arg
_s
trings
=
[
self
.
_print
(
a
)
for
a
in
expr
.
args
]
assert
len
(
arg
_s
trings
)
>
0
result
=
arg
_s
trings
[
0
]
for
item
in
arg
_s
trings
[
1
:]:
result
=
self
.
instructionSet
[
'&'
].
format
(
result
,
item
)
return
result
...
...
@@ -269,10 +290,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if
result
:
return
result
arg
S
trings
=
[
self
.
_print
(
a
)
for
a
in
expr
.
args
]
assert
len
(
arg
S
trings
)
>
0
result
=
arg
S
trings
[
0
]
for
item
in
arg
S
trings
[
1
:]:
arg
_s
trings
=
[
self
.
_print
(
a
)
for
a
in
expr
.
args
]
assert
len
(
arg
_s
trings
)
>
0
result
=
arg
_s
trings
[
0
]
for
item
in
arg
_s
trings
[
1
:]:
result
=
self
.
instructionSet
[
'|'
].
format
(
result
,
item
)
return
result
...
...
@@ -284,7 +305,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
summands
=
[]
for
term
in
expr
.
args
:
if
term
.
func
==
sp
.
Mul
:
sign
,
t
=
self
.
_print_Mul
(
term
,
inside
A
dd
=
True
)
sign
,
t
=
self
.
_print_Mul
(
term
,
inside
_a
dd
=
True
)
else
:
t
=
self
.
_print
(
term
)
sign
=
1
...
...
@@ -318,7 +339,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
else
:
raise
ValueError
(
"Generic exponential not supported"
)
def
_print_Mul
(
self
,
expr
,
insideAdd
=
False
):
def
_print_Mul
(
self
,
expr
,
inside_add
=
False
):
# noinspection PyProtectedMember
from
sympy.core.mul
import
_keep_coeff
result
=
self
.
_scalarFallback
(
'_print_Mul'
,
expr
)
if
result
:
return
result
...
...
@@ -359,7 +383,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
denominator_str
=
self
.
instructionSet
[
'*'
].
format
(
denominator_str
,
item
)
result
=
self
.
instructionSet
[
'/'
].
format
(
result
,
denominator_str
)
if
inside
A
dd
:
if
inside
_a
dd
:
return
sign
,
result
else
:
if
sign
<
0
:
...
...
@@ -384,7 +408,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if
result
:
return
result
if
expr
.
args
[
-
1
].
cond
!=
True
:
if
expr
.
args
[
-
1
].
cond
.
args
[
0
]
is
not
sp
.
sympify
(
True
)
:
# We need the last conditional to be a True, otherwise the resulting
# function may not return a result.
raise
ValueError
(
"All Piecewise expressions must contain an "
...
...
@@ -395,5 +419,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result
=
self
.
_print
(
expr
.
args
[
-
1
][
0
])
for
trueExpr
,
condition
in
reversed
(
expr
.
args
[:
-
1
]):
# noinspection SpellCheckingInspection
result
=
self
.
instructionSet
[
'blendv'
].
format
(
result
,
self
.
_print
(
trueExpr
),
self
.
_print
(
condition
))
return
result
backends/dot.py
View file @
3bcfac93
...
...
@@ -3,13 +3,14 @@ from graphviz import Digraph, lang
import
graphviz
# noinspection PyPep8Naming
class
DotPrinter
(
Printer
):
"""
A printer which converts ast to DOT (graph description language).
"""
def
__init__
(
self
,
node
ToStrF
unction
,
full
,
**
kwargs
):
def
__init__
(
self
,
node
_to_str_f
unction
,
full
,
**
kwargs
):
super
(
DotPrinter
,
self
).
__init__
()
self
.
_nodeToStrFunction
=
node
ToStrF
unction
self
.
_nodeToStrFunction
=
node
_to_str_f
unction
self
.
full
=
full
self
.
dot
=
Digraph
(
**
kwargs
)
self
.
dot
.
quote_edge
=
lang
.
quote
...
...
@@ -33,7 +34,8 @@ class DotPrinter(Printer):
self
.
dot
.
edge
(
str
(
id
(
block
)),
str
(
id
(
node
)))