Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Stephan Seitz
pystencils
Commits
ee9587df
Commit
ee9587df
authored
Jan 10, 2020
by
Stephan Seitz
Browse files
Allow differentation of InterpolatorAccess
parent
07c87bc7
Changes
2
Hide whitespace changes
Inline
Side-by-side
pystencils/interpolation_astnodes.py
View file @
ee9587df
...
...
@@ -143,7 +143,7 @@ class NearestNeightborInterpolator(Interpolator):
class
InterpolatorAccess
(
TypedSymbol
):
def
__new__
(
cls
,
field
,
*
offsets
,
**
kwargs
):
obj
=
Texture
Access
.
__xnew_cached_
(
cls
,
field
,
*
offsets
,
**
kwargs
)
obj
=
Interpolator
Access
.
__xnew_cached_
(
cls
,
field
,
*
offsets
,
**
kwargs
)
return
obj
def
__new_stage2__
(
self
,
symbol
,
*
offsets
):
...
...
@@ -201,6 +201,17 @@ class InterpolatorAccess(TypedSymbol):
def
interpolation_mode
(
self
):
return
self
.
interpolator
.
interpolation_mode
@
property
def
_diff_interpolation_vec
(
self
):
return
sp
.
Matrix
([
DiffInterpolatorAccess
(
self
.
symbol
,
i
,
*
self
.
offsets
)
for
i
in
range
(
len
(
self
.
offsets
))])
def
diff
(
self
,
*
symbols
,
**
kwargs
):
rtn
=
self
.
_diff_interpolation_vec
.
T
*
sp
.
Matrix
(
self
.
offsets
).
diff
(
*
symbols
,
**
kwargs
)
if
rtn
.
shape
==
(
1
,
1
):
rtn
=
rtn
[
0
,
0
]
return
rtn
def
implementation_with_stencils
(
self
):
field
=
self
.
field
...
...
@@ -255,7 +266,7 @@ class InterpolatorAccess(TypedSymbol):
for
(
dim
,
i
)
in
enumerate
(
index
)]
index
=
[
cast_func
(
sp
.
Piecewise
((
i
,
i
>
0
),
(
sp
.
Abs
(
cast_func
(
field
.
shape
[
dim
]
-
1
+
i
,
default_int_type
)),
True
)),
default_int_type
)
True
)),
default_int_type
)
for
(
dim
,
i
)
in
enumerate
(
index
)]
sum
[
channel_idx
]
+=
weight
*
\
absolute_access
(
index
,
channel_idx
if
field
.
index_dimensions
else
())
...
...
@@ -290,6 +301,46 @@ class InterpolatorAccess(TypedSymbol):
def
__getnewargs__
(
self
):
return
tuple
(
self
.
symbol
,
*
self
.
offsets
)
class
DiffInterpolatorAccess
(
InterpolatorAccess
):
def
__new__
(
cls
,
symbol
,
diff_coordinate_idx
,
*
offsets
,
**
kwargs
):
obj
=
DiffInterpolatorAccess
.
__xnew_cached_
(
cls
,
symbol
,
diff_coordinate_idx
,
*
offsets
,
**
kwargs
)
return
obj
def
__new_stage2__
(
self
,
symbol
:
sp
.
Symbol
,
diff_coordinate_idx
,
*
offsets
):
assert
offsets
is
not
None
obj
=
super
().
__xnew__
(
self
,
symbol
,
*
offsets
)
obj
.
diff_coordinate_idx
=
diff_coordinate_idx
return
obj
def
__hash__
(
self
):
return
hash
((
self
.
symbol
,
self
.
field
,
self
.
diff_coordinate_idx
,
tuple
(
self
.
offsets
),
self
.
interpolator
))
def
__str__
(
self
):
return
'%s_diff%i_interpolator(%s)'
%
(
self
.
field
.
name
,
self
.
diff_coordinate_idx
,
','
.
join
(
str
(
o
)
for
o
in
self
.
offsets
))
@
property
def
args
(
self
):
return
[
self
.
symbol
,
self
.
diff_coordinate_idx
,
*
self
.
offsets
]
@
property
def
symbols_defined
(
self
)
->
Set
[
sp
.
Symbol
]:
return
{
self
}
@
property
def
interpolation_mode
(
self
):
return
self
.
interpolator
.
interpolation_mode
# noinspection SpellCheckingInspection
__xnew__
=
staticmethod
(
__new_stage2__
)
# noinspection SpellCheckingInspection
__xnew_cached_
=
staticmethod
(
cacheit
(
__new_stage2__
))
def
__getnewargs__
(
self
):
return
tuple
(
self
.
symbol
,
self
.
diff_coordinate_idx
,
*
self
.
offsets
)
##########################################################################################
# GPU-specific fast specializations (for precision GPUs can also use above nodes/symbols #
##########################################################################################
...
...
pystencils_tests/test_interpolation.py
View file @
ee9587df
...
...
@@ -234,5 +234,13 @@ def test_field_interpolated(address_mode, target):
out
=
np
.
zeros_like
(
lenna
)
kernel
(
x
=
lenna
,
y
=
out
)
pyconrad
.
imshow
(
out
,
"out "
+
address_mode
)
kernel
(
x
=
lenna
,
y
=
out
)
pyconrad
.
imshow
(
out
,
"out "
+
address_mode
)
def
test_spatial_derivative
():
x
,
y
=
pystencils
.
fields
(
'x, y: float32[2d]'
)
tx
,
ty
=
pystencils
.
fields
(
't_x, t_y: float32[2d]'
)
diff
=
sympy
.
diff
(
x
.
interpolated_access
((
tx
.
center
,
ty
.
center
)),
tx
.
center
)
print
(
"diff: "
+
str
(
diff
))
diff
=
sympy
.
diff
(
x
.
interpolated_access
((
tx
.
center
,
2
*
ty
.
center
)),
sympy
.
Matrix
((
tx
.
center
,
ty
.
center
)))
print
(
"diff: "
+
str
(
diff
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment