Skip to content
Snippets Groups Projects
Commit 34732f61 authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

staggered_access: fix access to directions with mixed sign

NW (-1/2, 1/2) and the like were previously mapped to the wrong cell
parent 9888c32f
Branches
Tags
No related merge requests found
...@@ -15,7 +15,7 @@ import pystencils ...@@ -15,7 +15,7 @@ import pystencils
from pystencils.alignedarray import aligned_empty from pystencils.alignedarray import aligned_empty
from pystencils.data_types import StructType, TypedSymbol, create_type from pystencils.data_types import StructType, TypedSymbol, create_type
from pystencils.kernelparameters import FieldShapeSymbol, FieldStrideSymbol from pystencils.kernelparameters import FieldShapeSymbol, FieldStrideSymbol
from pystencils.stencil import direction_string_to_offset, offset_to_direction_string from pystencils.stencil import direction_string_to_offset, offset_to_direction_string, inverse_direction
from pystencils.sympyextensions import is_integer_sequence from pystencils.sympyextensions import is_integer_sequence
__all__ = ['Field', 'fields', 'FieldType', 'AbstractField'] __all__ = ['Field', 'fields', 'FieldType', 'AbstractField']
...@@ -490,19 +490,21 @@ class Field(AbstractField): ...@@ -490,19 +490,21 @@ class Field(AbstractField):
raise ValueError("Wrong number of spatial indices: " raise ValueError("Wrong number of spatial indices: "
"Got %d, expected %d" % (len(offset), self.spatial_dimensions)) "Got %d, expected %d" % (len(offset), self.spatial_dimensions))
offset = list(offset) neighbor_vec = [0] * len(offset)
neighbor = [0] * len(offset) for i in range(self.spatial_dimensions):
for i, o in enumerate(offset): if (offset[i] + sp.Rational(1, 2)).is_Integer:
if (o + sp.Rational(1, 2)).is_Integer: neighbor_vec[i] = sp.sign(offset[i])
offset[i] += sp.Rational(1, 2) neighbor = offset_to_direction_string(neighbor_vec)
neighbor[i] = -1 if neighbor not in self.staggered_stencil:
neighbor = offset_to_direction_string(neighbor) neighbor_vec = inverse_direction(neighbor_vec)
try: neighbor = offset_to_direction_string(neighbor_vec)
idx = self.staggered_stencil.index(neighbor) if neighbor not in self.staggered_stencil:
except ValueError:
raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig, raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig,
self.staggered_stencil_name)) self.staggered_stencil_name))
offset = tuple(offset)
offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec))
idx = self.staggered_stencil.index(neighbor)
if self.index_dimensions == 1: # this field stores a scalar value at each staggered position if self.index_dimensions == 1: # this field stores a scalar value at each staggered position
if index is not None: if index is not None:
......
...@@ -135,6 +135,8 @@ def test_staggered(): ...@@ -135,6 +135,8 @@ def test_staggered():
j1, j2, j3 = ps.fields('j1(2), j2(2,2), j3(2,2,2) : double[2D]', field_type=FieldType.STAGGERED) j1, j2, j3 = ps.fields('j1(2), j2(2,2), j3(2,2,2) : double[2D]', field_type=FieldType.STAGGERED)
assert j1[0, 1](1) == j1.staggered_access((0, sp.Rational(1, 2))) assert j1[0, 1](1) == j1.staggered_access((0, sp.Rational(1, 2)))
assert j1[1, 1](1) == j1.staggered_access((1, sp.Rational(1, 2)))
assert j1[0, 2](1) == j1.staggered_access((0, sp.Rational(3, 2)))
assert j1[0, 1](1) == j1.staggered_access("N") assert j1[0, 1](1) == j1.staggered_access("N")
assert j1[0, 0](1) == j1.staggered_access("S") assert j1[0, 0](1) == j1.staggered_access("S")
...@@ -149,3 +151,5 @@ def test_staggered(): ...@@ -149,3 +151,5 @@ def test_staggered():
assert k[1, 1](2) == k.staggered_access("NE") assert k[1, 1](2) == k.staggered_access("NE")
assert k[0, 0](2) == k.staggered_access("SW") assert k[0, 0](2) == k.staggered_access("SW")
assert k[0, 0](3) == k.staggered_access("NW")
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment