diff --git a/kernelcreation.py b/kernelcreation.py index 69a80cba1478c992e4d9d1849a5cad9f81c37dce..504404920b5588d04c52aea4bc9d631f07e9f903 100644 --- a/kernelcreation.py +++ b/kernelcreation.py @@ -216,8 +216,9 @@ def create_staggered_kernel(staggered_field, expressions, subexpressions=(), tar for d in dimensions]) sp_assignments = [SympyAssignment(a.lhs, a.rhs) for a in a_coll.all_assignments] if as_else_block and last_conditional: - last_conditional.false_block = Conditional(condition, Block(sp_assignments)) - last_conditional = last_conditional.false_block + new_cond = Conditional(condition, Block(sp_assignments)) + last_conditional.false_block = Block([new_cond]) + last_conditional = new_cond else: last_conditional = Conditional(condition, Block(sp_assignments)) final_assignments.append(last_conditional) diff --git a/transformations.py b/transformations.py index 787438f37cf0f959567f338f70dd51f57dfb2f97..fa683e1d59c35073ab622a2f5b323b3d56ba20e8 100644 --- a/transformations.py +++ b/transformations.py @@ -923,12 +923,17 @@ def typing_from_sympy_inspection(eqs, default_type="double"): """ result = defaultdict(lambda: default_type) for eq in eqs: - if isinstance(eq, ast.Node): + if isinstance(eq, ast.Conditional): + result.update(typing_from_sympy_inspection(eq.true_block.args)) + if eq.false_block: + result.update(typing_from_sympy_inspection(eq.false_block.args)) + elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment): continue - # problematic case here is when rhs is a symbol: then it is impossible to decide here without - # further information what type the left hand side is - default fallback is the dict value then - if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol): - result[eq.lhs.name] = "bool" + else: + # problematic case here is when rhs is a symbol: then it is impossible to decide here without + # further information what type the left hand side is - default fallback is the dict value then + if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol): + result[eq.lhs.name] = "bool" return result