Commit 6ef3d7d7 authored by Martin Bauer's avatar Martin Bauer
Browse files

Bugfixes and improvements in grandchem staggered μ sweep

parent 7df37e9f
class NestedScopes:
"""Symbol visibility model using nested scopes
- every accessed symbol that was not defined before, is added as a "free parameter"
- free parameters are global, i.e. they are not in scopes
- push/pop adds or removes a scope
>>> s = NestedScopes()
>>> s.access_symbol("a")
>>> s.is_defined("a")
False
>>> s.free_parameters
{'a'}
>>> s.define_symbol("b")
>>> s.is_defined("b")
True
>>> s.push()
>>> s.is_defined_locally("b")
False
>>> s.define_symbol("c")
>>> s.pop()
>>> s.is_defined("c")
False
"""
def __init__(self):
self.free_parameters = set()
self._defined = [set()]
def access_symbol(self, symbol):
if not self.is_defined(symbol):
self.free_parameters.add(symbol)
def define_symbol(self, symbol):
self._defined[-1].add(symbol)
def is_defined(self, symbol):
return any(symbol in scopes for scopes in self._defined)
def is_defined_locally(self, symbol):
return symbol in self._defined[-1]
def push(self):
self._defined.append(set())
def pop(self):
self._defined.pop()
assert self.depth >= 1
@property
def depth(self):
return len(self._defined)
......@@ -30,7 +30,7 @@ def create_kernel(assignments, target='cpu', data_type="double", iteration_slice
cpu_openmp: True or number of threads for OpenMP parallelization, False for no OpenMP
cpu_vectorize_info: a dictionary with keys, 'vector_instruction_set', 'assume_aligned' and 'nontemporal'
for documentation of these parameters see vectorize function. Example:
'{'vector_instruction_set': 'avx512', 'assume_aligned': True, 'nontemporal':True}'
'{'instruction_set': 'avx512', 'assume_aligned': True, 'nontemporal':True}'
gpu_indexing: either 'block' or 'line' , or custom indexing class, see `AbstractIndexing`
gpu_indexing_params: dict with indexing parameters (constructor parameters of indexing class)
e.g. for 'block' one can specify '{'block_size': (20, 20, 10) }'
......@@ -171,7 +171,7 @@ def create_staggered_kernel(staggered_field, expressions, subexpressions=(), tar
where e.g. ``f[0,0](0)`` is interpreted as value at the left cell boundary, ``f[1,0](0)`` the right cell
boundary and ``f[0,0](1)`` the southern cell boundary etc.
expressions: sequence of expressions of length dim, defining how the east, southern, (bottom) cell boundary
should be update.
should be updated.
subexpressions: optional sequence of Assignments, that define subexpressions used in the main expressions
target: 'cpu' or 'gpu'
kwargs: passed directly to create_kernel, iteration slice and ghost_layers parameters are not allowed
......
......@@ -7,6 +7,7 @@ import sympy as sp
from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase
from pystencils.assignment import Assignment
from pystencils.assignment_collection.nestedscopes import NestedScopes
from pystencils.field import Field, FieldType
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \
pointer_arithmetic_func, get_type_of_expression, collate_types, create_type
......@@ -727,9 +728,8 @@ class KernelConstraintsCheck:
def __init__(self, type_for_symbol, check_independence_condition):
self._type_for_symbol = type_for_symbol
self._defined_pure_symbols = set()
self._accessed_pure_symbols = set()
self.scopes = NestedScopes()
self._field_writes = defaultdict(set)
self.fields_read = set()
self.check_independence_condition = check_independence_condition
......@@ -784,11 +784,11 @@ class KernelConstraintsCheck:
if len(self._field_writes[fai]) > 1:
raise ValueError("Field {} is written at two different locations".format(lhs.field.name))
elif isinstance(lhs, sp.Symbol):
if lhs in self._defined_pure_symbols:
if self.scopes.is_defined_locally(lhs):
raise ValueError("Assignments not in SSA form, multiple assignments to {}".format(lhs.name))
if lhs in self._accessed_pure_symbols:
if lhs in self.scopes.free_parameters:
raise ValueError("Symbol {} is written, after it has been read".format(lhs.name))
self._defined_pure_symbols.add(lhs)
self.scopes.define_symbol(lhs)
def _update_accesses_rhs(self, rhs):
if isinstance(rhs, Field.Access) and self.check_independence_condition:
......@@ -800,7 +800,7 @@ class KernelConstraintsCheck:
"{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset))
self.fields_read.add(rhs.field)
elif isinstance(rhs, sp.Symbol):
self._accessed_pure_symbols.add(rhs)
self.scopes.access_symbol(rhs)
def add_types(eqs, type_for_symbol, check_independence_condition):
......@@ -829,11 +829,17 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
if isinstance(obj, sp.Eq) or isinstance(obj, ast.SympyAssignment) or isinstance(obj, Assignment):
return check.process_assignment(obj)
elif isinstance(obj, ast.Conditional):
check.scopes.push()
false_block = None if obj.false_block is None else visit(obj.false_block)
return ast.Conditional(check.process_expression(obj.condition_expr, type_constants=False),
true_block=visit(obj.true_block), false_block=false_block)
result = ast.Conditional(check.process_expression(obj.condition_expr, type_constants=False),
true_block=visit(obj.true_block), false_block=false_block)
check.scopes.pop()
return result
elif isinstance(obj, ast.Block):
return ast.Block([visit(e) for e in obj.args])
check.scopes.push()
result = ast.Block([visit(e) for e in obj.args])
check.scopes.pop()
return result
elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
return obj
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment