Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Stephan Seitz
pystencils
Commits
c512c755
Commit
c512c755
authored
Jul 05, 2019
by
Stephan Seitz
Browse files
Enable usage of templated Field type
parent
8e63c9ff
Changes
3
Hide whitespace changes
Inline
Side-by-side
pystencils/astnodes.py
View file @
c512c755
from
typing
import
Any
,
List
,
Optional
,
Sequence
,
Set
,
Union
import
jinja2
import
sympy
as
sp
from
pystencils.data_types
import
TypedSymbol
,
cast_func
,
create_type
...
...
@@ -655,7 +656,12 @@ class DestructuringBindingsForFieldClass(Node):
FieldShapeSymbol
:
"shape"
,
FieldStrideSymbol
:
"stride"
}
CLASS_NAME
=
"Field"
CLASS_NAME_TEMPLATE
=
jinja2
.
Template
(
"Field<{{ dtype }}, {{ ndim }}>"
)
@
property
def
fields_accessed
(
self
)
->
Set
[
'ResolvedFieldAccess'
]:
"""Set of Field instances: fields which are accessed inside this kernel function"""
return
set
(
o
.
field
for
o
in
self
.
atoms
(
ResolvedFieldAccess
))
def
__init__
(
self
,
body
):
super
(
DestructuringBindingsForFieldClass
,
self
).
__init__
()
...
...
@@ -676,10 +682,12 @@ class DestructuringBindingsForFieldClass(Node):
@
property
def
undefined_symbols
(
self
)
->
Set
[
sp
.
Symbol
]:
field_map
=
{
f
.
name
:
f
for
f
in
self
.
fields_accessed
}
undefined_field_symbols
=
self
.
symbols_defined
corresponding_field_names
=
{
s
.
field_name
for
s
in
undefined_field_symbols
if
hasattr
(
s
,
'field_name'
)}
corresponding_field_names
|=
{
s
.
field_names
[
0
]
for
s
in
undefined_field_symbols
if
hasattr
(
s
,
'field_names'
)}
return
{
TypedSymbol
(
f
,
self
.
CLASS_NAME
+
'&'
)
for
f
in
corresponding_field_names
}
|
\
return
{
TypedSymbol
(
f
,
self
.
CLASS_NAME_TEMPLATE
.
render
(
dtype
=
field_map
[
f
].
dtype
,
ndim
=
field_map
[
f
].
ndim
)
+
'&'
)
for
f
in
corresponding_field_names
}
|
\
(
self
.
body
.
undefined_symbols
-
undefined_field_symbols
)
def
subs
(
self
,
subs_dict
)
->
None
:
...
...
pystencils/data_types.py
View file @
c512c755
...
...
@@ -108,7 +108,7 @@ class TypedSymbol(sp.Symbol):
obj
=
super
(
TypedSymbol
,
cls
).
__xnew__
(
cls
,
name
)
try
:
obj
.
_dtype
=
create_type
(
dtype
)
except
TypeError
:
except
(
TypeError
,
ValueError
)
:
# on error keep the string
obj
.
_dtype
=
dtype
return
obj
...
...
pystencils/field.py
View file @
c512c755
...
...
@@ -306,6 +306,10 @@ class Field(AbstractField):
def
index_dimensions
(
self
)
->
int
:
return
len
(
self
.
shape
)
-
len
(
self
.
_layout
)
@
property
def
ndim
(
self
)
->
int
:
return
len
(
self
.
shape
)
@
property
def
layout
(
self
):
return
self
.
_layout
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment