Commit 39d171b3 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Split trivial translations of indices and offsets

parent 32e796ff
Pipeline #27719 waiting for manual action with stage
in 26 minutes and 37 seconds
...@@ -65,8 +65,9 @@ class BetweenTimestepsIndexing: ...@@ -65,8 +65,9 @@ class BetweenTimestepsIndexing:
self._index_dtype = create_type(index_dtype) self._index_dtype = create_type(index_dtype)
self._offsets_dtype = create_type(offsets_dtype) self._offsets_dtype = create_type(offsets_dtype)
self._required_arrays = set() self._required_index_arrays = set()
self._trivial_translations = self._collect_trivial_translations() self._required_offset_arrays = set()
self._trivial_index_translations, self._trivial_offset_translations = self._collect_trivial_translations()
def _index_array_symbol(self, f_dir, inverse): def _index_array_symbol(self, f_dir, inverse):
assert f_dir in ['in', 'out'] assert f_dir in ['in', 'out']
...@@ -82,17 +83,21 @@ class BetweenTimestepsIndexing: ...@@ -82,17 +83,21 @@ class BetweenTimestepsIndexing:
return symbols return symbols
def _array_symbols(self, f_dir, inverse, index): def _array_symbols(self, f_dir, inverse, index):
if (f_dir, inverse) in self._trivial_translations: if (f_dir, inverse) in self._trivial_index_translations:
return {'index': index, 'offsets': (0, ) * self._dim} translated_index = index
else:
index_array_symbol = self._index_array_symbol(f_dir, inverse) index_array_symbol = self._index_array_symbol(f_dir, inverse)
offset_array_symbols = self._offset_array_symbols(f_dir, inverse) translated_index = sp.IndexedBase(index_array_symbol, shape=(1,))[index]
self._required_arrays.add((f_dir, inverse)) self._required_index_arrays.add((f_dir, inverse))
return {
'index': sp.IndexedBase(index_array_symbol, shape=(1,))[index], if (f_dir, inverse) in self._trivial_offset_translations:
'offsets': tuple(sp.IndexedBase(s, shape=(1,))[index] offsets = (0, ) * self._dim
for s in offset_array_symbols) else:
} offset_array_symbols = self._offset_array_symbols(f_dir, inverse)
offsets = tuple(sp.IndexedBase(s, shape=(1,))[index] for s in offset_array_symbols)
self._required_offset_arrays.add((f_dir, inverse))
return {'index': translated_index, 'offsets': offsets}
# ================================= # =================================
# Proxy fields substitution # Proxy fields substitution
...@@ -154,14 +159,17 @@ class BetweenTimestepsIndexing: ...@@ -154,14 +159,17 @@ class BetweenTimestepsIndexing:
return indices, offsets return indices, offsets
def _collect_trivial_translations(self): def _collect_trivial_translations(self):
trivials = set() trivial_index_translations = set()
trivial_offset_translations = set()
trivial_indices = list(range(self._q)) trivial_indices = list(range(self._q))
trivial_offsets = [[0] * self._q] * self._dim trivial_offsets = [[0] * self._q] * self._dim
for f_dir, inv in product(['in', 'out'], [False, True]): for f_dir, inv in product(['in', 'out'], [False, True]):
indices, offsets = self._get_translated_indices_and_offsets(f_dir, inv) indices, offsets = self._get_translated_indices_and_offsets(f_dir, inv)
if indices == trivial_indices and offsets == trivial_offsets: if indices == trivial_indices:
trivials.add((f_dir, inv)) trivial_index_translations.add((f_dir, inv))
return trivials if offsets == trivial_offsets:
trivial_offset_translations.add((f_dir, inv))
return trivial_index_translations, trivial_offset_translations
def create_code_node(self): def create_code_node(self):
return BetweenTimestepsIndexing.TranslationArraysNode(self) return BetweenTimestepsIndexing.TranslationArraysNode(self)
...@@ -172,13 +180,14 @@ class BetweenTimestepsIndexing: ...@@ -172,13 +180,14 @@ class BetweenTimestepsIndexing:
code = '' code = ''
symbols_defined = set() symbols_defined = set()
for f_dir, inv in indexing._required_arrays: for f_dir, inv in indexing._required_index_arrays:
indices, offsets = indexing._get_translated_indices_and_offsets(f_dir, inv) indices, offsets = indexing._get_translated_indices_and_offsets(f_dir, inv)
index_array_symbol = indexing._index_array_symbol(f_dir, inv) index_array_symbol = indexing._index_array_symbol(f_dir, inv)
symbols_defined.add(index_array_symbol) symbols_defined.add(index_array_symbol)
code += _array_pattern(indexing._index_dtype, index_array_symbol.name, indices) code += _array_pattern(indexing._index_dtype, index_array_symbol.name, indices)
for f_dir, inv in indexing._required_offset_arrays:
indices, offsets = indexing._get_translated_indices_and_offsets(f_dir, inv)
offset_array_symbols = indexing._offset_array_symbols(f_dir, inv) offset_array_symbols = indexing._offset_array_symbols(f_dir, inv)
symbols_defined |= set(offset_array_symbols) symbols_defined |= set(offset_array_symbols)
for d, arrsymb in enumerate(offset_array_symbols): for d, arrsymb in enumerate(offset_array_symbols):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment