Commit dcc76a01 authored by Michael Kuron's avatar Michael Kuron
Browse files

Merge remote-tracking branch 'origin/master' into staggered_kernel

parents fcdfc127 8ca8b2ee
......@@ -426,13 +426,15 @@ class Field(AbstractField):
index_shape = self.index_shape
if len(index_shape) == 0:
return sp.Matrix([self.center])
if len(index_shape) == 1:
elif len(index_shape) == 1:
return sp.Matrix([self(i) for i in range(index_shape[0])])
elif len(index_shape) == 2:
def cb(*args):
r = self.__call__(*args)
return r
return sp.Matrix(*index_shape, cb)
return sp.Matrix([[self(i, j) for j in range(index_shape[1])] for i in range(index_shape[0])])
elif len(index_shape) == 3:
return sp.Matrix([[[self(i, j, k) for k in range(index_shape[2])]
for j in range(index_shape[1])] for i in range(index_shape[0])])
else:
raise NotImplementedError("center_vector is not implemented for more than 3 index dimensions")
@property
def center(self):
......@@ -534,6 +536,20 @@ class Field(AbstractField):
return prefactor * Field.Access(self, offset, (idx, *index))
def staggered_vector_access(self, offset):
"""Like staggered_access, but returns the entire vector/tensor stored at offset."""
assert FieldType.is_staggered(self)
if self.index_dimensions == 1:
return sp.Matrix([self.staggered_access(offset)])
elif self.index_dimensions == 2:
return sp.Matrix([self.staggered_access(offset, i) for i in range(self.index_shape[1])])
elif self.index_dimensions == 3:
return sp.Matrix([[self.staggered_access(offset, (i, k)) for k in range(self.index_shape[2])]
for i in range(self.index_shape[1])])
else:
raise NotImplementedError("staggered_vector_access is not implemented for more than 3 index dimensions")
@property
def staggered_stencil(self):
assert FieldType.is_staggered(self)
......
......@@ -28,6 +28,9 @@ def test_field_basic():
assert neighbor.offsets == (-1, 1)
assert '_' in neighbor._latex('dummy')
f = Field.create_fixed_size('f', (8, 8, 2, 2, 2), index_dimensions=3)
assert f.center_vector == sp.Matrix([[[f(i, j, k) for k in range(2)] for j in range(2)] for i in range(2)])
f = Field.create_generic('f', spatial_dimensions=5, index_dimensions=2)
field_access = f[1, -1, 2, -3, 0](1, 0)
assert field_access.offsets == (1, -1, 2, -3, 0)
......@@ -139,12 +142,16 @@ def test_staggered():
assert j1[0, 2](1) == j1.staggered_access((0, sp.Rational(3, 2)))
assert j1[0, 1](1) == j1.staggered_access("N")
assert j1[0, 0](1) == j1.staggered_access("S")
assert j1.staggered_vector_access("N") == sp.Matrix([j1.staggered_access("N")])
assert j2[0, 1](1, 1) == j2.staggered_access((0, sp.Rational(1, 2)), 1)
assert j2[0, 1](1, 1) == j2.staggered_access("N", 1)
assert j2.staggered_vector_access("N") == sp.Matrix([j2.staggered_access("N", 0), j2.staggered_access("N", 1)])
assert j3[0, 1](1, 1, 1) == j3.staggered_access((0, sp.Rational(1, 2)), (1, 1))
assert j3[0, 1](1, 1, 1) == j3.staggered_access("N", (1, 1))
assert j3.staggered_vector_access("N") == sp.Matrix([[j3.staggered_access("N", (i, j))
for j in range(2)] for i in range(2)])
# D2Q9
k = ps.fields('k(4) : double[2D]', field_type=FieldType.STAGGERED)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment