Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jonas Plewinski
pystencils
Commits
9b9a4b54
Commit
9b9a4b54
authored
Jul 10, 2019
by
Martin Bauer
Browse files
Merge branch 'destructuring-field-binding' into 'master'
Destructuring field binding See merge request
pycodegen/pystencils!4
parents
8c4a6f1e
2313eda2
Changes
5
Hide whitespace changes
Inline
Side-by-side
pystencils/astnodes.py
View file @
9b9a4b54
from
typing
import
Any
,
List
,
Optional
,
Sequence
,
Set
,
Union
import
jinja2
import
sympy
as
sp
from
pystencils.data_types
import
TypedSymbol
,
cast_func
,
create_type
from
pystencils.field
import
Field
from
pystencils.
data_types
import
TypedSymbol
,
create_type
,
cast_func
from
pystencils.kernelparameters
import
FieldStrideSymbol
,
FieldPointerSymbol
,
FieldShap
eSymbol
from
pystencils.
kernelparameters
import
(
FieldPointerSymbol
,
FieldShapeSymbol
,
FieldStrid
eSymbol
)
from
pystencils.sympyextensions
import
fast_subs
from
typing
import
List
,
Set
,
Optional
,
Union
,
Any
,
Sequence
NodeOrExpr
=
Union
[
'Node'
,
sp
.
Expr
]
...
...
@@ -130,6 +134,7 @@ class KernelFunction(Node):
defined in pystencils.kernelparameters.
If the parameter is related to one or multiple fields, these fields are referenced in the fields property.
"""
def
__init__
(
self
,
symbol
,
fields
):
self
.
symbol
=
symbol
# type: TypedSymbol
self
.
fields
=
fields
# type: Sequence[Field]
...
...
@@ -582,6 +587,7 @@ class TemporaryMemoryAllocation(Node):
size: number of elements to allocate
align_offset: the align_offset's element is aligned
"""
def
__init__
(
self
,
typed_symbol
:
TypedSymbol
,
size
,
align_offset
):
super
(
TemporaryMemoryAllocation
,
self
).
__init__
(
parent
=
None
)
self
.
symbol
=
typed_symbol
...
...
@@ -639,3 +645,58 @@ class TemporaryMemoryFree(Node):
def
early_out
(
condition
):
from
pystencils.cpu.vectorization
import
vec_all
return
Conditional
(
vec_all
(
condition
),
Block
([
SkipIteration
()]))
class
DestructuringBindingsForFieldClass
(
Node
):
"""
Defines all variables needed for describing a field (shape, pointer, strides)
"""
CLASS_TO_MEMBER_DICT
=
{
FieldPointerSymbol
:
"data"
,
FieldShapeSymbol
:
"shape"
,
FieldStrideSymbol
:
"stride"
}
CLASS_NAME_TEMPLATE
=
jinja2
.
Template
(
"Field<{{ dtype }}, {{ ndim }}>"
)
@
property
def
fields_accessed
(
self
)
->
Set
[
'ResolvedFieldAccess'
]:
"""Set of Field instances: fields which are accessed inside this kernel function"""
return
set
(
o
.
field
for
o
in
self
.
atoms
(
ResolvedFieldAccess
))
def
__init__
(
self
,
body
):
super
(
DestructuringBindingsForFieldClass
,
self
).
__init__
()
self
.
headers
=
[
'<Field.h>'
]
self
.
body
=
body
@
property
def
args
(
self
)
->
List
[
NodeOrExpr
]:
"""Returns all arguments/children of this node."""
return
set
()
@
property
def
symbols_defined
(
self
)
->
Set
[
sp
.
Symbol
]:
"""Set of symbols which are defined by this node."""
undefined_field_symbols
=
{
s
for
s
in
self
.
body
.
undefined_symbols
if
isinstance
(
s
,
(
FieldPointerSymbol
,
FieldShapeSymbol
,
FieldStrideSymbol
))}
return
undefined_field_symbols
@
property
def
undefined_symbols
(
self
)
->
Set
[
sp
.
Symbol
]:
field_map
=
{
f
.
name
:
f
for
f
in
self
.
fields_accessed
}
undefined_field_symbols
=
self
.
symbols_defined
corresponding_field_names
=
{
s
.
field_name
for
s
in
undefined_field_symbols
if
hasattr
(
s
,
'field_name'
)}
corresponding_field_names
|=
{
s
.
field_names
[
0
]
for
s
in
undefined_field_symbols
if
hasattr
(
s
,
'field_names'
)}
return
{
TypedSymbol
(
f
,
self
.
CLASS_NAME_TEMPLATE
.
render
(
dtype
=
field_map
[
f
].
dtype
,
ndim
=
field_map
[
f
].
ndim
)
+
'&'
)
for
f
in
corresponding_field_names
}
|
\
(
self
.
body
.
undefined_symbols
-
undefined_field_symbols
)
def
subs
(
self
,
subs_dict
)
->
None
:
"""Inplace! substitute, similar to sympy's but modifies the AST inplace."""
self
.
body
.
subs
(
subs_dict
)
@
property
def
func
(
self
):
return
self
.
__class__
def
atoms
(
self
,
arg_type
)
->
Set
[
Any
]:
return
self
.
body
.
atoms
(
arg_type
)
|
{
s
for
s
in
self
.
symbols_defined
if
isinstance
(
s
,
arg_type
)}
pystencils/backends/cbackend.py
View file @
9b9a4b54
import
sympy
as
sp
from
collections
import
namedtuple
from
sympy.core
import
S
from
typing
import
Set
import
jinja2
import
sympy
as
sp
from
sympy.core
import
S
from
sympy.printing.ccode
import
C89CodePrinter
from
pystencils.cpu.vectorization
import
vec_any
,
vec_all
from
pystencils.astnodes
import
(
DestructuringBindingsForFieldClass
,
KernelFunction
,
Node
)
from
pystencils.cpu.vectorization
import
vec_all
,
vec_any
from
pystencils.data_types
import
(
PointerType
,
VectorType
,
address_of
,
cast_func
,
create_type
,
reinterpret_cast_func
,
cast_func
,
create_type
,
get_type_of_expression
,
vector_memory_access
)
from
pystencils.fast_approximation
import
fast_division
,
fast_inv_sqrt
,
fast_sqrt
reinterpret_cast_func
,
vector_memory_access
)
from
pystencils.fast_approximation
import
(
fast_division
,
fast_inv_sqrt
,
fast_sqrt
)
from
pystencils.integer_functions
import
(
bit_shift_left
,
bit_shift_right
,
bitwise_and
,
bitwise_or
,
bitwise_xor
,
int_div
,
int_power_of_2
,
modulo_ceil
)
from
pystencils.kernelparameters
import
FieldPointerSymbol
try
:
from
sympy.printing.ccode
import
C99CodePrinter
as
CCodePrinter
except
ImportError
:
from
sympy.printing.ccode
import
CCodePrinter
# for sympy versions < 1.1
from
pystencils.integer_functions
import
bitwise_xor
,
bit_shift_right
,
bit_shift_left
,
bitwise_and
,
\
bitwise_or
,
modulo_ceil
,
int_div
,
int_power_of_2
from
pystencils.astnodes
import
Node
,
KernelFunction
__all__
=
[
'generate_c'
,
'CustomCodeNode'
,
'PrintNode'
,
'get_headers'
,
'CustomSympyPrinter'
]
...
...
@@ -255,6 +261,30 @@ class CBackend:
result
+=
"else "
+
false_block
return
result
def
_print_DestructuringBindingsForFieldClass
(
self
,
node
:
Node
):
# Define all undefined symbols
undefined_field_symbols
=
node
.
symbols_defined
destructuring_bindings
=
[
"%s = %s.%s%s;"
%
(
u
.
name
,
u
.
field_name
if
hasattr
(
u
,
'field_name'
)
else
u
.
field_names
[
0
],
DestructuringBindingsForFieldClass
.
CLASS_TO_MEMBER_DICT
[
u
.
__class__
],
""
if
type
(
u
)
==
FieldPointerSymbol
else
(
"[%i]"
%
u
.
coordinate
))
for
u
in
undefined_field_symbols
]
destructuring_bindings
.
sort
()
# only for code aesthetics
template
=
jinja2
.
Template
(
"""{
{% for binding in bindings -%}
{{ binding | indent(3) }}
{% endfor -%}
{{ block | indent(3) }}
}
"""
)
code
=
template
.
render
(
bindings
=
destructuring_bindings
,
block
=
self
.
_print
(
node
.
body
))
return
code
# ------------------------------------------ Helper function & classes -------------------------------------------------
...
...
pystencils/data_types.py
View file @
9b9a4b54
...
...
@@ -108,7 +108,7 @@ class TypedSymbol(sp.Symbol):
obj
=
super
(
TypedSymbol
,
cls
).
__xnew__
(
cls
,
name
)
try
:
obj
.
_dtype
=
create_type
(
dtype
)
except
TypeError
:
except
(
TypeError
,
ValueError
)
:
# on error keep the string
obj
.
_dtype
=
dtype
return
obj
...
...
pystencils/field.py
View file @
9b9a4b54
...
...
@@ -306,6 +306,10 @@ class Field(AbstractField):
def
index_dimensions
(
self
)
->
int
:
return
len
(
self
.
shape
)
-
len
(
self
.
_layout
)
@
property
def
ndim
(
self
)
->
int
:
return
len
(
self
.
shape
)
@
property
def
layout
(
self
):
return
self
.
_layout
...
...
pystencils_tests/test_destructuring_field_class.py
0 → 100644
View file @
9b9a4b54
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import
sympy
import
pystencils
from
pystencils.astnodes
import
DestructuringBindingsForFieldClass
def
test_destructuring_field_class
():
z
,
x
,
y
=
pystencils
.
fields
(
"z, y, x: [2d]"
)
normal_assignments
=
pystencils
.
AssignmentCollection
([
pystencils
.
Assignment
(
z
[
0
,
0
],
x
[
0
,
0
]
*
sympy
.
log
(
x
[
0
,
0
]
*
y
[
0
,
0
]))],
[])
ast
=
pystencils
.
create_kernel
(
normal_assignments
)
print
(
pystencils
.
show_code
(
ast
))
ast
.
body
=
DestructuringBindingsForFieldClass
(
ast
.
body
)
print
(
pystencils
.
show_code
(
ast
))
def
main
():
test_destructuring_field_class
()
if
__name__
==
'__main__'
:
main
()
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment