Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Sebastian Bindgen
pystencils
Commits
4a7299f1
Commit
4a7299f1
authored
Apr 10, 2018
by
Martin Bauer
Browse files
Rest of PEP8 renaming
parent
7acdc31c
Changes
40
Hide whitespace changes
Inline
Side-by-side
alignedarray.py
View file @
4a7299f1
...
...
@@ -5,9 +5,9 @@ def aligned_empty(shape, byte_alignment=32, dtype=np.float64, byte_offset=0, ord
"""
Creates an aligned empty numpy array
:param shape: size of the array
:param byte_alignment: alignment in bytes, for the start address of the array holds (a % byte
A
lignment) == 0
:param byte_alignment: alignment in bytes, for the start address of the array holds (a % byte
_a
lignment) == 0
:param dtype: numpy data type
:param byte_offset: offset in bytes for position that should be aligned i.e. (a+byte_offset) % byte
A
lignment == 0
:param byte_offset: offset in bytes for position that should be aligned i.e. (a+byte_offset) % byte
_a
lignment == 0
typically used to align first inner cell instead of ghost layer
:param order: storage linearization order
:param align_inner_coordinate: if True, the start of the innermost coordinate lines are aligned as well
...
...
assignment_collection/assignment_collection.py
View file @
4a7299f1
...
...
@@ -111,7 +111,7 @@ class AssignmentCollection:
def
dependent_symbols
(
self
,
symbols
:
Iterable
[
sp
.
Symbol
])
->
Set
[
sp
.
Symbol
]:
"""Returns all symbols that depend on one of the passed symbols.
A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- some
E
xpression(b)' i.e. when
A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- some
_e
xpression(b)' i.e. when
'b' is required to compute 'a'.
"""
...
...
@@ -217,18 +217,18 @@ class AssignmentCollection:
substitution_dict
=
{}
processed_other_subexpression_equations
=
[]
for
other
S
ubexpression
E
q
in
other
.
subexpressions
:
if
other
S
ubexpression
E
q
.
lhs
in
own_subexpression_symbols
:
if
other
S
ubexpression
E
q
.
rhs
==
own_subexpression_symbols
[
other
S
ubexpression
E
q
.
lhs
]:
for
other
_s
ubexpression
_e
q
in
other
.
subexpressions
:
if
other
_s
ubexpression
_e
q
.
lhs
in
own_subexpression_symbols
:
if
other
_s
ubexpression
_e
q
.
rhs
==
own_subexpression_symbols
[
other
_s
ubexpression
_e
q
.
lhs
]:
continue
# exact the same subexpression equation exists already
else
:
# different definition - a new name has to be introduced
new_lhs
=
next
(
self
.
subexpression_symbol_generator
)
new_eq
=
Assignment
(
new_lhs
,
fast_subs
(
other
S
ubexpression
E
q
.
rhs
,
substitution_dict
))
new_eq
=
Assignment
(
new_lhs
,
fast_subs
(
other
_s
ubexpression
_e
q
.
rhs
,
substitution_dict
))
processed_other_subexpression_equations
.
append
(
new_eq
)
substitution_dict
[
other
S
ubexpression
E
q
.
lhs
]
=
new_lhs
substitution_dict
[
other
_s
ubexpression
_e
q
.
lhs
]
=
new_lhs
else
:
processed_other_subexpression_equations
.
append
(
fast_subs
(
other
S
ubexpression
E
q
,
substitution_dict
))
processed_other_subexpression_equations
.
append
(
fast_subs
(
other
_s
ubexpression
_e
q
,
substitution_dict
))
processed_other_main_assignments
=
[
fast_subs
(
eq
,
substitution_dict
)
for
eq
in
other
.
main_assignments
]
return
self
.
copy
(
self
.
main_assignments
+
processed_other_main_assignments
,
...
...
assignment_collection/simplifications.py
View file @
4a7299f1
...
...
@@ -50,10 +50,10 @@ def apply_on_all_subexpressions(ac: AssignmentCollection,
def
subexpression_substitution_in_existing_subexpressions
(
ac
:
AssignmentCollection
)
->
AssignmentCollection
:
"""Goes through the subexpressions list and replaces the term in the following subexpressions."""
result
=
[]
for
outer
C
tr
,
s
in
enumerate
(
ac
.
subexpressions
):
for
outer
_c
tr
,
s
in
enumerate
(
ac
.
subexpressions
):
new_rhs
=
s
.
rhs
for
inner
C
tr
in
range
(
outer
C
tr
):
sub_expr
=
ac
.
subexpressions
[
inner
C
tr
]
for
inner
_c
tr
in
range
(
outer
_c
tr
):
sub_expr
=
ac
.
subexpressions
[
inner
_c
tr
]
new_rhs
=
subs_additive
(
new_rhs
,
sub_expr
.
lhs
,
sub_expr
.
rhs
,
required_match_replacement
=
1.0
)
new_rhs
=
new_rhs
.
subs
(
sub_expr
.
rhs
,
sub_expr
.
lhs
)
result
.
append
(
Assignment
(
s
.
lhs
,
new_rhs
))
...
...
@@ -66,8 +66,8 @@ def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) ->
result
=
[]
for
s
in
ac
.
main_assignments
:
new_rhs
=
s
.
rhs
for
sub
E
xpr
in
ac
.
subexpressions
:
new_rhs
=
subs_additive
(
new_rhs
,
sub
E
xpr
.
lhs
,
sub
E
xpr
.
rhs
,
required_match_replacement
=
1.0
)
for
sub
_e
xpr
in
ac
.
subexpressions
:
new_rhs
=
subs_additive
(
new_rhs
,
sub
_e
xpr
.
lhs
,
sub
_e
xpr
.
rhs
,
required_match_replacement
=
1.0
)
result
.
append
(
Assignment
(
s
.
lhs
,
new_rhs
))
return
ac
.
copy
(
result
)
...
...
@@ -91,5 +91,5 @@ def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentColl
search_divisors
(
eq
.
rhs
)
new_symbol_gen
=
ac
.
subexpression_symbol_generator
substitutions
=
{
divisor
:
new
S
ymbol
for
new
S
ymbol
,
divisor
in
zip
(
new_symbol_gen
,
divisors
)}
substitutions
=
{
divisor
:
new
_s
ymbol
for
new
_s
ymbol
,
divisor
in
zip
(
new_symbol_gen
,
divisors
)}
return
ac
.
new_with_substitutions
(
substitutions
,
True
)
astnodes.py
View file @
4a7299f1
...
...
@@ -64,7 +64,7 @@ class Conditional(Node):
super
(
Conditional
,
self
).
__init__
(
parent
=
None
)
assert
condition_expr
.
is_Boolean
or
condition_expr
.
is_Relational
self
.
condition
E
xpr
=
condition_expr
self
.
condition
_e
xpr
=
condition_expr
def
handle_child
(
c
):
if
c
is
None
:
...
...
@@ -74,20 +74,20 @@ class Conditional(Node):
c
.
parent
=
self
return
c
self
.
true
B
lock
=
handle_child
(
true_block
)
self
.
false
B
lock
=
handle_child
(
false_block
)
self
.
true
_b
lock
=
handle_child
(
true_block
)
self
.
false
_b
lock
=
handle_child
(
false_block
)
def
subs
(
self
,
*
args
,
**
kwargs
):
self
.
true
B
lock
.
subs
(
*
args
,
**
kwargs
)
if
self
.
false
B
lock
:
self
.
false
B
lock
.
subs
(
*
args
,
**
kwargs
)
self
.
condition
E
xpr
=
self
.
condition
E
xpr
.
subs
(
*
args
,
**
kwargs
)
self
.
true
_b
lock
.
subs
(
*
args
,
**
kwargs
)
if
self
.
false
_b
lock
:
self
.
false
_b
lock
.
subs
(
*
args
,
**
kwargs
)
self
.
condition
_e
xpr
=
self
.
condition
_e
xpr
.
subs
(
*
args
,
**
kwargs
)
@
property
def
args
(
self
):
result
=
[
self
.
condition
E
xpr
,
self
.
true
B
lock
]
if
self
.
false
B
lock
:
result
.
append
(
self
.
false
B
lock
)
result
=
[
self
.
condition
_e
xpr
,
self
.
true
_b
lock
]
if
self
.
false
_b
lock
:
result
.
append
(
self
.
false
_b
lock
)
return
result
@
property
...
...
@@ -96,17 +96,17 @@ class Conditional(Node):
@
property
def
undefined_symbols
(
self
):
result
=
self
.
true
B
lock
.
undefined_symbols
if
self
.
false
B
lock
:
result
.
update
(
self
.
false
B
lock
.
undefined_symbols
)
result
.
update
(
self
.
condition
E
xpr
.
atoms
(
sp
.
Symbol
))
result
=
self
.
true
_b
lock
.
undefined_symbols
if
self
.
false
_b
lock
:
result
.
update
(
self
.
false
_b
lock
.
undefined_symbols
)
result
.
update
(
self
.
condition
_e
xpr
.
atoms
(
sp
.
Symbol
))
return
result
def
__str__
(
self
):
return
'if:({!s}) '
.
format
(
self
.
condition
E
xpr
)
return
'if:({!s}) '
.
format
(
self
.
condition
_e
xpr
)
def
__repr__
(
self
):
return
'if:({!r}) '
.
format
(
self
.
condition
E
xpr
)
return
'if:({!r}) '
.
format
(
self
.
condition
_e
xpr
)
class
KernelFunction
(
Node
):
...
...
@@ -116,39 +116,39 @@ class KernelFunction(Node):
from
pystencils.transformations
import
symbol_name_to_variable_name
self
.
name
=
name
self
.
dtype
=
dtype
self
.
is
F
ield
PtrA
rgument
=
False
self
.
is
F
ield
S
hape
A
rgument
=
False
self
.
is
F
ield
S
tride
A
rgument
=
False
self
.
is
F
ield
A
rgument
=
False
self
.
is
_f
ield
_ptr_a
rgument
=
False
self
.
is
_f
ield
_s
hape
_a
rgument
=
False
self
.
is
_f
ield
_s
tride
_a
rgument
=
False
self
.
is
_f
ield
_a
rgument
=
False
self
.
field_name
=
""
self
.
coordinate
=
None
self
.
symbol
=
symbol
if
name
.
startswith
(
Field
.
DATA_PREFIX
):
self
.
is
F
ield
PtrA
rgument
=
True
self
.
is
F
ield
A
rgument
=
True
self
.
is
_f
ield
_ptr_a
rgument
=
True
self
.
is
_f
ield
_a
rgument
=
True
self
.
field_name
=
name
[
len
(
Field
.
DATA_PREFIX
):]
elif
name
.
startswith
(
Field
.
SHAPE_PREFIX
):
self
.
is
F
ield
S
hape
A
rgument
=
True
self
.
is
F
ield
A
rgument
=
True
self
.
is
_f
ield
_s
hape
_a
rgument
=
True
self
.
is
_f
ield
_a
rgument
=
True
self
.
field_name
=
name
[
len
(
Field
.
SHAPE_PREFIX
):]
elif
name
.
startswith
(
Field
.
STRIDE_PREFIX
):
self
.
is
F
ield
S
tride
A
rgument
=
True
self
.
is
F
ield
A
rgument
=
True
self
.
is
_f
ield
_s
tride
_a
rgument
=
True
self
.
is
_f
ield
_a
rgument
=
True
self
.
field_name
=
name
[
len
(
Field
.
STRIDE_PREFIX
):]
self
.
field
=
None
if
self
.
is
F
ield
A
rgument
:
if
self
.
is
_f
ield
_a
rgument
:
field_map
=
{
symbol_name_to_variable_name
(
f
.
name
):
f
for
f
in
kernel_function_node
.
fields_accessed
}
self
.
field
=
field_map
[
self
.
field_name
]
def
__lt__
(
self
,
other
):
def
score
(
l
):
if
l
.
is
F
ield
PtrA
rgument
:
if
l
.
is
_f
ield
_ptr_a
rgument
:
return
-
4
elif
l
.
is
F
ield
S
hape
A
rgument
:
elif
l
.
is
_f
ield
_s
hape
_a
rgument
:
return
-
3
elif
l
.
is
F
ield
S
tride
A
rgument
:
elif
l
.
is
_f
ield
_s
tride
_a
rgument
:
return
-
2
return
0
...
...
@@ -298,12 +298,12 @@ class Block(Node):
class
PragmaBlock
(
Block
):
def
__init__
(
self
,
pragma_line
,
nodes
):
super
(
PragmaBlock
,
self
).
__init__
(
nodes
)
self
.
pragma
L
ine
=
pragma_line
self
.
pragma
_l
ine
=
pragma_line
for
n
in
nodes
:
n
.
parent
=
self
def
__repr__
(
self
):
return
self
.
pragma
L
ine
return
self
.
pragma
_l
ine
class
LoopOverCoordinate
(
Node
):
...
...
@@ -313,16 +313,16 @@ class LoopOverCoordinate(Node):
super
(
LoopOverCoordinate
,
self
).
__init__
(
parent
=
None
)
self
.
body
=
body
body
.
parent
=
self
self
.
coordinate
ToL
oop
O
ver
=
coordinate_to_loop_over
self
.
coordinate
_to_l
oop
_o
ver
=
coordinate_to_loop_over
self
.
start
=
start
self
.
stop
=
stop
self
.
step
=
step
self
.
body
.
parent
=
self
self
.
prefix
L
ines
=
[]
self
.
prefix
_l
ines
=
[]
def
new_loop_with_different_body
(
self
,
new_body
):
result
=
LoopOverCoordinate
(
new_body
,
self
.
coordinate
ToL
oop
O
ver
,
self
.
start
,
self
.
stop
,
self
.
step
)
result
.
prefix
L
ines
=
[
l
for
l
in
self
.
prefix
L
ines
]
result
=
LoopOverCoordinate
(
new_body
,
self
.
coordinate
_to_l
oop
_o
ver
,
self
.
start
,
self
.
stop
,
self
.
step
)
result
.
prefix
_l
ines
=
[
l
for
l
in
self
.
prefix
_l
ines
]
return
result
def
subs
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -359,9 +359,9 @@ class LoopOverCoordinate(Node):
@
property
def
undefined_symbols
(
self
):
result
=
self
.
body
.
undefined_symbols
for
possible
S
ymbol
in
[
self
.
start
,
self
.
stop
,
self
.
step
]:
if
isinstance
(
possible
S
ymbol
,
Node
)
or
isinstance
(
possible
S
ymbol
,
sp
.
Basic
):
result
.
update
(
possible
S
ymbol
.
atoms
(
sp
.
Symbol
))
for
possible
_s
ymbol
in
[
self
.
start
,
self
.
stop
,
self
.
step
]:
if
isinstance
(
possible
_s
ymbol
,
Node
)
or
isinstance
(
possible
_s
ymbol
,
sp
.
Basic
):
result
.
update
(
possible
_s
ymbol
.
atoms
(
sp
.
Symbol
))
return
result
-
{
self
.
loop_counter_symbol
}
@
staticmethod
...
...
@@ -370,7 +370,7 @@ class LoopOverCoordinate(Node):
@
property
def
loop_counter_name
(
self
):
return
LoopOverCoordinate
.
get_loop_counter_name
(
self
.
coordinate
ToL
oop
O
ver
)
return
LoopOverCoordinate
.
get_loop_counter_name
(
self
.
coordinate
_to_l
oop
_o
ver
)
@
staticmethod
def
is_loop_counter_symbol
(
symbol
):
...
...
@@ -388,7 +388,7 @@ class LoopOverCoordinate(Node):
@
property
def
loop_counter_symbol
(
self
):
return
LoopOverCoordinate
.
get_loop_counter_symbol
(
self
.
coordinate
ToL
oop
O
ver
)
return
LoopOverCoordinate
.
get_loop_counter_symbol
(
self
.
coordinate
_to_l
oop
_o
ver
)
@
property
def
is_outermost_loop
(
self
):
...
...
@@ -414,25 +414,25 @@ class LoopOverCoordinate(Node):
class
SympyAssignment
(
Node
):
def
__init__
(
self
,
lhs_symbol
,
rhs_expr
,
is_const
=
True
):
super
(
SympyAssignment
,
self
).
__init__
(
parent
=
None
)
self
.
_lhs
S
ymbol
=
lhs_symbol
self
.
_lhs
_s
ymbol
=
lhs_symbol
self
.
rhs
=
rhs_expr
self
.
_is
D
eclaration
=
True
is_cast
=
self
.
_lhs
S
ymbol
.
func
==
cast_func
if
isinstance
(
self
.
_lhs
S
ymbol
,
Field
.
Access
)
or
isinstance
(
self
.
_lhs
S
ymbol
,
ResolvedFieldAccess
)
or
is_cast
:
self
.
_is
D
eclaration
=
False
self
.
_is
C
onst
=
is_const
self
.
_is
_d
eclaration
=
True
is_cast
=
self
.
_lhs
_s
ymbol
.
func
==
cast_func
if
isinstance
(
self
.
_lhs
_s
ymbol
,
Field
.
Access
)
or
isinstance
(
self
.
_lhs
_s
ymbol
,
ResolvedFieldAccess
)
or
is_cast
:
self
.
_is
_d
eclaration
=
False
self
.
_is
_c
onst
=
is_const
@
property
def
lhs
(
self
):
return
self
.
_lhs
S
ymbol
return
self
.
_lhs
_s
ymbol
@
lhs
.
setter
def
lhs
(
self
,
new_value
):
self
.
_lhs
S
ymbol
=
new_value
self
.
_is
D
eclaration
=
True
is_cast
=
self
.
_lhs
S
ymbol
.
func
==
cast_func
if
isinstance
(
self
.
_lhs
S
ymbol
,
Field
.
Access
)
or
isinstance
(
self
.
_lhs
S
ymbol
,
sp
.
Indexed
)
or
is_cast
:
self
.
_is
D
eclaration
=
False
self
.
_lhs
_s
ymbol
=
new_value
self
.
_is
_d
eclaration
=
True
is_cast
=
self
.
_lhs
_s
ymbol
.
func
==
cast_func
if
isinstance
(
self
.
_lhs
_s
ymbol
,
Field
.
Access
)
or
isinstance
(
self
.
_lhs
_s
ymbol
,
sp
.
Indexed
)
or
is_cast
:
self
.
_is
_d
eclaration
=
False
def
subs
(
self
,
*
args
,
**
kwargs
):
self
.
lhs
=
fast_subs
(
self
.
lhs
,
*
args
,
**
kwargs
)
...
...
@@ -440,13 +440,13 @@ class SympyAssignment(Node):
@
property
def
args
(
self
):
return
[
self
.
_lhs
S
ymbol
,
self
.
rhs
]
return
[
self
.
_lhs
_s
ymbol
,
self
.
rhs
]
@
property
def
symbols_defined
(
self
):
if
not
self
.
_is
D
eclaration
:
if
not
self
.
_is
_d
eclaration
:
return
set
()
return
{
self
.
_lhs
S
ymbol
}
return
{
self
.
_lhs
_s
ymbol
}
@
property
def
undefined_symbols
(
self
):
...
...
@@ -458,16 +458,16 @@ class SympyAssignment(Node):
for
i
in
range
(
len
(
symbol
.
offsets
)):
loop_counters
.
add
(
LoopOverCoordinate
.
get_loop_counter_symbol
(
i
))
result
.
update
(
loop_counters
)
result
.
update
(
self
.
_lhs
S
ymbol
.
atoms
(
sp
.
Symbol
))
result
.
update
(
self
.
_lhs
_s
ymbol
.
atoms
(
sp
.
Symbol
))
return
result
@
property
def
is_declaration
(
self
):
return
self
.
_is
D
eclaration
return
self
.
_is
_d
eclaration
@
property
def
is_const
(
self
):
return
self
.
_is
C
onst
return
self
.
_is
_c
onst
def
replace
(
self
,
child
,
replacement
):
if
child
==
self
.
lhs
:
...
...
@@ -495,24 +495,24 @@ class ResolvedFieldAccess(sp.Indexed):
obj
=
super
(
ResolvedFieldAccess
,
cls
).
__new__
(
cls
,
base
,
linearized_index
)
obj
.
field
=
field
obj
.
offsets
=
offsets
obj
.
idx
C
oordinate
V
alues
=
idx_coordinate_values
obj
.
idx
_c
oordinate
_v
alues
=
idx_coordinate_values
return
obj
def
_eval_subs
(
self
,
old
,
new
):
return
ResolvedFieldAccess
(
self
.
args
[
0
],
self
.
args
[
1
].
subs
(
old
,
new
),
self
.
field
,
self
.
offsets
,
self
.
idx
C
oordinate
V
alues
)
self
.
field
,
self
.
offsets
,
self
.
idx
_c
oordinate
_v
alues
)
def
fast_subs
(
self
,
substitutions
):
if
self
in
substitutions
:
return
substitutions
[
self
]
return
ResolvedFieldAccess
(
self
.
args
[
0
].
subs
(
substitutions
),
self
.
args
[
1
].
subs
(
substitutions
),
self
.
field
,
self
.
offsets
,
self
.
idx
C
oordinate
V
alues
)
self
.
field
,
self
.
offsets
,
self
.
idx
_c
oordinate
_v
alues
)
def
_hashable_content
(
self
):
super_class_contents
=
super
(
ResolvedFieldAccess
,
self
).
_hashable_content
()
return
super_class_contents
+
tuple
(
self
.
offsets
)
+
(
repr
(
self
.
idx
C
oordinate
V
alues
),
hash
(
self
.
field
))
return
super_class_contents
+
tuple
(
self
.
offsets
)
+
(
repr
(
self
.
idx
_c
oordinate
_v
alues
),
hash
(
self
.
field
))
@
property
def
typed_symbol
(
self
):
...
...
@@ -523,7 +523,7 @@ class ResolvedFieldAccess(sp.Indexed):
return
"%s (%s)"
%
(
top
,
self
.
typed_symbol
.
dtype
)
def
__getnewargs__
(
self
):
return
self
.
base
,
self
.
indices
[
0
],
self
.
field
,
self
.
offsets
,
self
.
idx
C
oordinate
V
alues
return
self
.
base
,
self
.
indices
[
0
],
self
.
field
,
self
.
offsets
,
self
.
idx
_c
oordinate
_v
alues
class
TemporaryMemoryAllocation
(
Node
):
...
...
backends/__init__.py
View file @
4a7299f1
...
...
@@ -2,6 +2,6 @@ from .cbackend import generate_c
try
:
from
.dot
import
print_dot
from
.llvm
import
generate
LLVM
from
.llvm
import
generate
_llvm
except
ImportError
:
pass
backends/cbackend.py
View file @
4a7299f1
...
...
@@ -11,7 +11,7 @@ except ImportError:
from
pystencils.bitoperations
import
bitwise_xor
,
bit_shift_right
,
bit_shift_left
,
bitwise_and
,
bitwise_or
from
pystencils.astnodes
import
Node
,
ResolvedFieldAccess
,
SympyAssignment
from
pystencils.data_types
import
create_type
,
PointerType
,
get_type_of_expression
,
VectorType
,
cast_func
from
pystencils.backends.simd_instruction_sets
import
selected
I
nstruction
S
et
from
pystencils.backends.simd_instruction_sets
import
selected
_i
nstruction
_s
et
__all__
=
[
'generate_c'
,
'CustomCppCode'
,
'PrintNode'
,
'get_headers'
]
...
...
@@ -36,7 +36,7 @@ def generate_c(ast_node: Node, signature_only: bool = False, use_float_constants
double
=
create_type
(
'double'
)
use_float_constants
=
double
not
in
field_types
vector_is
=
selected
I
nstruction
S
et
[
'double'
]
vector_is
=
selected
_i
nstruction
_s
et
[
'double'
]
printer
=
CBackend
(
constants_as_floats
=
use_float_constants
,
signature_only
=
signature_only
,
vector_instruction_set
=
vector_is
)
return
printer
(
ast_node
)
...
...
@@ -50,7 +50,7 @@ def get_headers(ast_node: Node) -> Set[str]:
headers
.
update
(
ast_node
.
headers
)
elif
isinstance
(
ast_node
,
SympyAssignment
):
if
type
(
get_type_of_expression
(
ast_node
.
rhs
))
is
VectorType
:
headers
.
update
(
selected
I
nstruction
S
et
[
'double'
][
'headers'
])
headers
.
update
(
selected
_i
nstruction
_s
et
[
'double'
][
'headers'
])
for
a
in
ast_node
.
args
:
if
isinstance
(
a
,
Node
):
...
...
@@ -104,23 +104,23 @@ class CBackend:
def
__init__
(
self
,
constants_as_floats
=
False
,
sympy_printer
=
None
,
signature_only
=
False
,
vector_instruction_set
=
None
):
if
sympy_printer
is
None
:
self
.
sympy
P
rinter
=
CustomSympyPrinter
(
constants_as_floats
)
self
.
sympy
_p
rinter
=
CustomSympyPrinter
(
constants_as_floats
)
if
vector_instruction_set
is
not
None
:
self
.
sympy
P
rinter
=
VectorizedCustomSympyPrinter
(
vector_instruction_set
,
constants_as_floats
)
self
.
sympy
_p
rinter
=
VectorizedCustomSympyPrinter
(
vector_instruction_set
,
constants_as_floats
)
else
:
self
.
sympy
P
rinter
=
CustomSympyPrinter
(
constants_as_floats
)
self
.
sympy
_p
rinter
=
CustomSympyPrinter
(
constants_as_floats
)
else
:
self
.
sympy
P
rinter
=
sympy_printer
self
.
sympy
_p
rinter
=
sympy_printer
self
.
_vectorInstructionSet
=
vector_instruction_set
self
.
_indent
=
" "
self
.
_signatureOnly
=
signature_only
def
__call__
(
self
,
node
):
prev_is
=
VectorType
.
instruction
S
et
VectorType
.
instruction
S
et
=
self
.
_vectorInstructionSet
prev_is
=
VectorType
.
instruction
_s
et
VectorType
.
instruction
_s
et
=
self
.
_vectorInstructionSet
result
=
str
(
self
.
_print
(
node
))
VectorType
.
instruction
S
et
=
prev_is
VectorType
.
instruction
_s
et
=
prev_is
return
result
def
_print
(
self
,
node
):
...
...
@@ -144,49 +144,49 @@ class CBackend:
return
"{
\n
%s
\n
}"
%
(
self
.
_indent
+
self
.
_indent
.
join
(
block_contents
.
splitlines
(
True
)))
def
_print_PragmaBlock
(
self
,
node
):
return
"%s
\n
%s"
%
(
node
.
pragma
L
ine
,
self
.
_print_Block
(
node
))
return
"%s
\n
%s"
%
(
node
.
pragma
_l
ine
,
self
.
_print_Block
(
node
))
def
_print_LoopOverCoordinate
(
self
,
node
):
counter_symbol
=
node
.
loop_counter_name
start
=
"int %s = %s"
%
(
counter_symbol
,
self
.
sympy
P
rinter
.
doprint
(
node
.
start
))
condition
=
"%s < %s"
%
(
counter_symbol
,
self
.
sympy
P
rinter
.
doprint
(
node
.
stop
))
update
=
"%s += %s"
%
(
counter_symbol
,
self
.
sympy
P
rinter
.
doprint
(
node
.
step
),)
loop
S
tr
=
"for (%s; %s; %s)"
%
(
start
,
condition
,
update
)
start
=
"int %s = %s"
%
(
counter_symbol
,
self
.
sympy
_p
rinter
.
doprint
(
node
.
start
))
condition
=
"%s < %s"
%
(
counter_symbol
,
self
.
sympy
_p
rinter
.
doprint
(
node
.
stop
))
update
=
"%s += %s"
%
(
counter_symbol
,
self
.
sympy
_p
rinter
.
doprint
(
node
.
step
),)
loop
_s
tr
=
"for (%s; %s; %s)"
%
(
start
,
condition
,
update
)
prefix
=
"
\n
"
.
join
(
node
.
prefix
L
ines
)
prefix
=
"
\n
"
.
join
(
node
.
prefix
_l
ines
)
if
prefix
:
prefix
+=
"
\n
"
return
"%s%s
\n
%s"
%
(
prefix
,
loop
S
tr
,
self
.
_print
(
node
.
body
))
return
"%s%s
\n
%s"
%
(
prefix
,
loop
_s
tr
,
self
.
_print
(
node
.
body
))
def
_print_SympyAssignment
(
self
,
node
):
if
node
.
is_declaration
:
data_type
=
"const "
+
str
(
node
.
lhs
.
dtype
)
+
" "
if
node
.
is_const
else
str
(
node
.
lhs
.
dtype
)
+
" "
return
"%s %s = %s;"
%
(
data_type
,
self
.
sympy
P
rinter
.
doprint
(
node
.
lhs
),
self
.
sympy
P
rinter
.
doprint
(
node
.
rhs
))
return
"%s %s = %s;"
%
(
data_type
,
self
.
sympy
_p
rinter
.
doprint
(
node
.
lhs
),
self
.
sympy
_p
rinter
.
doprint
(
node
.
rhs
))
else
:
lhs_type
=
get_type_of_expression
(
node
.
lhs
)
if
type
(
lhs_type
)
is
VectorType
and
node
.
lhs
.
func
==
cast_func
:
return
self
.
_vectorInstructionSet
[
'storeU'
].
format
(
"&"
+
self
.
sympy
P
rinter
.
doprint
(
node
.
lhs
.
args
[
0
]),
self
.
sympy
P
rinter
.
doprint
(
node
.
rhs
))
+
';'
return
self
.
_vectorInstructionSet
[
'storeU'
].
format
(
"&"
+
self
.
sympy
_p
rinter
.
doprint
(
node
.
lhs
.
args
[
0
]),
self
.
sympy
_p
rinter
.
doprint
(
node
.
rhs
))
+
';'
else
:
return
"%s = %s;"
%
(
self
.
sympy
P
rinter
.
doprint
(
node
.
lhs
),
self
.
sympy
P
rinter
.
doprint
(
node
.
rhs
))
return
"%s = %s;"
%
(
self
.
sympy
_p
rinter
.
doprint
(
node
.
lhs
),
self
.
sympy
_p
rinter
.
doprint
(
node
.
rhs
))
def
_print_TemporaryMemoryAllocation
(
self
,
node
):
return
"%s %s = new %s[%s];"
%
(
node
.
symbol
.
dtype
,
self
.
sympy
P
rinter
.
doprint
(
node
.
symbol
.
name
),
node
.
symbol
.
dtype
.
base_type
,
self
.
sympy
P
rinter
.
doprint
(
node
.
size
))
return
"%s %s = new %s[%s];"
%
(
node
.
symbol
.
dtype
,
self
.
sympy
_p
rinter
.
doprint
(
node
.
symbol
.
name
),
node
.
symbol
.
dtype
.
base_type
,
self
.
sympy
_p
rinter
.
doprint
(
node
.
size
))
def
_print_TemporaryMemoryFree
(
self
,
node
):
return
"delete [] %s;"
%
(
self
.
sympy
P
rinter
.
doprint
(
node
.
symbol
.
name
),)
return
"delete [] %s;"
%
(
self
.
sympy
_p
rinter
.
doprint
(
node
.
symbol
.
name
),)
@
staticmethod
def
_print_CustomCppCode
(
node
):
return
node
.
code
def
_print_Conditional
(
self
,
node
):
condition_expr
=
self
.
sympy
P
rinter
.
doprint
(
node
.
condition
E
xpr
)
true_block
=
self
.
_print_Block
(
node
.
true
B
lock
)
condition_expr
=
self
.
sympy
_p
rinter
.
doprint
(
node
.
condition
_e
xpr
)
true_block
=
self
.
_print_Block
(
node
.
true
_b
lock
)
result
=
"if (%s)
\n
%s "
%
(
condition_expr
,
true_block
)
if
node
.
false
B
lock
:
false_block
=
self
.
_print_Block
(
node
.
false
B
lock
)
if
node
.
false
_b
lock
:
false_block
=
self
.
_print_Block
(
node
.
false
_b
lock
)
result
+=
"else "
+
false_block
return
result
...
...
@@ -253,14 +253,14 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
def
__init__
(
self
,
instruction_set
,
constants_as_floats
=
False
):
super
(
VectorizedCustomSympyPrinter
,
self
).
__init__
(
constants_as_floats
)
self
.
instruction
S
et
=
instruction_set
self
.
instruction
_s
et
=
instruction_set
def
_scalarFallback
(
self
,
func_name
,
expr
,
*
args
,
**
kwargs
):
expr_type
=
get_type_of_expression
(
expr
)
if
type
(
expr_type
)
is
not
VectorType
:
return
getattr
(
super
(
VectorizedCustomSympyPrinter
,
self
),
func_name
)(
expr
,
*
args
,
**
kwargs
)
else
:
assert
self
.
instruction
S
et
[
'width'
]
==
expr_type
.
width
assert
self
.
instruction
_s
et
[
'width'
]
==
expr_type
.
width
return
None
def
_print_Function
(
self
,
expr
):
...
...
@@ -268,9 +268,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
arg
,
data_type
=
expr
.
args
if
type
(
data_type
)
is
VectorType
:
if
type
(
arg
)
is
ResolvedFieldAccess
:
return
self
.
instruction
S
et
[
'loadU'
].
format
(
"& "
+
self
.
_print
(
arg
))
return
self
.
instruction
_s
et
[
'loadU'
].
format
(
"& "
+
self
.
_print
(
arg
))
else
:
return
self
.
instruction
S
et
[
'makeVec'
].
format
(
self
.
_print
(
arg
))
return
self
.
instruction
_s
et
[
'makeVec'
].
format
(
self
.
_print
(
arg
))
return
super
(
VectorizedCustomSympyPrinter
,
self
).
_print_Function
(
expr
)
...
...
@@ -283,7 +283,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert
len
(
arg_strings
)
>
0
result
=
arg_strings
[
0
]
for
item
in
arg_strings
[
1
:]:
result
=
self
.
instruction
S
et
[
'&'
].
format
(
result
,
item
)
result
=
self
.
instruction
_s
et
[
'&'
].
format
(
result
,
item
)
return
result
def
_print_Or
(
self
,
expr
):
...
...
@@ -295,7 +295,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert
len
(
arg_strings
)
>
0
result
=
arg_strings
[
0
]
for
item
in
arg_strings
[
1
:]:
result
=
self
.
instruction
S
et
[
'|'
].
format
(
result
,
item
)
result
=
self
.
instruction
_s
et
[
'|'
].
format
(
result
,
item
)
return
result
def
_print_Add
(
self
,
expr
,
order
=
None
):
...
...
@@ -320,7 +320,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert
len
(
summands
)
>=
2
processed
=
summands
[
0
].
term
for
summand
in
summands
[
1
:]:
func
=
self
.
instruction
S
et
[
'-'
]
if
summand
.
sign
==
-
1
else
self
.
instruction
S
et
[
'+'
]
func
=
self
.
instruction
_s
et
[
'-'
]
if
summand
.
sign
==
-
1
else
self
.
instruction
_s
et
[
'+'
]
processed
=
func
.
format
(
processed
,
summand
.
term
)
return
processed
...
...
@@ -333,10 +333,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
return
"("
+
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
expr
.
exp
,
evaluate
=
False
))
+
")"
else
:
if
expr
.
exp
==
-
1
:
one
=
self
.
instruction
S
et
[
'makeVec'
].
format
(
1.0
)
return
self
.
instruction
S
et
[
'/'
].
format
(
one
,
self
.
_print
(
expr
.
base
))
one
=
self
.
instruction
_s
et
[
'makeVec'
].
format
(
1.0
)
return
self
.
instruction
_s
et
[
'/'
].
format
(
one
,
self
.
_print
(
expr
.
base
))
elif
expr
.
exp
==
0.5
:
return
self
.
instruction
S
et
[
'sqrt'
].
format
(
self
.
_print
(
expr
.
base
))
return
self
.
instruction
_s
et
[
'sqrt'
].
format
(
self
.
_print
(
expr
.
base
))
else
:
raise
ValueError
(
"Generic exponential not supported"
)
...
...
@@ -369,26 +369,26 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
a
.
append
(
item
)
a
=
a
or
[
S
.
One
]
# a = a or [cast
F
unc(S.One, VectorType(create
T
ype
F
rom
S
tring("double"), expr
T
ype.width))]
# a = a or [cast
_f
unc(S.One, VectorType(create
_t
ype
_f
rom
_s
tring("double"), expr
_t
ype.width))]
a_str
=
[
self
.
_print
(
x
)
for
x
in
a
]
b_str
=
[
self
.
_print
(
x
)
for
x
in
b
]
result
=
a_str
[
0
]
for
item
in
a_str
[
1
:]:
result
=
self
.
instruction
S
et
[
'*'
].
format
(
result
,
item
)
result
=
self
.
instruction
_s
et
[
'*'
].
format
(
result
,
item
)
if
len
(
b
)
>
0
:
denominator_str
=
b_str
[
0
]
for
item
in
b_str
[
1
:]:
denominator_str
=
self
.
instruction
S
et
[
'*'
].
format
(
denominator_str
,
item
)
result
=
self
.
instruction
S
et
[
'/'
].