Commit 8833346c authored by Martin Bauer's avatar Martin Bauer
Browse files

Fixes in vectorization to also support float kernels

parent c378ca19
...@@ -707,7 +707,7 @@ class KernelConstraintsCheck: ...@@ -707,7 +707,7 @@ class KernelConstraintsCheck:
new_lhs = self._process_lhs(assignment.lhs) new_lhs = self._process_lhs(assignment.lhs)
return ast.SympyAssignment(new_lhs, new_rhs) return ast.SympyAssignment(new_lhs, new_rhs)
def process_expression(self, rhs): def process_expression(self, rhs, type_constants=True):
self._update_accesses_rhs(rhs) self._update_accesses_rhs(rhs)
if isinstance(rhs, Field.Access): if isinstance(rhs, Field.Access):
self.fields_read.add(rhs.field) self.fields_read.add(rhs.field)
...@@ -716,19 +716,19 @@ class KernelConstraintsCheck: ...@@ -716,19 +716,19 @@ class KernelConstraintsCheck:
return rhs return rhs
elif isinstance(rhs, sp.Symbol): elif isinstance(rhs, sp.Symbol):
return TypedSymbol(symbol_name_to_variable_name(rhs.name), self._type_for_symbol[rhs.name]) return TypedSymbol(symbol_name_to_variable_name(rhs.name), self._type_for_symbol[rhs.name])
elif isinstance(rhs, sp.Number): elif type_constants and isinstance(rhs, sp.Number):
return cast_func(rhs, create_type(self._type_for_symbol['_constant'])) return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
elif isinstance(rhs, sp.Mul): elif isinstance(rhs, sp.Mul):
new_args = [self.process_expression(arg) if arg not in (-1, 1) else arg for arg in rhs.args] new_args = [self.process_expression(arg, type_constants) if arg not in (-1, 1) else arg for arg in rhs.args]
return rhs.func(*new_args) if new_args else rhs return rhs.func(*new_args) if new_args else rhs
elif isinstance(rhs, sp.Indexed): elif isinstance(rhs, sp.Indexed):
return rhs return rhs
else: else:
if isinstance(rhs, sp.Pow): if isinstance(rhs, sp.Pow):
# don't process exponents -> they should remain integers # don't process exponents -> they should remain integers
return sp.Pow(self.process_expression(rhs.args[0]), rhs.args[1]) return sp.Pow(self.process_expression(rhs.args[0], type_constants), rhs.args[1])
else: else:
new_args = [self.process_expression(arg) for arg in rhs.args] new_args = [self.process_expression(arg, type_constants) for arg in rhs.args]
return rhs.func(*new_args) if new_args else rhs return rhs.func(*new_args) if new_args else rhs
@property @property
...@@ -796,7 +796,7 @@ def add_types(eqs, type_for_symbol, check_independence_condition): ...@@ -796,7 +796,7 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
return check.process_assignment(obj) return check.process_assignment(obj)
elif isinstance(obj, ast.Conditional): elif isinstance(obj, ast.Conditional):
false_block = None if obj.false_block is None else visit(obj.false_block) false_block = None if obj.false_block is None else visit(obj.false_block)
return ast.Conditional(check.process_expression(obj.condition_expr), return ast.Conditional(check.process_expression(obj.condition_expr, type_constants=False),
true_block=visit(obj.true_block), false_block=false_block) true_block=visit(obj.true_block), false_block=false_block)
elif isinstance(obj, ast.Block): elif isinstance(obj, ast.Block):
return ast.Block([visit(e) for e in obj.args]) return ast.Block([visit(e) for e in obj.args])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment