Commit 08540d5f authored by Martin Bauer's avatar Martin Bauer
Browse files

No double-write check inside conditionals

parent 85053df1
......@@ -437,7 +437,6 @@ def run_compile_step(command):
config_env = compiler_config['env'] if 'env' in compiler_config else {}
compile_environment = os.environ.copy()
shell = True if compiler_config['os'].lower() == 'windows' else False
subprocess.check_output(command, env=compile_environment, stderr=subprocess.STDOUT, shell=shell)
......@@ -143,7 +143,9 @@ def mask_conditionals(loop_body):
def visit_node(node, mask):
if isinstance(node, ast.Conditional):
true_mask = sp.And(node.condition_expr, mask)
cond = node.condition_expr
cond = cond if loop_body.loop_counter_symbol in cond.atoms(sp.Symbol) else True
true_mask = sp.And(cond, mask)
visit_node(node.true_block, true_mask)
if node.false_block:
false_mask = sp.And(sp.Not(node.condition_expr), mask)
......@@ -27,6 +27,9 @@ def typed_symbols(names, dtype, *args):
return TypedSymbol(str(symbols), dtype)
def type_all_numbers(expr, dtype):
substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)}
return expr.subs(substitutions)
def matrix_symbols(names, dtype, rows, cols):
if isinstance(names, str):
......@@ -345,7 +345,8 @@ class AssignmentCollection:
return result
def __repr__(self):
return "Assignment Collection for " + ",".join([str(eq.lhs) for eq in self.main_assignments])
return "Assignment Collection for " + ",".join([str(eq.lhs) for eq in self.main_assignments
if isinstance(eq, Assignment)])
def __str__(self):
result = "Subexpressions:\n"
......@@ -659,7 +659,7 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
outer_loop = outer_loop[0]
symbols_with_temporary_array = OrderedDict()
assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args)
assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args if hasattr(a, 'lhs'))
assignment_groups = []
for symbol_group in symbol_groups:
......@@ -690,13 +690,10 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
if assignment.lhs in symbols_resolved:
new_rhs = assignment.rhs.subs(
if not isinstance(assignment.lhs, AbstractField.AbstractAccess
) and assignment.lhs in symbol_group:
if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group:
assert type(assignment.lhs) is TypedSymbol
new_ts = TypedSymbol(,
new_lhs = sp.IndexedBase(
new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
new_ts = TypedSymbol(, PointerType(assignment.lhs.dtype))
new_lhs = sp.IndexedBase(new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
new_lhs = assignment.lhs
assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
......@@ -805,13 +802,14 @@ class KernelConstraintsCheck:
FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
def __init__(self, type_for_symbol, check_independence_condition):
def __init__(self, type_for_symbol, check_independence_condition, check_double_write_condition=True):
self._type_for_symbol = type_for_symbol
self.scopes = NestedScopes()
self._field_writes = defaultdict(set)
self.fields_read = set()
self.check_independence_condition = check_independence_condition
self.check_double_write_condition = check_double_write_condition
def process_assignment(self, assignment):
# for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
......@@ -891,19 +889,15 @@ class KernelConstraintsCheck:
if isinstance(lhs, AbstractField.AbstractAccess):
fai = self.FieldAndIndex(lhs.field, lhs.index)
#if len(self._field_writes[fai]) > 1:
# raise ValueError(
# "Field {} is written at two different locations".format(
if self.check_double_write_condition and len(self._field_writes[fai]) > 1:
raise ValueError(
"Field {} is written at two different locations".format(
elif isinstance(lhs, sp.Symbol):
if self.scopes.is_defined_locally(lhs):
raise ValueError(
"Assignments not in SSA form, multiple assignments to {}".
raise ValueError("Assignments not in SSA form, multiple assignments to {}".format(
if lhs in self.scopes.free_parameters:
raise ValueError(
"Symbol {} is written, after it has been read".format(
raise ValueError("Symbol {} is written, after it has been read".format(
def _update_accesses_rhs(self, rhs):
......@@ -947,12 +941,16 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
return check.process_assignment(obj)
elif isinstance(obj, ast.Conditional):
# Disable double write check inside conditionals
# would be triggered by e.g. in-kernel boundaries
check.check_double_write_condition = False
false_block = None if obj.false_block is None else visit(
result = ast.Conditional(check.process_expression(
obj.condition_expr, type_constants=False),
check.check_double_write_condition = True
return result
elif isinstance(obj, ast.Block):
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment