From 00738cb8a10e09dcf5a6e87a619c4748cd01ca8e Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Wed, 4 May 2022 16:37:24 +0200
Subject: [PATCH] More fixes

---
 pystencils/kernel_contrains_check.py |  3 +++
 pystencils/kernelcreation.py         |  2 +-
 pystencils/node_collection.py        | 11 ++++++++++-
 pystencils/typing/cast_functions.py  |  1 +
 pystencils/typing/utilities.py       |  2 --
 5 files changed, 15 insertions(+), 4 deletions(-)

diff --git a/pystencils/kernel_contrains_check.py b/pystencils/kernel_contrains_check.py
index f79677656..b4b681e1c 100644
--- a/pystencils/kernel_contrains_check.py
+++ b/pystencils/kernel_contrains_check.py
@@ -54,12 +54,15 @@ class KernelConstraintsCheck:
             # Disable double write check inside conditionals
             # would be triggered by e.g. in-kernel boundaries
             old_double_write = self.check_double_write_condition
+            old_independence_condition = self.check_independence_condition
             self.check_double_write_condition = False
+            self.check_independence_condition = False
             if obj.false_block:
                 self.visit(obj.false_block)
             self.process_expression(obj.condition_expr)
             self.process_expression(obj.true_block)
             self.check_double_write_condition = old_double_write
+            self.check_independence_condition = old_independence_condition
             self.scopes.pop()
         elif isinstance(obj, ast.Block):
             self.scopes.push()
diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index cd744a0d9..4b02ca13d 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -75,7 +75,7 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol
             warnings.warn(f"It was not possible to apply the default pystencils optimisations to the "
                           f"AssignmentCollection due to the following problem :{e}")
         simplification_hints = assignments.simplification_hints
-        assignments = NodeCollection(assignments.all_assignments)
+        assignments = NodeCollection.from_assignment_collection(assignments)
         assignments.simplification_hints = simplification_hints
 
     if config.index_fields:
diff --git a/pystencils/node_collection.py b/pystencils/node_collection.py
index b393df0ed..821eda99b 100644
--- a/pystencils/node_collection.py
+++ b/pystencils/node_collection.py
@@ -31,7 +31,16 @@ class NodeCollection:
 
     @staticmethod
     def from_assignment_collection(assignment_collection: AssignmentCollection):
-        return NodeCollection([SympyAssignment(a.lhs, a.rhs) for a in assignment_collection.all_assignments])
+        nodes = list()
+        for assignemt in assignment_collection.all_assignments:
+            if isinstance(assignemt, Assignment):
+                nodes.append(SympyAssignment(assignemt.lhs, assignemt.rhs))
+            elif isinstance(assignemt, Node):
+                nodes.append(assignemt)
+            else:
+                raise ValueError(f"Unknown node in the AssignmentCollection: {assignemt}")
+
+        return NodeCollection(nodes)
 
     def evaluate_terms(self):
         evaluate_constant_terms = ReplaceOptim(
diff --git a/pystencils/typing/cast_functions.py b/pystencils/typing/cast_functions.py
index 1b83d223c..9e9ad372e 100644
--- a/pystencils/typing/cast_functions.py
+++ b/pystencils/typing/cast_functions.py
@@ -89,6 +89,7 @@ class CastFunc(sp.Function):
         else:
             return super().is_nonnegative
 
+
     @property
     def is_real(self):
         """
diff --git a/pystencils/typing/utilities.py b/pystencils/typing/utilities.py
index c5612e935..da40c510e 100644
--- a/pystencils/typing/utilities.py
+++ b/pystencils/typing/utilities.py
@@ -129,8 +129,6 @@ def get_type_of_expression(expr,
     expr = sp.sympify(expr)
     if isinstance(expr, sp.Integer):
         return create_type(default_int_type)
-    elif expr.is_real is False:
-        return create_type((np.zeros((1,), default_float_type) * 1j).dtype)
     elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
         return create_type(default_float_type)
     elif isinstance(expr, ResolvedFieldAccess):
-- 
GitLab