Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jonas Plewinski
pystencils
Commits
1c0665c4
Commit
1c0665c4
authored
Sep 24, 2019
by
Stephan Seitz
Browse files
Implement interpolation (without CubicInterpolationCUDA)
parent
e871e864
Changes
18
Expand all
Hide whitespace changes
Inline
Side-by-side
pystencils/backends/cbackend.py
View file @
1c0665c4
...
@@ -103,6 +103,10 @@ def get_headers(ast_node: Node) -> Set[str]:
...
@@ -103,6 +103,10 @@ def get_headers(ast_node: Node) -> Set[str]:
if
isinstance
(
a
,
Node
):
if
isinstance
(
a
,
Node
):
headers
.
update
(
get_headers
(
a
))
headers
.
update
(
get_headers
(
a
))
for
g
in
get_global_declarations
(
ast_node
):
if
isinstance
(
g
,
Node
):
headers
.
update
(
get_headers
(
g
))
return
sorted
(
headers
)
return
sorted
(
headers
)
...
...
pystencils/backends/cuda_backend.py
View file @
1c0665c4
...
@@ -3,6 +3,7 @@ from os.path import dirname, join
...
@@ -3,6 +3,7 @@ from os.path import dirname, join
from
pystencils.astnodes
import
Node
from
pystencils.astnodes
import
Node
from
pystencils.backends.cbackend
import
CBackend
,
CustomSympyPrinter
,
generate_c
from
pystencils.backends.cbackend
import
CBackend
,
CustomSympyPrinter
,
generate_c
from
pystencils.fast_approximation
import
fast_division
,
fast_inv_sqrt
,
fast_sqrt
from
pystencils.fast_approximation
import
fast_division
,
fast_inv_sqrt
,
fast_sqrt
from
pystencils.interpolation_astnodes
import
InterpolationMode
with
open
(
join
(
dirname
(
__file__
),
'cuda_known_functions.txt'
))
as
f
:
with
open
(
join
(
dirname
(
__file__
),
'cuda_known_functions.txt'
))
as
f
:
lines
=
f
.
readlines
()
lines
=
f
.
readlines
()
...
@@ -43,11 +44,19 @@ class CudaBackend(CBackend):
...
@@ -43,11 +44,19 @@ class CudaBackend(CBackend):
return
code
return
code
def
_print_TextureDeclaration
(
self
,
node
):
def
_print_TextureDeclaration
(
self
,
node
):
code
=
"texture<%s, cudaTextureType%iD, cudaReadModeElementType> %s;"
%
(
str
(
node
.
texture
.
field
.
dtype
),
if
node
.
texture
.
field
.
dtype
.
numpy_dtype
.
itemsize
>
4
:
node
.
texture
.
field
.
spatial_dimensions
,
code
=
"texture<fp_tex_%s, cudaTextureType%iD, cudaReadModeElementType> %s;"
%
(
node
.
texture
str
(
node
.
texture
.
field
.
dtype
),
)
node
.
texture
.
field
.
spatial_dimensions
,
node
.
texture
)
else
:
code
=
"texture<%s, cudaTextureType%iD, cudaReadModeElementType> %s;"
%
(
str
(
node
.
texture
.
field
.
dtype
),
node
.
texture
.
field
.
spatial_dimensions
,
node
.
texture
)
return
code
return
code
def
_print_SkipIteration
(
self
,
_
):
def
_print_SkipIteration
(
self
,
_
):
...
@@ -62,17 +71,23 @@ class CudaSympyPrinter(CustomSympyPrinter):
...
@@ -62,17 +71,23 @@ class CudaSympyPrinter(CustomSympyPrinter):
self
.
known_functions
.
update
(
CUDA_KNOWN_FUNCTIONS
)
self
.
known_functions
.
update
(
CUDA_KNOWN_FUNCTIONS
)
def
_print_TextureAccess
(
self
,
node
):
def
_print_TextureAccess
(
self
,
node
):
dtype
=
node
.
texture
.
field
.
dtype
.
numpy_dtype
if
node
.
texture
.
cubic_bspline_interpolation
:
if
node
.
texture
.
interpolation_mode
==
InterpolationMode
.
CUBIC_SPLINE
:
template
=
"cubicTex%iDSimple
<%s>
(%s, %s)"
template
=
"cubicTex%iDSimple(%s, %s)"
else
:
else
:
template
=
"tex%iD<%s>(%s, %s)"
if
dtype
.
itemsize
>
4
:
# Use PyCuda hack!
# https://github.com/inducer/pycuda/blob/master/pycuda/cuda/pycuda-helpers.hpp
template
=
"fp_tex%iD(%s, %s)"
else
:
template
=
"tex%iD(%s, %s)"
code
=
template
%
(
code
=
template
%
(
node
.
texture
.
field
.
spatial_dimensions
,
node
.
texture
.
field
.
spatial_dimensions
,
str
(
node
.
texture
.
field
.
dtype
),
str
(
node
.
texture
),
str
(
node
.
texture
),
', '
.
join
(
self
.
_print
(
o
)
for
o
in
node
.
offsets
)
# + 0.5 comes from Nvidia's staggered indexing
', '
.
join
(
self
.
_print
(
o
+
0.5
)
for
o
in
reversed
(
node
.
offsets
))
)
)
return
code
return
code
...
...
pystencils/backends/cuda_known_functions.txt
View file @
1c0665c4
...
@@ -45,6 +45,7 @@ tex1D
...
@@ -45,6 +45,7 @@ tex1D
tex2D
tex2D
tex3D
tex3D
sqrtf
rsqrtf
rsqrtf
cbrtf
cbrtf
rcbrtf
rcbrtf
...
...
pystencils/cpu/kernelcreation.py
View file @
1c0665c4
...
@@ -10,8 +10,8 @@ from pystencils.data_types import BasicType, StructType, TypedSymbol, create_typ
...
@@ -10,8 +10,8 @@ from pystencils.data_types import BasicType, StructType, TypedSymbol, create_typ
from
pystencils.field
import
Field
,
FieldType
from
pystencils.field
import
Field
,
FieldType
from
pystencils.transformations
import
(
from
pystencils.transformations
import
(
add_types
,
filtered_tree_iteration
,
get_base_buffer_index
,
get_optimal_loop_ordering
,
add_types
,
filtered_tree_iteration
,
get_base_buffer_index
,
get_optimal_loop_ordering
,
make_loop_over_domain
,
move_constants_before_loop
,
parse_base_pointer_info
,
implement_interpolations
,
make_loop_over_domain
,
move_constants_before_loop
,
resolve_buffer_accesses
,
resolve_field_accesses
,
split_inner_loop
)
parse_base_pointer_info
,
resolve_buffer_accesses
,
resolve_field_accesses
,
split_inner_loop
)
AssignmentOrAstNodeList
=
List
[
Union
[
Assignment
,
ast
.
Node
]]
AssignmentOrAstNodeList
=
List
[
Union
[
Assignment
,
ast
.
Node
]]
...
@@ -67,6 +67,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
...
@@ -67,6 +67,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
ghost_layers
=
ghost_layers
,
loop_order
=
loop_order
)
ghost_layers
=
ghost_layers
,
loop_order
=
loop_order
)
ast_node
=
KernelFunction
(
loop_node
,
'cpu'
,
'c'
,
compile_function
=
make_python_function
,
ast_node
=
KernelFunction
(
loop_node
,
'cpu'
,
'c'
,
compile_function
=
make_python_function
,
ghost_layers
=
ghost_layer_info
,
function_name
=
function_name
)
ghost_layers
=
ghost_layer_info
,
function_name
=
function_name
)
implement_interpolations
(
body
)
if
split_groups
:
if
split_groups
:
typed_split_groups
=
[[
type_symbol
(
s
)
for
s
in
split_group
]
for
split_group
in
split_groups
]
typed_split_groups
=
[[
type_symbol
(
s
)
for
s
in
split_group
]
for
split_group
in
split_groups
]
...
@@ -139,6 +140,8 @@ def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, fu
...
@@ -139,6 +140,8 @@ def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, fu
loop_body
=
Block
([])
loop_body
=
Block
([])
loop_node
=
LoopOverCoordinate
(
loop_body
,
coordinate_to_loop_over
=
0
,
start
=
0
,
stop
=
index_fields
[
0
].
shape
[
0
])
loop_node
=
LoopOverCoordinate
(
loop_body
,
coordinate_to_loop_over
=
0
,
start
=
0
,
stop
=
index_fields
[
0
].
shape
[
0
])
implement_interpolations
(
loop_node
)
for
assignment
in
assignments
:
for
assignment
in
assignments
:
loop_body
.
append
(
assignment
)
loop_body
.
append
(
assignment
)
...
...
pystencils/data_types.py
View file @
1c0665c4
import
ctypes
import
ctypes
from
collections
import
defaultdict
from
collections
import
defaultdict
from
functools
import
partial
from
functools
import
partial
from
typing
import
Tuple
import
numpy
as
np
import
numpy
as
np
import
sympy
as
sp
import
sympy
as
sp
import
sympy.codegen.ast
from
sympy.core.cache
import
cacheit
from
sympy.core.cache
import
cacheit
from
sympy.logic.boolalg
import
Boolean
from
sympy.logic.boolalg
import
Boolean
import
pystencils
from
pystencils.cache
import
memorycache
,
memorycache_if_hashable
from
pystencils.cache
import
memorycache
,
memorycache_if_hashable
from
pystencils.utils
import
all_equal
from
pystencils.utils
import
all_equal
...
@@ -17,6 +20,26 @@ except ImportError as e:
...
@@ -17,6 +20,26 @@ except ImportError as e:
_ir_importerror
=
e
_ir_importerror
=
e
def
typed_symbols
(
names
,
dtype
,
*
args
):
symbols
=
sp
.
symbols
(
names
,
*
args
)
if
isinstance
(
symbols
,
Tuple
):
return
tuple
(
TypedSymbol
(
str
(
s
),
dtype
)
for
s
in
symbols
)
else
:
return
TypedSymbol
(
str
(
symbols
),
dtype
)
def
matrix_symbols
(
names
,
dtype
,
rows
,
cols
):
if
isinstance
(
names
,
str
):
names
=
names
.
replace
(
' '
,
''
).
split
(
','
)
matrices
=
[]
for
n
in
names
:
symbols
=
typed_symbols
(
"%s:%i"
%
(
n
,
rows
*
cols
),
dtype
)
matrices
.
append
(
sp
.
Matrix
(
rows
,
cols
,
lambda
i
,
j
:
symbols
[
i
*
cols
+
j
]))
return
tuple
(
matrices
)
# noinspection PyPep8Naming
# noinspection PyPep8Naming
class
address_of
(
sp
.
Function
):
class
address_of
(
sp
.
Function
):
is_Atom
=
True
is_Atom
=
True
...
@@ -86,6 +109,11 @@ class cast_func(sp.Function):
...
@@ -86,6 +109,11 @@ class cast_func(sp.Function):
@
property
@
property
def
is_integer
(
self
):
def
is_integer
(
self
):
"""
Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate
For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
"""
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
return
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
integer
)
or
super
().
is_integer
return
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
integer
)
or
super
().
is_integer
else
:
else
:
...
@@ -93,6 +121,9 @@ class cast_func(sp.Function):
...
@@ -93,6 +121,9 @@ class cast_func(sp.Function):
@
property
@
property
def
is_negative
(
self
):
def
is_negative
(
self
):
"""
See :func:`.TypedSymbol.is_integer`
"""
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
if
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
unsignedinteger
):
if
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
unsignedinteger
):
return
False
return
False
...
@@ -101,6 +132,9 @@ class cast_func(sp.Function):
...
@@ -101,6 +132,9 @@ class cast_func(sp.Function):
@
property
@
property
def
is_nonnegative
(
self
):
def
is_nonnegative
(
self
):
"""
See :func:`.TypedSymbol.is_integer`
"""
if
self
.
is_negative
is
False
:
if
self
.
is_negative
is
False
:
return
True
return
True
else
:
else
:
...
@@ -108,6 +142,9 @@ class cast_func(sp.Function):
...
@@ -108,6 +142,9 @@ class cast_func(sp.Function):
@
property
@
property
def
is_real
(
self
):
def
is_real
(
self
):
"""
See :func:`.TypedSymbol.is_integer`
"""
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
return
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
integer
)
or
\
return
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
integer
)
or
\
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
floating
)
or
\
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
floating
)
or
\
...
@@ -171,6 +208,11 @@ class TypedSymbol(sp.Symbol):
...
@@ -171,6 +208,11 @@ class TypedSymbol(sp.Symbol):
# For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
# For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
@
property
@
property
def
is_integer
(
self
):
def
is_integer
(
self
):
"""
Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate
For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
"""
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
return
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
integer
)
or
super
().
is_integer
return
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
integer
)
or
super
().
is_integer
else
:
else
:
...
@@ -178,6 +220,9 @@ class TypedSymbol(sp.Symbol):
...
@@ -178,6 +220,9 @@ class TypedSymbol(sp.Symbol):
@
property
@
property
def
is_negative
(
self
):
def
is_negative
(
self
):
"""
See :func:`.TypedSymbol.is_integer`
"""
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
if
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
unsignedinteger
):
if
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
unsignedinteger
):
return
False
return
False
...
@@ -186,6 +231,9 @@ class TypedSymbol(sp.Symbol):
...
@@ -186,6 +231,9 @@ class TypedSymbol(sp.Symbol):
@
property
@
property
def
is_nonnegative
(
self
):
def
is_nonnegative
(
self
):
"""
See :func:`.TypedSymbol.is_integer`
"""
if
self
.
is_negative
is
False
:
if
self
.
is_negative
is
False
:
return
True
return
True
else
:
else
:
...
@@ -193,6 +241,9 @@ class TypedSymbol(sp.Symbol):
...
@@ -193,6 +241,9 @@ class TypedSymbol(sp.Symbol):
@
property
@
property
def
is_real
(
self
):
def
is_real
(
self
):
"""
See :func:`.TypedSymbol.is_integer`
"""
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
return
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
integer
)
or
\
return
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
integer
)
or
\
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
floating
)
or
\
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
floating
)
or
\
...
@@ -370,12 +421,17 @@ def peel_off_type(dtype, type_to_peel_off):
...
@@ -370,12 +421,17 @@ def peel_off_type(dtype, type_to_peel_off):
return
dtype
return
dtype
def
collate_types
(
types
):
def
collate_types
(
types
,
forbid_collation_to_float
=
False
):
"""
"""
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy.
Uses the collation rules from numpy.
"""
"""
if
forbid_collation_to_float
:
types
=
[
t
for
t
in
types
if
not
(
hasattr
(
t
,
'is_float'
)
and
t
.
is_float
())]
if
not
types
:
return
[
create_type
(
'int32'
)]
# Pointer arithmetic case i.e. pointer + integer is allowed
# Pointer arithmetic case i.e. pointer + integer is allowed
if
any
(
type
(
t
)
is
PointerType
for
t
in
types
):
if
any
(
type
(
t
)
is
PointerType
for
t
in
types
):
pointer_type
=
None
pointer_type
=
None
...
@@ -433,6 +489,8 @@ def get_type_of_expression(expr,
...
@@ -433,6 +489,8 @@ def get_type_of_expression(expr,
return
create_type
(
default_float_type
)
return
create_type
(
default_float_type
)
elif
isinstance
(
expr
,
ResolvedFieldAccess
):
elif
isinstance
(
expr
,
ResolvedFieldAccess
):
return
expr
.
field
.
dtype
return
expr
.
field
.
dtype
elif
isinstance
(
expr
,
pystencils
.
field
.
Field
.
AbstractAccess
):
return
expr
.
field
.
dtype
elif
isinstance
(
expr
,
TypedSymbol
):
elif
isinstance
(
expr
,
TypedSymbol
):
return
expr
.
dtype
return
expr
.
dtype
elif
isinstance
(
expr
,
sp
.
Symbol
):
elif
isinstance
(
expr
,
sp
.
Symbol
):
...
@@ -525,6 +583,10 @@ class BasicType(Type):
...
@@ -525,6 +583,10 @@ class BasicType(Type):
def
numpy_dtype
(
self
):
def
numpy_dtype
(
self
):
return
self
.
_dtype
return
self
.
_dtype
@
property
def
sympy_dtype
(
self
):
return
getattr
(
sympy
.
codegen
.
ast
,
str
(
self
.
numpy_dtype
))
@
property
@
property
def
item_size
(
self
):
def
item_size
(
self
):
return
1
return
1
...
...
pystencils/field.py
View file @
1c0665c4
import
functools
import
hashlib
import
hashlib
import
operator
import
pickle
import
pickle
import
re
import
re
from
enum
import
Enum
from
enum
import
Enum
...
@@ -9,6 +11,7 @@ import numpy as np
...
@@ -9,6 +11,7 @@ import numpy as np
import
sympy
as
sp
import
sympy
as
sp
from
sympy.core.cache
import
cacheit
from
sympy.core.cache
import
cacheit
import
pystencils
from
pystencils.alignedarray
import
aligned_empty
from
pystencils.alignedarray
import
aligned_empty
from
pystencils.data_types
import
StructType
,
TypedSymbol
,
create_type
from
pystencils.data_types
import
StructType
,
TypedSymbol
,
create_type
from
pystencils.kernelparameters
import
FieldShapeSymbol
,
FieldStrideSymbol
from
pystencils.kernelparameters
import
FieldShapeSymbol
,
FieldStrideSymbol
...
@@ -38,7 +41,6 @@ def fields(description=None, index_dimensions=0, layout=None, **kwargs):
...
@@ -38,7 +41,6 @@ def fields(description=None, index_dimensions=0, layout=None, **kwargs):
>>> assert s.index_dimensions == 0 and s.dtype.numpy_dtype == arr_s.dtype
>>> assert s.index_dimensions == 0 and s.dtype.numpy_dtype == arr_s.dtype
>>> assert v.index_shape == (2,)
>>> assert v.index_shape == (2,)
Format string can be left out, field names are taken from keyword arguments.
Format string can be left out, field names are taken from keyword arguments.
>>> fields(f1=arr_s, f2=arr_s)
>>> fields(f1=arr_s, f2=arr_s)
[f1, f2]
[f1, f2]
...
@@ -292,6 +294,10 @@ class Field(AbstractField):
...
@@ -292,6 +294,10 @@ class Field(AbstractField):
self
.
shape
=
shape
self
.
shape
=
shape
self
.
strides
=
strides
self
.
strides
=
strides
self
.
latex_name
=
None
# type: Optional[str]
self
.
latex_name
=
None
# type: Optional[str]
self
.
coordinate_origin
=
sp
.
Matrix
(
tuple
(
0
for
_
in
range
(
self
.
spatial_dimensions
)
))
# type: tuple[float,sp.Symbol]
self
.
coordinate_transform
=
sp
.
eye
(
self
.
spatial_dimensions
)
def
new_field_with_different_name
(
self
,
new_name
):
def
new_field_with_different_name
(
self
,
new_name
):
if
self
.
has_fixed_shape
:
if
self
.
has_fixed_shape
:
...
@@ -312,6 +318,9 @@ class Field(AbstractField):
...
@@ -312,6 +318,9 @@ class Field(AbstractField):
def
ndim
(
self
)
->
int
:
def
ndim
(
self
)
->
int
:
return
len
(
self
.
shape
)
return
len
(
self
.
shape
)
def
values_per_cell
(
self
)
->
int
:
return
functools
.
reduce
(
operator
.
mul
,
self
.
index_shape
,
1
)
@
property
@
property
def
layout
(
self
):
def
layout
(
self
):
return
self
.
_layout
return
self
.
_layout
...
@@ -393,6 +402,27 @@ class Field(AbstractField):
...
@@ -393,6 +402,27 @@ class Field(AbstractField):
assert
FieldType
.
is_custom
(
self
)
assert
FieldType
.
is_custom
(
self
)
return
Field
.
Access
(
self
,
offset
,
index
,
is_absolute_access
=
True
)
return
Field
.
Access
(
self
,
offset
,
index
,
is_absolute_access
=
True
)
def
interpolated_access
(
self
,
offset
:
Tuple
,
interpolation_mode
=
'linear'
,
address_mode
=
'BORDER'
,
allow_textures
=
True
):
"""Provides access to field values at non-integer positions
``interpolated_access`` is similar to :func:`Field.absolute_access` except that
it allows non-integer offsets and automatic handling of out-of-bound accesses.
:param offset: Tuple of spatial coordinates (can be floats)
:param interpolation_mode: One of :class:`pystencils.interpolation_astnodes.InterpolationMode`
:param address_mode: How boundaries are handled can be 'border', 'wrap', 'mirror', 'clamp'
:param allow_textures: Allow implementation by texture accesses on GPUs
"""
from
pystencils.interpolation_astnodes
import
Interpolator
return
Interpolator
(
self
,
interpolation_mode
,
address_mode
,
allow_textures
=
allow_textures
).
at
(
offset
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
center
=
tuple
([
0
]
*
self
.
spatial_dimensions
)
center
=
tuple
([
0
]
*
self
.
spatial_dimensions
)
return
Field
.
Access
(
self
,
center
)(
*
args
,
**
kwargs
)
return
Field
.
Access
(
self
,
center
)(
*
args
,
**
kwargs
)
...
@@ -409,6 +439,34 @@ class Field(AbstractField):
...
@@ -409,6 +439,34 @@ class Field(AbstractField):
return
False
return
False
return
self
.
hashable_contents
()
==
other
.
hashable_contents
()
return
self
.
hashable_contents
()
==
other
.
hashable_contents
()
@
property
def
physical_coordinates
(
self
):
return
self
.
coordinate_transform
@
(
self
.
coordinate_origin
+
pystencils
.
x_vector
(
self
.
spatial_dimensions
))
@
property
def
physical_coordinates_staggered
(
self
):
return
self
.
coordinate_transform
@
\
(
self
.
coordinate_origin
+
pystencils
.
x_staggered_vector
(
self
.
spatial_dimensions
))
def
index_to_physical
(
self
,
index_coordinates
,
staggered
=
False
):
if
staggered
:
index_coordinates
=
sp
.
Matrix
([
i
+
0.5
for
i
in
index_coordinates
])
return
self
.
coordinate_transform
@
(
self
.
coordinate_origin
+
index_coordinates
)
def
physical_to_index
(
self
,
physical_coordinates
,
staggered
=
False
):
rtn
=
self
.
coordinate_transform
.
inv
()
@
physical_coordinates
-
self
.
coordinate_origin
if
staggered
:
rtn
=
sp
.
Matrix
([
i
-
0.5
for
i
in
rtn
])
return
rtn
def
index_to_staggered_physical_coordinates
(
self
,
symbol_vector
):
symbol_vector
+=
sp
.
Matrix
([
0.5
]
*
self
.
spatial_dimensions
)
return
self
.
create_physical_coordinates
(
symbol_vector
)
def
set_coordinate_origin_to_field_center
(
self
):
self
.
coordinate_origin
=
-
sp
.
Matrix
([
i
/
2
for
i
in
self
.
spatial_shape
])
# noinspection PyAttributeOutsideInit,PyUnresolvedReferences
# noinspection PyAttributeOutsideInit,PyUnresolvedReferences
class
Access
(
TypedSymbol
,
AbstractField
.
AbstractAccess
):
class
Access
(
TypedSymbol
,
AbstractField
.
AbstractAccess
):
"""Class representing a relative access into a `Field`.
"""Class representing a relative access into a `Field`.
...
@@ -429,11 +487,12 @@ class Field(AbstractField):
...
@@ -429,11 +487,12 @@ class Field(AbstractField):
>>> central_y_component.at_index(0) # change component
>>> central_y_component.at_index(0) # change component
v_C^0
v_C^0
"""
"""
def
__new__
(
cls
,
name
,
*
args
,
**
kwargs
):
def
__new__
(
cls
,
name
,
*
args
,
**
kwargs
):
obj
=
Field
.
Access
.
__xnew_cached_
(
cls
,
name
,
*
args
,
**
kwargs
)
obj
=
Field
.
Access
.
__xnew_cached_
(
cls
,
name
,
*
args
,
**
kwargs
)
return
obj
return
obj
def
__new_stage2__
(
self
,
field
,
offsets
=
(
0
,
0
,
0
),
idx
=
None
,
is_absolute_access
=
False
):
def
__new_stage2__
(
self
,
field
,
offsets
=
(
0
,
0
,
0
),
idx
=
None
,
is_absolute_access
=
False
,
dtype
=
None
):
field_name
=
field
.
name
field_name
=
field
.
name
offsets_and_index
=
(
*
offsets
,
*
idx
)
if
idx
is
not
None
else
offsets
offsets_and_index
=
(
*
offsets
,
*
idx
)
if
idx
is
not
None
else
offsets
constant_offsets
=
not
any
([
isinstance
(
o
,
sp
.
Basic
)
and
not
o
.
is_Integer
for
o
in
offsets_and_index
])
constant_offsets
=
not
any
([
isinstance
(
o
,
sp
.
Basic
)
and
not
o
.
is_Integer
for
o
in
offsets_and_index
])
...
@@ -484,7 +543,7 @@ class Field(AbstractField):
...
@@ -484,7 +543,7 @@ class Field(AbstractField):
return
obj
return
obj
def
__getnewargs__
(
self
):
def
__getnewargs__
(
self
):
return
self
.
field
,
self
.
offsets
,
self
.
index
,
self
.
is_absolute_access
return
self
.
field
,
self
.
offsets
,
self
.
index
,
self
.
is_absolute_access
,
self
.
dtype
# noinspection SpellCheckingInspection
# noinspection SpellCheckingInspection
__xnew__
=
staticmethod
(
__new_stage2__
)
__xnew__
=
staticmethod
(
__new_stage2__
)
...
@@ -503,7 +562,7 @@ class Field(AbstractField):
...
@@ -503,7 +562,7 @@ class Field(AbstractField):
if
len
(
idx
)
!=
self
.
field
.
index_dimensions
:
if
len
(
idx
)
!=
self
.
field
.
index_dimensions
:
raise
ValueError
(
"Wrong number of indices: "
raise
ValueError
(
"Wrong number of indices: "
"Got %d, expected %d"
%
(
len
(
idx
),
self
.
field
.
index_dimensions
))
"Got %d, expected %d"
%
(
len
(
idx
),
self
.
field
.
index_dimensions
))
return
Field
.
Access
(
self
.
field
,
self
.
_offsets
,
idx
)
return
Field
.
Access
(
self
.
field
,
self
.
_offsets
,
idx
,
dtype
=
self
.
dtype
)
def
__getitem__
(
self
,
*
idx
):
def
__getitem__
(
self
,
*
idx
):
return
self
.
__call__
(
*
idx
)
return
self
.
__call__
(
*
idx
)
...
@@ -562,7 +621,7 @@ class Field(AbstractField):
...
@@ -562,7 +621,7 @@ class Field(AbstractField):
"""
"""
offset_list
=
list
(
self
.
offsets
)
offset_list
=
list
(
self
.
offsets
)
offset_list
[
coord_id
]
+=
offset
offset_list
[
coord_id
]
+=
offset
return
Field
.
Access
(
self
.
field
,
tuple
(
offset_list
),
self
.
index
)
return
Field
.
Access
(
self
.
field
,
tuple
(
offset_list
),
self
.
index
,
dtype
=
self
.
dtype
)
def
get_shifted
(
self
,
*
shift
)
->
'Field.Access'
:
def
get_shifted
(
self
,
*
shift
)
->
'Field.Access'
:
"""Returns a new Access with changed spatial coordinates
"""Returns a new Access with changed spatial coordinates
...
@@ -572,7 +631,10 @@ class Field(AbstractField):
...
@@ -572,7 +631,10 @@ class Field(AbstractField):
>>> f[0,0].get_shifted(1, 1)
>>> f[0,0].get_shifted(1, 1)
f_NE
f_NE
"""
"""
return
Field
.
Access
(
self
.
field
,
tuple
(
a
+
b
for
a
,
b
in
zip
(
shift
,
self
.
offsets
)),
self
.
index
)
return
Field
.
Access
(
self
.
field
,
tuple
(
a
+
b
for
a
,
b
in
zip
(
shift
,
self
.
offsets
)),
self
.
index
,
dtype
=
self
.
dtype
)
def
at_index
(
self
,
*
idx_tuple
)
->
'Field.Access'
:
def
at_index
(
self
,
*
idx_tuple
)
->
'Field.Access'
:
"""Returns new Access with changed index.
"""Returns new Access with changed index.
...
@@ -582,7 +644,7 @@ class Field(AbstractField):
...
@@ -582,7 +644,7 @@ class Field(AbstractField):
>>> f(0).at_index(8)
>>> f(0).at_index(8)
f_C^8
f_C^8
"""
"""
return
Field
.
Access
(
self
.
field
,
self
.
offsets
,
idx_tuple
)
return
Field
.
Access
(
self
.
field
,
self
.
offsets
,
idx_tuple
,
dtype
=
self
.
dtype
)
@
property
@
property
def
is_absolute_access
(
self
)
->
bool
: