Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Stephan Seitz
pystencils
Commits
dcf2c6f4
Commit
dcf2c6f4
authored
Sep 25, 2019
by
Martin Bauer
Browse files
Merge branch 'interpolation-24.0.9' into 'master'
Interpolation 24.0.9 See merge request
pycodegen/pystencils!56
parents
472f6f6c
84d81234
Changes
25
Expand all
Hide whitespace changes
Inline
Side-by-side
pystencils/__init__.py
View file @
dcf2c6f4
...
...
@@ -12,6 +12,8 @@ from .kernelcreation import create_indexed_kernel, create_kernel, create_stagger
from
.simp
import
AssignmentCollection
from
.slicing
import
make_slice
from
.sympyextensions
import
SymbolCreator
from
.spatial_coordinates
import
(
x_
,
x_staggered
,
x_staggered_vector
,
x_vector
,
y_
,
y_staggered
,
z_
,
z_staggered
)
try
:
import
pystencils_autodiff
...
...
@@ -30,5 +32,8 @@ __all__ = ['Field', 'FieldType', 'fields',
'SymbolCreator'
,
'create_data_handling'
,
'kernel'
,
'x_'
,
'y_'
,
'z_'
,
'x_staggered'
,
'y_staggered'
,
'z_staggered'
,
'x_vector'
,
'x_staggered_vector'
,
'fd'
,
'stencil'
]
pystencils/astnodes.py
View file @
dcf2c6f4
import
collections.abc
import
itertools
import
uuid
from
typing
import
Any
,
List
,
Optional
,
Sequence
,
Set
,
Union
...
...
@@ -33,7 +35,7 @@ class Node:
raise
NotImplementedError
()
def
subs
(
self
,
subs_dict
)
->
None
:
"""Inplace!
s
ubstitute, similar to sympy's but modifies the AST inplace."""
"""Inplace!
S
ubstitute, similar to sympy's but modifies the AST inplace."""
for
a
in
self
.
args
:
a
.
subs
(
subs_dict
)
...
...
@@ -102,7 +104,8 @@ class Conditional(Node):
result
=
self
.
true_block
.
undefined_symbols
if
self
.
false_block
:
result
.
update
(
self
.
false_block
.
undefined_symbols
)
result
.
update
(
self
.
condition_expr
.
atoms
(
sp
.
Symbol
))
if
hasattr
(
self
.
condition_expr
,
'atoms'
):
result
.
update
(
self
.
condition_expr
.
atoms
(
sp
.
Symbol
))
return
result
def
__str__
(
self
):
...
...
@@ -212,9 +215,16 @@ class KernelFunction(Node):
"""Set of Field instances: fields which are accessed inside this kernel function"""
return
set
(
o
.
field
for
o
in
self
.
atoms
(
ResolvedFieldAccess
))
def
fields_written
(
self
):
assigments
=
self
.
atoms
(
SympyAssignment
)
return
{
a
.
lhs
.
field
for
a
in
assigments
if
isinstance
(
a
.
lhs
,
ResolvedFieldAccess
)}
@
property
def
fields_written
(
self
)
->
Set
[
'ResolvedFieldAccess'
]:
assignments
=
self
.
atoms
(
SympyAssignment
)
return
{
a
.
lhs
.
field
for
a
in
assignments
if
isinstance
(
a
.
lhs
,
ResolvedFieldAccess
)}
@
property
def
fields_read
(
self
)
->
Set
[
'ResolvedFieldAccess'
]:
assignments
=
self
.
atoms
(
SympyAssignment
)
return
set
().
union
(
itertools
.
chain
.
from_iterable
([
f
.
field
for
f
in
a
.
rhs
.
free_symbols
if
hasattr
(
f
,
'field'
)]
for
a
in
assignments
))
def
get_parameters
(
self
)
->
Sequence
[
'KernelFunction.Parameter'
]:
"""Returns list of parameters for this function.
...
...
@@ -283,8 +293,15 @@ class Block(Node):
a
.
subs
(
subs_dict
)
def
insert_front
(
self
,
node
):
node
.
parent
=
self
self
.
_nodes
.
insert
(
0
,
node
)
if
isinstance
(
node
,
collections
.
abc
.
Iterable
):
node
=
list
(
node
)
for
n
in
node
:
n
.
parent
=
self
self
.
_nodes
=
node
+
self
.
_nodes
else
:
node
.
parent
=
self
self
.
_nodes
.
insert
(
0
,
node
)
def
insert_before
(
self
,
new_node
,
insert_before
):
new_node
.
parent
=
self
...
...
@@ -485,7 +502,7 @@ class SympyAssignment(Node):
def
__init__
(
self
,
lhs_symbol
,
rhs_expr
,
is_const
=
True
):
super
(
SympyAssignment
,
self
).
__init__
(
parent
=
None
)
self
.
_lhs_symbol
=
lhs_symbol
self
.
rhs
=
rhs_expr
self
.
rhs
=
sp
.
simplify
(
rhs_expr
)
self
.
_is_const
=
is_const
self
.
_is_declaration
=
self
.
__is_declaration
()
...
...
@@ -678,3 +695,49 @@ def early_out(condition):
def
get_dummy_symbol
(
dtype
=
'bool'
):
return
TypedSymbol
(
'dummy%s'
%
uuid
.
uuid4
().
hex
,
create_type
(
dtype
))
class
SourceCodeComment
(
Node
):
def
__init__
(
self
,
text
):
self
.
text
=
text
@
property
def
args
(
self
):
return
[]
@
property
def
symbols_defined
(
self
):
return
set
()
@
property
def
undefined_symbols
(
self
):
return
set
()
def
__str__
(
self
):
return
"/* "
+
self
.
text
+
" */"
def
__repr__
(
self
):
return
self
.
__str__
()
class
EmptyLine
(
Node
):
def
__init__
(
self
):
pass
@
property
def
args
(
self
):
return
[]
@
property
def
symbols_defined
(
self
):
return
set
()
@
property
def
undefined_symbols
(
self
):
return
set
()
def
__str__
(
self
):
return
""
def
__repr__
(
self
):
return
self
.
__str__
()
pystencils/backends/cbackend.py
View file @
dcf2c6f4
...
...
@@ -102,6 +102,10 @@ def get_headers(ast_node: Node) -> Set[str]:
if
isinstance
(
a
,
Node
):
headers
.
update
(
get_headers
(
a
))
for
g
in
get_global_declarations
(
ast_node
):
if
isinstance
(
g
,
Node
):
headers
.
update
(
get_headers
(
g
))
return
sorted
(
headers
)
...
...
@@ -131,6 +135,12 @@ class CustomCodeNode(Node):
def
undefined_symbols
(
self
):
return
self
.
_symbols_read
-
self
.
_symbols_defined
def
__eq___
(
self
,
other
):
return
self
.
_code
==
other
.
_code
def
__hash__
(
self
):
return
hash
(
self
.
_code
)
class
PrintNode
(
CustomCodeNode
):
# noinspection SpellCheckingInspection
...
...
@@ -263,6 +273,12 @@ class CBackend:
def
_print_CustomCodeNode
(
self
,
node
):
return
node
.
get_code
(
self
.
_dialect
,
self
.
_vector_instruction_set
)
def
_print_SourceCodeComment
(
self
,
node
):
return
"/* "
+
node
.
text
+
" */"
def
_print_EmptyLine
(
self
,
node
):
return
""
def
_print_Conditional
(
self
,
node
):
cond_type
=
get_type_of_expression
(
node
.
condition_expr
)
if
isinstance
(
cond_type
,
VectorType
):
...
...
@@ -409,6 +425,7 @@ class CustomSympyPrinter(CCodePrinter):
condition
=
self
.
_print
(
var
)
+
' <= '
+
self
.
_print
(
end
)
# if start < end else '>='
)
return
code
_print_Max
=
C89CodePrinter
.
_print_Max
_print_Min
=
C89CodePrinter
.
_print_Min
...
...
pystencils/backends/cuda_backend.py
View file @
dcf2c6f4
...
...
@@ -3,6 +3,7 @@ from os.path import dirname, join
from
pystencils.astnodes
import
Node
from
pystencils.backends.cbackend
import
CBackend
,
CustomSympyPrinter
,
generate_c
from
pystencils.fast_approximation
import
fast_division
,
fast_inv_sqrt
,
fast_sqrt
from
pystencils.interpolation_astnodes
import
InterpolationMode
with
open
(
join
(
dirname
(
__file__
),
'cuda_known_functions.txt'
))
as
f
:
lines
=
f
.
readlines
()
...
...
@@ -43,11 +44,19 @@ class CudaBackend(CBackend):
return
code
def
_print_TextureDeclaration
(
self
,
node
):
code
=
"texture<%s, cudaTextureType%iD, cudaReadModeElementType> %s;"
%
(
str
(
node
.
texture
.
field
.
dtype
),
node
.
texture
.
field
.
spatial_dimensions
,
node
.
texture
)
if
node
.
texture
.
field
.
dtype
.
numpy_dtype
.
itemsize
>
4
:
code
=
"texture<fp_tex_%s, cudaTextureType%iD, cudaReadModeElementType> %s;"
%
(
str
(
node
.
texture
.
field
.
dtype
),
node
.
texture
.
field
.
spatial_dimensions
,
node
.
texture
)
else
:
code
=
"texture<%s, cudaTextureType%iD, cudaReadModeElementType> %s;"
%
(
str
(
node
.
texture
.
field
.
dtype
),
node
.
texture
.
field
.
spatial_dimensions
,
node
.
texture
)
return
code
def
_print_SkipIteration
(
self
,
_
):
...
...
@@ -62,17 +71,23 @@ class CudaSympyPrinter(CustomSympyPrinter):
self
.
known_functions
.
update
(
CUDA_KNOWN_FUNCTIONS
)
def
_print_TextureAccess
(
self
,
node
):
dtype
=
node
.
texture
.
field
.
dtype
.
numpy_dtype
if
node
.
texture
.
cubic_bspline_interpolation
:
template
=
"cubicTex%iDSimple
<%s>
(%s, %s)"
if
node
.
texture
.
interpolation_mode
==
InterpolationMode
.
CUBIC_SPLINE
:
template
=
"cubicTex%iDSimple(%s, %s)"
else
:
template
=
"tex%iD<%s>(%s, %s)"
if
dtype
.
itemsize
>
4
:
# Use PyCuda hack!
# https://github.com/inducer/pycuda/blob/master/pycuda/cuda/pycuda-helpers.hpp
template
=
"fp_tex%iD(%s, %s)"
else
:
template
=
"tex%iD(%s, %s)"
code
=
template
%
(
node
.
texture
.
field
.
spatial_dimensions
,
str
(
node
.
texture
.
field
.
dtype
),
str
(
node
.
texture
),
', '
.
join
(
self
.
_print
(
o
)
for
o
in
node
.
offsets
)
# + 0.5 comes from Nvidia's staggered indexing
', '
.
join
(
self
.
_print
(
o
+
0.5
)
for
o
in
reversed
(
node
.
offsets
))
)
return
code
...
...
pystencils/backends/cuda_known_functions.txt
View file @
dcf2c6f4
...
...
@@ -45,6 +45,7 @@ tex1D
tex2D
tex3D
sqrtf
rsqrtf
cbrtf
rcbrtf
...
...
pystencils/backends/json.py
0 → 100644
View file @
dcf2c6f4
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import
json
from
pystencils.backends.cbackend
import
CustomSympyPrinter
,
generate_c
try
:
import
toml
except
Exception
:
class
toml
:
def
dumps
(
self
,
*
args
):
raise
ImportError
(
'toml not installed'
)
def
dump
(
self
,
*
args
):
raise
ImportError
(
'toml not installed'
)
try
:
import
yaml
except
Exception
:
class
yaml
:
def
dumps
(
self
,
*
args
):
raise
ImportError
(
'pyyaml not installed'
)
def
dump
(
self
,
*
args
):
raise
ImportError
(
'pyyaml not installed'
)
def
expr_to_dict
(
expr_or_node
,
with_c_code
=
True
,
full_class_names
=
False
):
self
=
{
'str'
:
str
(
expr_or_node
)}
if
with_c_code
:
try
:
self
.
update
({
'c'
:
generate_c
(
expr_or_node
)})
except
Exception
:
try
:
self
.
update
({
'c'
:
CustomSympyPrinter
().
doprint
(
expr_or_node
)})
except
Exception
:
pass
for
a
in
expr_or_node
.
args
:
self
.
update
({
str
(
a
.
__class__
if
full_class_names
else
a
.
__class__
.
__name__
):
expr_to_dict
(
a
)})
return
self
def
print_json
(
expr_or_node
):
dict
=
expr_to_dict
(
expr_or_node
)
return
json
.
dumps
(
dict
,
indent
=
4
)
def
write_json
(
filename
,
expr_or_node
):
dict
=
expr_to_dict
(
expr_or_node
)
with
open
(
filename
,
'w'
)
as
f
:
json
.
dump
(
dict
,
f
,
indent
=
4
)
def
print_toml
(
expr_or_node
):
dict
=
expr_to_dict
(
expr_or_node
,
full_class_names
=
False
)
return
toml
.
dumps
(
dict
)
def
write_toml
(
filename
,
expr_or_node
):
dict
=
expr_to_dict
(
expr_or_node
)
with
open
(
filename
,
'w'
)
as
f
:
toml
.
dump
(
dict
,
f
)
def
print_yaml
(
expr_or_node
):
dict
=
expr_to_dict
(
expr_or_node
,
full_class_names
=
False
)
return
yaml
.
dump
(
dict
)
def
write_yaml
(
filename
,
expr_or_node
):
dict
=
expr_to_dict
(
expr_or_node
)
with
open
(
filename
,
'w'
)
as
f
:
yaml
.
dump
(
dict
,
f
)
pystencils/cpu/kernelcreation.py
View file @
dcf2c6f4
...
...
@@ -10,8 +10,8 @@ from pystencils.data_types import BasicType, StructType, TypedSymbol, create_typ
from
pystencils.field
import
Field
,
FieldType
from
pystencils.transformations
import
(
add_types
,
filtered_tree_iteration
,
get_base_buffer_index
,
get_optimal_loop_ordering
,
make_loop_over_domain
,
move_constants_before_loop
,
parse_base_pointer_info
,
resolve_buffer_accesses
,
resolve_field_accesses
,
split_inner_loop
)
implement_interpolations
,
make_loop_over_domain
,
move_constants_before_loop
,
parse_base_pointer_info
,
resolve_buffer_accesses
,
resolve_field_accesses
,
split_inner_loop
)
AssignmentOrAstNodeList
=
List
[
Union
[
Assignment
,
ast
.
Node
]]
...
...
@@ -67,6 +67,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
ghost_layers
=
ghost_layers
,
loop_order
=
loop_order
)
ast_node
=
KernelFunction
(
loop_node
,
'cpu'
,
'c'
,
compile_function
=
make_python_function
,
ghost_layers
=
ghost_layer_info
,
function_name
=
function_name
)
implement_interpolations
(
body
)
if
split_groups
:
typed_split_groups
=
[[
type_symbol
(
s
)
for
s
in
split_group
]
for
split_group
in
split_groups
]
...
...
@@ -139,6 +140,8 @@ def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, fu
loop_body
=
Block
([])
loop_node
=
LoopOverCoordinate
(
loop_body
,
coordinate_to_loop_over
=
0
,
start
=
0
,
stop
=
index_fields
[
0
].
shape
[
0
])
implement_interpolations
(
loop_node
)
for
assignment
in
assignments
:
loop_body
.
append
(
assignment
)
...
...
pystencils/data_types.py
View file @
dcf2c6f4
import
ctypes
from
collections
import
defaultdict
from
functools
import
partial
from
typing
import
Tuple
import
numpy
as
np
import
sympy
as
sp
import
sympy.codegen.ast
from
sympy.core.cache
import
cacheit
from
sympy.logic.boolalg
import
Boolean
import
pystencils
from
pystencils.cache
import
memorycache
,
memorycache_if_hashable
from
pystencils.utils
import
all_equal
...
...
@@ -17,6 +20,26 @@ except ImportError as e:
_ir_importerror
=
e
def
typed_symbols
(
names
,
dtype
,
*
args
):
symbols
=
sp
.
symbols
(
names
,
*
args
)
if
isinstance
(
symbols
,
Tuple
):
return
tuple
(
TypedSymbol
(
str
(
s
),
dtype
)
for
s
in
symbols
)
else
:
return
TypedSymbol
(
str
(
symbols
),
dtype
)
def
matrix_symbols
(
names
,
dtype
,
rows
,
cols
):
if
isinstance
(
names
,
str
):
names
=
names
.
replace
(
' '
,
''
).
split
(
','
)
matrices
=
[]
for
n
in
names
:
symbols
=
typed_symbols
(
"%s:%i"
%
(
n
,
rows
*
cols
),
dtype
)
matrices
.
append
(
sp
.
Matrix
(
rows
,
cols
,
lambda
i
,
j
:
symbols
[
i
*
cols
+
j
]))
return
tuple
(
matrices
)
# noinspection PyPep8Naming
class
address_of
(
sp
.
Function
):
is_Atom
=
True
...
...
@@ -86,6 +109,11 @@ class cast_func(sp.Function):
@
property
def
is_integer
(
self
):
"""
Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate
For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
"""
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
return
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
integer
)
or
super
().
is_integer
else
:
...
...
@@ -93,6 +121,9 @@ class cast_func(sp.Function):
@
property
def
is_negative
(
self
):
"""
See :func:`.TypedSymbol.is_integer`
"""
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
if
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
unsignedinteger
):
return
False
...
...
@@ -101,6 +132,9 @@ class cast_func(sp.Function):
@
property
def
is_nonnegative
(
self
):
"""
See :func:`.TypedSymbol.is_integer`
"""
if
self
.
is_negative
is
False
:
return
True
else
:
...
...
@@ -108,6 +142,9 @@ class cast_func(sp.Function):
@
property
def
is_real
(
self
):
"""
See :func:`.TypedSymbol.is_integer`
"""
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
return
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
integer
)
or
\
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
floating
)
or
\
...
...
@@ -171,6 +208,11 @@ class TypedSymbol(sp.Symbol):
# For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
@
property
def
is_integer
(
self
):
"""
Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate
For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
"""
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
return
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
integer
)
or
super
().
is_integer
else
:
...
...
@@ -178,6 +220,9 @@ class TypedSymbol(sp.Symbol):
@
property
def
is_negative
(
self
):
"""
See :func:`.TypedSymbol.is_integer`
"""
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
if
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
unsignedinteger
):
return
False
...
...
@@ -186,6 +231,9 @@ class TypedSymbol(sp.Symbol):
@
property
def
is_nonnegative
(
self
):
"""
See :func:`.TypedSymbol.is_integer`
"""
if
self
.
is_negative
is
False
:
return
True
else
:
...
...
@@ -193,6 +241,9 @@ class TypedSymbol(sp.Symbol):
@
property
def
is_real
(
self
):
"""
See :func:`.TypedSymbol.is_integer`
"""
if
hasattr
(
self
.
dtype
,
'numpy_dtype'
):
return
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
integer
)
or
\
np
.
issubdtype
(
self
.
dtype
.
numpy_dtype
,
np
.
floating
)
or
\
...
...
@@ -370,12 +421,17 @@ def peel_off_type(dtype, type_to_peel_off):
return
dtype
def
collate_types
(
types
):
def
collate_types
(
types
,
forbid_collation_to_float
=
False
):
"""
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy.
"""
if
forbid_collation_to_float
:
types
=
[
t
for
t
in
types
if
not
(
hasattr
(
t
,
'is_float'
)
and
t
.
is_float
())]
if
not
types
:
return
[
create_type
(
'int32'
)]
# Pointer arithmetic case i.e. pointer + integer is allowed
if
any
(
type
(
t
)
is
PointerType
for
t
in
types
):
pointer_type
=
None
...
...
@@ -433,6 +489,8 @@ def get_type_of_expression(expr,
return
create_type
(
default_float_type
)
elif
isinstance
(
expr
,
ResolvedFieldAccess
):
return
expr
.
field
.
dtype
elif
isinstance
(
expr
,
pystencils
.
field
.
Field
.
AbstractAccess
):
return
expr
.
field
.
dtype
elif
isinstance
(
expr
,
TypedSymbol
):
return
expr
.
dtype
elif
isinstance
(
expr
,
sp
.
Symbol
):
...
...
@@ -525,6 +583,10 @@ class BasicType(Type):
def
numpy_dtype
(
self
):
return
self
.
_dtype
@
property
def
sympy_dtype
(
self
):
return
getattr
(
sympy
.
codegen
.
ast
,
str
(
self
.
numpy_dtype
))
@
property
def
item_size
(
self
):
return
1
...
...
pystencils/field.py
View file @
dcf2c6f4
import
functools
import
hashlib
import
operator
import
pickle
import
re
from
enum
import
Enum
...
...
@@ -9,6 +11,7 @@ import numpy as np
import
sympy
as
sp
from
sympy.core.cache
import
cacheit
import
pystencils
from
pystencils.alignedarray
import
aligned_empty
from
pystencils.data_types
import
StructType
,
TypedSymbol
,
create_type
from
pystencils.kernelparameters
import
FieldShapeSymbol
,
FieldStrideSymbol
...
...
@@ -38,7 +41,6 @@ def fields(description=None, index_dimensions=0, layout=None, **kwargs):
>>> assert s.index_dimensions == 0 and s.dtype.numpy_dtype == arr_s.dtype
>>> assert v.index_shape == (2,)
Format string can be left out, field names are taken from keyword arguments.
>>> fields(f1=arr_s, f2=arr_s)
[f1, f2]
...
...
@@ -292,6 +294,10 @@ class Field(AbstractField):
self
.
shape
=
shape
self
.
strides
=
strides
self
.
latex_name
=
None
# type: Optional[str]
self
.
coordinate_origin
=
sp
.
Matrix
(
tuple
(
0
for
_
in
range
(
self
.
spatial_dimensions
)
))
# type: tuple[float,sp.Symbol]
self
.
coordinate_transform
=
sp
.
eye
(
self
.
spatial_dimensions
)
def
new_field_with_different_name
(
self
,
new_name
):
if
self
.
has_fixed_shape
:
...
...
@@ -312,6 +318,9 @@ class Field(AbstractField):
def
ndim
(
self
)
->
int
:
return
len
(
self
.
shape
)
def
values_per_cell
(
self
)
->
int
:
return
functools
.
reduce
(
operator
.
mul
,
self
.
index_shape
,
1
)
@
property
def
layout
(
self
):
return
self
.
_layout
...
...
@@ -393,6 +402,27 @@ class Field(AbstractField):
assert
FieldType
.
is_custom
(
self
)
return
Field
.
Access
(
self
,
offset
,
index
,
is_absolute_access
=
True
)
def
interpolated_access
(
self
,
offset
:
Tuple
,
interpolation_mode
=
'linear'
,
address_mode
=
'BORDER'
,
allow_textures
=
True
):
"""Provides access to field values at non-integer positions
``interpolated_access`` is similar to :func:`Field.absolute_access` except that
it allows non-integer offsets and automatic handling of out-of-bound accesses.
:param offset: Tuple of spatial coordinates (can be floats)
:param interpolation_mode: One of :class:`pystencils.interpolation_astnodes.InterpolationMode`
:param address_mode: How boundaries are handled can be 'border', 'wrap', 'mirror', 'clamp'
:param allow_textures: Allow implementation by texture accesses on GPUs
"""
from
pystencils.interpolation_astnodes
import
Interpolator
return
Interpolator
(
self
,
interpolation_mode
,
address_mode
,
allow_textures
=
allow_textures
).
at
(
offset
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
center
=
tuple
([
0
]
*
self
.
spatial_dimensions
)
return
Field
.
Access
(
self
,
center
)(
*
args
,
**
kwargs
)
...
...
@@ -409,6 +439,34 @@ class Field(AbstractField):
return
False
return
self
.
hashable_contents
()
==
other
.
hashable_contents
()
@
property
def
physical_coordinates
(
self
):
return
self
.
coordinate_transform
@
(
self
.
coordinate_origin
+
pystencils
.
x_vector
(
self
.
spatial_dimensions
))
@
property
def
physical_coordinates_staggered
(
self
):
return
self
.
coordinate_transform
@
\
(
self
.
coordinate_origin
+
pystencils
.
x_staggered_vector
(
self
.
spatial_dimensions
))
def
index_to_physical
(
self
,
index_coordinates
,
staggered
=
False
):
if
staggered
:
index_coordinates
=
sp
.
Matrix
([
i
+
0.5
for
i
in
index_coordinates
])
return
self
.
coordinate_transform
@
(
self
.
coordinate_origin
+
index_coordinates
)
def
physical_to_index
(
self
,
physical_coordinates
,
staggered
=
False
):
rtn
=
self
.
coordinate_transform
.
inv
()
@
physical_coordinates
-
self
.
coordinate_origin
if
staggered
:
rtn
=
sp
.
Matrix
([
i
-
0.5
for
i
in
rtn
])
return
rtn
def
index_to_staggered_physical_coordinates
(
self
,
symbol_vector
):