Commit dd236fb0 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

refactored private members of AdvancedStreamingBoundaryIndexing

parent 759a3921
......@@ -22,9 +22,7 @@ class AdvancedStreamingBoundaryIndexing:
# =======================================
def proxy_fields(self):
q = len(self.stencil)
d = len(self.stencil[0])
return ps.fields(f"f_out({q}), f_in({q}): [{d}D]")
return ps.fields(f"f_out({self._q}), f_in({self._q}): [{self._dim}D]")
def dir_symbol(self):
return TypedSymbol('dir', create_type(np.int64))
......@@ -77,16 +75,16 @@ class AdvancedStreamingBoundaryIndexing:
inward_accesses = (
even_accessor if odd_to_even else odd_accessor).read(pdf_field, stencil)
self.accesses = {'out': outward_accesses, 'in': inward_accesses}
self._accesses = {'out': outward_accesses, 'in': inward_accesses}
self.pdf_field = pdf_field
self.stencil = stencil
self.dim = len(stencil[0])
self.q = len(stencil)
self.coordinate_names = ['x', 'y', 'z'][:self.dim]
self._pdf_field = pdf_field
self._stencil = stencil
self._dim = len(stencil[0])
self._q = len(stencil)
self._coordinate_names = ['x', 'y', 'z'][:self._dim]
self.required_arrays = set()
self.trivial_translations = self._collect_trivial_translations()
self._required_arrays = set()
self._trivial_translations = self._collect_trivial_translations()
def _index_array_name(self, f_dir, inverse):
assert f_dir in ['in', 'out']
......@@ -98,16 +96,16 @@ class AdvancedStreamingBoundaryIndexing:
assert f_dir in ['in', 'out']
inv = '_inv' if inverse else ''
name_base = f"f_{f_dir}{inv}_offsets_"
names = [name_base + d for d in self.coordinate_names]
names = [name_base + d for d in self._coordinate_names]
return names
def _array_symbols(self, f_dir, inverse, index):
if (f_dir, inverse) in self.trivial_translations:
return {'index': index, 'offsets': (0, ) * self.dim}
if (f_dir, inverse) in self._trivial_translations:
return {'index': index, 'offsets': (0, ) * self._dim}
index_array_name = self._index_array_name(f_dir, inverse)
offset_array_names = self._offset_array_names(f_dir, inverse)
self.required_arrays.add((f_dir, inverse))
self._required_arrays.add((f_dir, inverse))
return {
'index': sp.IndexedBase(index_array_name)[index],
'offsets': tuple(sp.IndexedBase(n)[index] for n in offset_array_names)
......@@ -124,7 +122,7 @@ class AdvancedStreamingBoundaryIndexing:
if not isinstance(assignments, ps.AssignmentCollection):
assignments = ps.AssignmentCollection(assignments)
accesses = self.accesses
accesses = self._accesses
f_out, f_in = self.proxy_fields()
inv_dir = self.inverse_dir_symbol()
......@@ -150,7 +148,7 @@ class AdvancedStreamingBoundaryIndexing:
accessor_subs[fa] = accesses[f_dir][idx].get_shifted(*(fa.offsets))
else:
arr = self._array_symbols(f_dir, inv, idx)
accessor_subs[fa] = self.pdf_field[arr['offsets']](arr['index']).get_shifted(*(fa.offsets))
accessor_subs[fa] = self._pdf_field[arr['offsets']](arr['index']).get_shifted(*(fa.offsets))
return assignments.new_with_substitutions(accessor_subs)
......@@ -159,26 +157,26 @@ class AdvancedStreamingBoundaryIndexing:
# =================
def _inverse_integer_dir_index(self, idx):
return self.stencil.index(tuple(-d for d in self.stencil[idx]))
return self._stencil.index(tuple(-d for d in self._stencil[idx]))
def _get_translated_indices_and_offsets(self, f_dir, inv):
accesses = self.accesses[f_dir]
accesses = self._accesses[f_dir]
if inv:
inverse_indices = [self._inverse_integer_dir_index(i)
for i in range(len(self.stencil))]
for i in range(len(self._stencil))]
accesses = [accesses[idx] for idx in inverse_indices]
indices = [a.index[0] for a in accesses]
offsets = []
for d in range(self.dim):
for d in range(self._dim):
offsets.append([a.offsets[d] for a in accesses])
return indices, offsets
def _collect_trivial_translations(self):
trivials = set()
trivial_indices = list(range(self.q))
trivial_offsets = [[0] * self.q] * self.dim
trivial_indices = list(range(self._q))
trivial_offsets = [[0] * self._q] * self._dim
for f_dir, inv in product(['in', 'out'], [False, True]):
indices, offsets = self._get_translated_indices_and_offsets(f_dir, inv)
if indices == trivial_indices and offsets == trivial_offsets:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment