Commit 39d171b3 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Split trivial translations of indices and offsets

parent 32e796ff
Pipeline #27719 waiting for manual action with stage
in 26 minutes and 37 seconds
......@@ -65,8 +65,9 @@ class BetweenTimestepsIndexing:
self._index_dtype = create_type(index_dtype)
self._offsets_dtype = create_type(offsets_dtype)
self._required_arrays = set()
self._trivial_translations = self._collect_trivial_translations()
self._required_index_arrays = set()
self._required_offset_arrays = set()
self._trivial_index_translations, self._trivial_offset_translations = self._collect_trivial_translations()
def _index_array_symbol(self, f_dir, inverse):
assert f_dir in ['in', 'out']
......@@ -82,17 +83,21 @@ class BetweenTimestepsIndexing:
return symbols
def _array_symbols(self, f_dir, inverse, index):
if (f_dir, inverse) in self._trivial_translations:
return {'index': index, 'offsets': (0, ) * self._dim}
index_array_symbol = self._index_array_symbol(f_dir, inverse)
offset_array_symbols = self._offset_array_symbols(f_dir, inverse)
self._required_arrays.add((f_dir, inverse))
return {
'index': sp.IndexedBase(index_array_symbol, shape=(1,))[index],
'offsets': tuple(sp.IndexedBase(s, shape=(1,))[index]
for s in offset_array_symbols)
}
if (f_dir, inverse) in self._trivial_index_translations:
translated_index = index
else:
index_array_symbol = self._index_array_symbol(f_dir, inverse)
translated_index = sp.IndexedBase(index_array_symbol, shape=(1,))[index]
self._required_index_arrays.add((f_dir, inverse))
if (f_dir, inverse) in self._trivial_offset_translations:
offsets = (0, ) * self._dim
else:
offset_array_symbols = self._offset_array_symbols(f_dir, inverse)
offsets = tuple(sp.IndexedBase(s, shape=(1,))[index] for s in offset_array_symbols)
self._required_offset_arrays.add((f_dir, inverse))
return {'index': translated_index, 'offsets': offsets}
# =================================
# Proxy fields substitution
......@@ -154,14 +159,17 @@ class BetweenTimestepsIndexing:
return indices, offsets
def _collect_trivial_translations(self):
trivials = set()
trivial_index_translations = set()
trivial_offset_translations = set()
trivial_indices = list(range(self._q))
trivial_offsets = [[0] * self._q] * self._dim
for f_dir, inv in product(['in', 'out'], [False, True]):
indices, offsets = self._get_translated_indices_and_offsets(f_dir, inv)
if indices == trivial_indices and offsets == trivial_offsets:
trivials.add((f_dir, inv))
return trivials
if indices == trivial_indices:
trivial_index_translations.add((f_dir, inv))
if offsets == trivial_offsets:
trivial_offset_translations.add((f_dir, inv))
return trivial_index_translations, trivial_offset_translations
def create_code_node(self):
return BetweenTimestepsIndexing.TranslationArraysNode(self)
......@@ -172,13 +180,14 @@ class BetweenTimestepsIndexing:
code = ''
symbols_defined = set()
for f_dir, inv in indexing._required_arrays:
for f_dir, inv in indexing._required_index_arrays:
indices, offsets = indexing._get_translated_indices_and_offsets(f_dir, inv)
index_array_symbol = indexing._index_array_symbol(f_dir, inv)
symbols_defined.add(index_array_symbol)
code += _array_pattern(indexing._index_dtype, index_array_symbol.name, indices)
for f_dir, inv in indexing._required_offset_arrays:
indices, offsets = indexing._get_translated_indices_and_offsets(f_dir, inv)
offset_array_symbols = indexing._offset_array_symbols(f_dir, inv)
symbols_defined |= set(offset_array_symbols)
for d, arrsymb in enumerate(offset_array_symbols):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment