Commit 318f7c90 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Format Fixes

parent b1ee4c0b
from .indexing import AdvancedStreamingIndexing
__all__ = [ 'AdvancedStreamingIndexing' ]
\ No newline at end of file
__all__ = ['AdvancedStreamingIndexing']
......@@ -61,7 +61,8 @@ class AdvancedStreamingBoundaryOffsetInfo(CustomCodeNode):
def __init__(self, pdf_field, stencil, between_timesteps='both', kernel_type='pull'):
if between_timesteps not in ['both', 'odd_to_even', 'even_to_odd']:
raise ValueError(
"Invalid value of parameter 'between_timesteps'. Must be one of 'both', 'odd_to_even' or 'even_to_odd'.",
"Invalid value of parameter 'between_timesteps'."
+ " Must be one of 'both', 'odd_to_even' or 'even_to_odd'.",
between_timesteps)
if between_timesteps == 'both' and kernel_type in ['aa', 'esotwist']:
......@@ -137,8 +138,8 @@ class AdvancedStreamingBoundaryOffsetInfo(CustomCodeNode):
code += "const int64_t %s [] = { %s };\n" % (
self.INV_INWARD_INDEX_SYMBOL.name, ", ".join(inv_in_acc_indices))
defined_symbols = set(outward_offset_sym + inv_inward_offset_sym +
[self.OUTWARD_INDEX_SYMBOL, self.INV_INWARD_INDEX_SYMBOL])
defined_symbols = set(outward_offset_sym + inv_inward_offset_sym
+ [self.OUTWARD_INDEX_SYMBOL, self.INV_INWARD_INDEX_SYMBOL])
super(AdvancedStreamingBoundaryOffsetInfo, self).__init__(
code, symbols_read=set(), symbols_defined=defined_symbols)
......
......@@ -11,6 +11,7 @@ from lbmpy.fieldaccess import StreamPullTwoFieldsAccessor, \
EsoTwistEvenTimeStepAccessor, \
EsoTwistOddTimeStepAccessor
class AdvancedStreamingIndexing:
# =======================================
......@@ -32,10 +33,11 @@ class AdvancedStreamingIndexing:
# Constructor and State
# =============================
def __init__(self, pdf_field, stencil, between_timesteps = 'both', kernel_type='pull'):
def __init__(self, pdf_field, stencil, between_timesteps='both', kernel_type='pull'):
if between_timesteps not in ['both', 'odd_to_even', 'even_to_odd']:
raise ValueError(
"Invalid value of parameter 'between_timesteps'. Must be one of 'both', 'odd_to_even' or 'even_to_odd'.",
"Invalid value of parameter 'between_timesteps'."
+ " Must be one of 'both', 'odd_to_even' or 'even_to_odd'.",
between_timesteps)
if between_timesteps == 'both' and kernel_type in ['aa', 'esotwist']:
......@@ -93,20 +95,18 @@ class AdvancedStreamingIndexing:
assert f_dir in ['in', 'out']
inv = '_inv' if inverse else ''
name_base = f"f_{f_dir}{inv}_offsets_"
names = [ name_base + d for d in self.directions ]
names = [name_base + d for d in self.directions]
self.required_arrays |= set(names)
return [ sp.IndexedBase(n)[dir_symbol] for n in names ]
return [sp.IndexedBase(n)[dir_symbol] for n in names]
# =================================
# Proxy fields substitution
# =================================
def substitute_proxies(self, assignments):
# Get the accessor lists for the streaming pattern (here AA)
outward_accesses = self.outward_accesses
inward_accesses = self.inward_accesses
f_out, f_in = self.proxy_fields()
dir_symbol = self.dir_symbol()
inv_dir = self.inverse_dir_symbol()
......@@ -115,18 +115,18 @@ class AdvancedStreamingIndexing:
if not isinstance(assignments, ps.AssignmentCollection):
assignments = ps.AssignmentCollection([assignments])
accessor_subs = dict()
for fa in assignments.atoms(ps.Field.Access):
if fa.field == f_out:
if fa.offsets == (0,0):
if fa.offsets == (0, 0):
if isinstance(fa.index[0], int):
accessor_subs[fa] = outward_accesses[fa.index[0]]
elif fa.index[0] == dir_symbol:
accessor_subs[fa] = self.pdf_field[
accessor_subs[fa] = self.pdf_field[
self._offset_array_symbols('out', False, dir_symbol)
](self._index_array_symbol('out', False, dir_symbol))
](self._index_array_symbol('out', False, dir_symbol))
else:
# other cases like inverse direction, etc.
pass
......@@ -135,13 +135,13 @@ class AdvancedStreamingIndexing:
pass
elif fa.field == f_in:
if fa.offsets == (0,0):
if fa.offsets == (0, 0):
if isinstance(fa.index[0], int):
accessor_subs[fa] = inward_accesses[fa.index[0]]
elif fa.index[0] == inv_dir[dir_symbol]:
accessor_subs[fa] = self.pdf_field[
accessor_subs[fa] = self.pdf_field[
self._offset_array_symbols('in', True, dir_symbol)
](self._index_array_symbol('in', True, dir_symbol))
](self._index_array_symbol('in', True, dir_symbol))
else:
# other cases
pass
......@@ -155,5 +155,3 @@ class AdvancedStreamingIndexing:
return assignments.new_with_substitutions(accessor_subs)
# end class AdvancedStreamingIndexing
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment