Commit 6be6ba63 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Making cool stuff with interpolators

parent 064147b6
......@@ -75,7 +75,7 @@ class CudaSympyPrinter(CustomSympyPrinter):
if type(node) == DiffInterpolatorAccess:
# cubicTex3D_1st_derivative_x(texture tex, float3 coord)
template = f"cubicTex%iD_1st_derivative_{'xyz'[node.diff_coordinate_idx]}(%s, %s)"
template = f"cubicTex%iD_1st_derivative_{list(reversed('xyz'[:node.ndim]))[node.diff_coordinate_idx]}(%s, %s)" # noqa
elif node.interpolator.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
template = "cubicTex%iDSimple(%s, %s)"
else:
......
......@@ -109,7 +109,9 @@ class Discretization2ndOrder:
return self._discretize_advection(e)
elif isinstance(e, Diff):
arg, *indices = diff_args(e)
if not isinstance(arg, Field.Access):
from pystencils.interpolation_astnodes import InterpolatorAccess
if not isinstance(arg, (Field.Access, InterpolatorAccess)):
raise ValueError("Only derivatives with field or field accesses as arguments can be discretized")
return self.spatial_stencil(indices, self.dx, arg)
else:
......
......@@ -170,6 +170,14 @@ class InterpolatorAccess(TypedSymbol):
def __repr__(self):
return self.__str__()
@property
def ndim(self):
return len(self.offsets)
@property
def is_texture(self):
return isinstance(self.interpolator, TextureCachedField)
def atoms(self, *types):
if self.offsets:
offsets = set(o for o in self.offsets if isinstance(o, types))
......@@ -182,6 +190,11 @@ class InterpolatorAccess(TypedSymbol):
else:
return set()
def neighbor(self, coord_id, offset):
offset_list = list(self.offsets)
offset_list[coord_id] += offset
return self.interpolator.at(tuple(offset_list))
@property
def free_symbols(self):
symbols = set()
......@@ -318,6 +331,9 @@ class InterpolatorAccess(TypedSymbol):
class DiffInterpolatorAccess(InterpolatorAccess):
def __new__(cls, symbol, diff_coordinate_idx, *offsets, **kwargs):
if symbol.interpolator.interpolation_mode == InterpolationMode.LINEAR:
from pystencils.fd import Diff, Discretization2ndOrder
return Discretization2ndOrder(1)(Diff(symbol.interpolator.at(offsets), diff_coordinate_idx))
obj = DiffInterpolatorAccess.__xnew_cached_(cls, symbol, diff_coordinate_idx, *offsets, **kwargs)
return obj
......@@ -363,7 +379,7 @@ class DiffInterpolatorAccess(InterpolatorAccess):
##########################################################################################
class TextureCachedField:
class TextureCachedField(Interpolator):
def __init__(self, parent_field,
address_mode=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment