Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Tom Harke
pystencils
Commits
4f43b51a
Commit
4f43b51a
authored
Mar 27, 2019
by
Nils Kohl
🌝
Committed by
Martin Bauer
Apr 26, 2019
Browse files
Improved support for arbitrary field classes.
- introduced AbstractField and AbstractAccess Fixes #28
parent
eec4dc4b
Changes
2
Hide whitespace changes
Inline
Side-by-side
pystencils/field.py
View file @
4f43b51a
...
...
@@ -13,7 +13,7 @@ from pystencils.sympyextensions import is_integer_sequence
import
pickle
import
hashlib
__all__
=
[
'Field'
,
'fields'
,
'FieldType'
]
__all__
=
[
'Field'
,
'fields'
,
'FieldType'
,
'AbstractField'
]
def
fields
(
description
=
None
,
index_dimensions
=
0
,
layout
=
None
,
**
kwargs
):
...
...
@@ -116,7 +116,13 @@ class FieldType(Enum):
return
field
.
field_type
==
FieldType
.
CUSTOM
class
Field
:
class
AbstractField
:
class
AbstractAccess
:
pass
class
Field
(
AbstractField
):
"""
With fields one can formulate stencil-like update rules on structured grids.
This Field class knows about the dimension, memory layout (strides) and optionally about the size of an array.
...
...
@@ -394,7 +400,7 @@ class Field:
return
self
.
hashable_contents
()
==
other
.
hashable_contents
()
# noinspection PyAttributeOutsideInit,PyUnresolvedReferences
class
Access
(
sp
.
Symbol
):
class
Access
(
sp
.
Symbol
,
AbstractField
.
AbstractAccess
):
"""Class representing a relative access into a `Field`.
This class behaves like a normal sympy Symbol, it is actually derived from it. One can built up
...
...
pystencils/transformations.py
View file @
4f43b51a
...
...
@@ -9,7 +9,7 @@ from sympy.logic.boolalg import Boolean
from
sympy.tensor
import
IndexedBase
from
pystencils.simp.assignment_collection
import
AssignmentCollection
from
pystencils.assignment
import
Assignment
from
pystencils.field
import
Field
,
FieldType
from
pystencils.field
import
Abstract
Field
,
FieldType
,
Field
from
pystencils.data_types
import
TypedSymbol
,
PointerType
,
StructType
,
get_base_type
,
reinterpret_cast_func
,
\
cast_func
,
pointer_arithmetic_func
,
get_type_of_expression
,
collate_types
,
create_type
from
pystencils.kernelparameters
import
FieldPointerSymbol
...
...
@@ -160,7 +160,7 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer
:class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts
"""
# find correct ordering by inspecting participating FieldAccesses
field_accesses
=
body
.
atoms
(
Field
.
Access
)
field_accesses
=
body
.
atoms
(
AbstractField
.
Abstract
Access
)
field_accesses
=
{
e
for
e
in
field_accesses
if
not
e
.
is_absolute_access
}
# exclude accesses to buffers from field_list, because buffers are treated separately
...
...
@@ -353,7 +353,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
loop_iterations
=
[(
l
.
stop
-
l
.
start
)
/
l
.
step
for
l
in
loops
]
loop_counters
=
[
l
.
loop_counter_symbol
for
l
in
loops
]
field_accesses
=
ast_node
.
atoms
(
Field
.
Access
)
field_accesses
=
ast_node
.
atoms
(
AbstractField
.
Abstract
Access
)
buffer_accesses
=
{
fa
for
fa
in
field_accesses
if
FieldType
.
is_buffer
(
fa
.
field
)}
loop_counters
=
[
v
*
len
(
buffer_accesses
)
for
v
in
loop_counters
]
...
...
@@ -369,7 +369,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
def
resolve_buffer_accesses
(
ast_node
,
base_buffer_index
,
read_only_field_names
=
set
()):
def
visit_sympy_expr
(
expr
,
enclosing_block
,
sympy_assignment
):
if
isinstance
(
expr
,
Field
.
Access
):
if
isinstance
(
expr
,
AbstractField
.
Abstract
Access
):
field_access
=
expr
# Do not apply transformation if field is not a buffer
...
...
@@ -433,7 +433,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
field_to_fixed_coordinates
=
OrderedDict
(
sorted
(
field_to_fixed_coordinates
.
items
(),
key
=
lambda
pair
:
pair
[
0
]))
def
visit_sympy_expr
(
expr
,
enclosing_block
,
sympy_assignment
):
if
isinstance
(
expr
,
Field
.
Access
):
if
isinstance
(
expr
,
AbstractField
.
Abstract
Access
):
field_access
=
expr
field
=
field_access
.
field
...
...
@@ -654,12 +654,13 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
if
s
in
assignment_map
:
# if there is no assignment inside the loop body it is independent already
for
new_symbol
in
assignment_map
[
s
].
rhs
.
atoms
(
sp
.
Symbol
):
if
type
(
new_symbol
)
is
not
Field
.
Access
and
new_symbol
not
in
symbols_with_temporary_array
:
if
not
isinstance
(
new_symbol
,
AbstractField
.
AbstractAccess
)
and
\
new_symbol
not
in
symbols_with_temporary_array
:
symbols_to_process
.
append
(
new_symbol
)
symbols_resolved
.
add
(
s
)
for
symbol
in
symbol_group
:
if
type
(
symbol
)
is
not
Field
.
Access
:
if
not
isinstance
(
symbol
,
AbstractField
.
Abstract
Access
)
:
assert
type
(
symbol
)
is
TypedSymbol
new_ts
=
TypedSymbol
(
symbol
.
name
,
PointerType
(
symbol
.
dtype
))
symbols_with_temporary_array
[
symbol
]
=
IndexedBase
(
new_ts
,
shape
=
(
1
,))[
inner_loop
.
loop_counter_symbol
]
...
...
@@ -668,7 +669,7 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
for
assignment
in
inner_loop
.
body
.
args
:
if
assignment
.
lhs
in
symbols_resolved
:
new_rhs
=
assignment
.
rhs
.
subs
(
symbols_with_temporary_array
.
items
())
if
typ
e
(
assignment
.
lhs
)
is
not
Field
.
Access
and
assignment
.
lhs
in
symbol_group
:
if
not
isinstanc
e
(
assignment
.
lhs
,
AbstractField
.
Abstract
Access
)
and
assignment
.
lhs
in
symbol_group
:
assert
type
(
assignment
.
lhs
)
is
TypedSymbol
new_ts
=
TypedSymbol
(
assignment
.
lhs
.
name
,
PointerType
(
assignment
.
lhs
.
dtype
))
new_lhs
=
IndexedBase
(
new_ts
,
shape
=
(
1
,))[
inner_loop
.
loop_counter_symbol
]
...
...
@@ -792,7 +793,7 @@ class KernelConstraintsCheck:
def
process_expression
(
self
,
rhs
,
type_constants
=
True
):
self
.
_update_accesses_rhs
(
rhs
)
if
isinstance
(
rhs
,
Field
.
Access
):
if
isinstance
(
rhs
,
AbstractField
.
Abstract
Access
):
self
.
fields_read
.
add
(
rhs
.
field
)
self
.
fields_read
.
update
(
rhs
.
indirect_addressing_fields
)
return
rhs
...
...
@@ -822,13 +823,13 @@ class KernelConstraintsCheck:
def
_process_lhs
(
self
,
lhs
):
assert
isinstance
(
lhs
,
sp
.
Symbol
)
self
.
_update_accesses_lhs
(
lhs
)
if
not
isinstance
(
lhs
,
Field
.
Access
)
and
not
isinstance
(
lhs
,
TypedSymbol
):
if
not
isinstance
(
lhs
,
AbstractField
.
Abstract
Access
)
and
not
isinstance
(
lhs
,
TypedSymbol
):
return
TypedSymbol
(
lhs
.
name
,
self
.
_type_for_symbol
[
lhs
.
name
])
else
:
return
lhs
def
_update_accesses_lhs
(
self
,
lhs
):
if
isinstance
(
lhs
,
Field
.
Access
):
if
isinstance
(
lhs
,
AbstractField
.
Abstract
Access
):
fai
=
self
.
FieldAndIndex
(
lhs
.
field
,
lhs
.
index
)
self
.
_field_writes
[
fai
].
add
(
lhs
.
offsets
)
if
len
(
self
.
_field_writes
[
fai
])
>
1
:
...
...
@@ -841,7 +842,7 @@ class KernelConstraintsCheck:
self
.
scopes
.
define_symbol
(
lhs
)
def
_update_accesses_rhs
(
self
,
rhs
):
if
isinstance
(
rhs
,
Field
.
Access
)
and
self
.
check_independence_condition
:
if
isinstance
(
rhs
,
AbstractField
.
Abstract
Access
)
and
self
.
check_independence_condition
:
writes
=
self
.
_field_writes
[
self
.
FieldAndIndex
(
rhs
.
field
,
rhs
.
index
)]
for
write_offset
in
writes
:
assert
len
(
writes
)
==
1
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment