From b91e6990d23947b5e1ad8a870c8547a7175cf12d Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Mon, 18 Mar 2019 14:16:37 +0100
Subject: [PATCH] Correct automatic typing (double/bool) also for staggered
 kernels

---
 kernelcreation.py  |  5 +++--
 transformations.py | 15 ++++++++++-----
 2 files changed, 13 insertions(+), 7 deletions(-)

diff --git a/kernelcreation.py b/kernelcreation.py
index 69a80cba1..504404920 100644
--- a/kernelcreation.py
+++ b/kernelcreation.py
@@ -216,8 +216,9 @@ def create_staggered_kernel(staggered_field, expressions, subexpressions=(), tar
                                           for d in dimensions])
         sp_assignments = [SympyAssignment(a.lhs, a.rhs) for a in a_coll.all_assignments]
         if as_else_block and last_conditional:
-            last_conditional.false_block = Conditional(condition, Block(sp_assignments))
-            last_conditional = last_conditional.false_block
+            new_cond = Conditional(condition, Block(sp_assignments))
+            last_conditional.false_block = Block([new_cond])
+            last_conditional = new_cond
         else:
             last_conditional = Conditional(condition, Block(sp_assignments))
             final_assignments.append(last_conditional)
diff --git a/transformations.py b/transformations.py
index 787438f37..fa683e1d5 100644
--- a/transformations.py
+++ b/transformations.py
@@ -923,12 +923,17 @@ def typing_from_sympy_inspection(eqs, default_type="double"):
     """
     result = defaultdict(lambda: default_type)
     for eq in eqs:
-        if isinstance(eq, ast.Node):
+        if isinstance(eq, ast.Conditional):
+            result.update(typing_from_sympy_inspection(eq.true_block.args))
+            if eq.false_block:
+                result.update(typing_from_sympy_inspection(eq.false_block.args))
+        elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment):
             continue
-        # problematic case here is when rhs is a symbol: then it is impossible to decide here without
-        # further information what type the left hand side is - default fallback is the dict value then
-        if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
-            result[eq.lhs.name] = "bool"
+        else:
+            # problematic case here is when rhs is a symbol: then it is impossible to decide here without
+            # further information what type the left hand side is - default fallback is the dict value then
+            if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
+                result[eq.lhs.name] = "bool"
     return result
 
 
-- 
GitLab