diff --git a/transformations.py b/transformations.py index d108b7011c1e4582b966b928820cf0d491d060b9..8b288bb4c9f3a0f67b47fd27362cf55a8f423b41 100644 --- a/transformations.py +++ b/transformations.py @@ -707,7 +707,7 @@ class KernelConstraintsCheck: new_lhs = self._process_lhs(assignment.lhs) return ast.SympyAssignment(new_lhs, new_rhs) - def process_expression(self, rhs): + def process_expression(self, rhs, type_constants=True): self._update_accesses_rhs(rhs) if isinstance(rhs, Field.Access): self.fields_read.add(rhs.field) @@ -716,19 +716,19 @@ class KernelConstraintsCheck: return rhs elif isinstance(rhs, sp.Symbol): return TypedSymbol(symbol_name_to_variable_name(rhs.name), self._type_for_symbol[rhs.name]) - elif isinstance(rhs, sp.Number): + elif type_constants and isinstance(rhs, sp.Number): return cast_func(rhs, create_type(self._type_for_symbol['_constant'])) elif isinstance(rhs, sp.Mul): - new_args = [self.process_expression(arg) if arg not in (-1, 1) else arg for arg in rhs.args] + new_args = [self.process_expression(arg, type_constants) if arg not in (-1, 1) else arg for arg in rhs.args] return rhs.func(*new_args) if new_args else rhs elif isinstance(rhs, sp.Indexed): return rhs else: if isinstance(rhs, sp.Pow): # don't process exponents -> they should remain integers - return sp.Pow(self.process_expression(rhs.args[0]), rhs.args[1]) + return sp.Pow(self.process_expression(rhs.args[0], type_constants), rhs.args[1]) else: - new_args = [self.process_expression(arg) for arg in rhs.args] + new_args = [self.process_expression(arg, type_constants) for arg in rhs.args] return rhs.func(*new_args) if new_args else rhs @property @@ -796,7 +796,7 @@ def add_types(eqs, type_for_symbol, check_independence_condition): return check.process_assignment(obj) elif isinstance(obj, ast.Conditional): false_block = None if obj.false_block is None else visit(obj.false_block) - return ast.Conditional(check.process_expression(obj.condition_expr), + return ast.Conditional(check.process_expression(obj.condition_expr, type_constants=False), true_block=visit(obj.true_block), false_block=false_block) elif isinstance(obj, ast.Block): return ast.Block([visit(e) for e in obj.args])