From 00738cb8a10e09dcf5a6e87a619c4748cd01ca8e Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Wed, 4 May 2022 16:37:24 +0200 Subject: [PATCH] More fixes --- pystencils/kernel_contrains_check.py | 3 +++ pystencils/kernelcreation.py | 2 +- pystencils/node_collection.py | 11 ++++++++++- pystencils/typing/cast_functions.py | 1 + pystencils/typing/utilities.py | 2 -- 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/pystencils/kernel_contrains_check.py b/pystencils/kernel_contrains_check.py index f79677656..b4b681e1c 100644 --- a/pystencils/kernel_contrains_check.py +++ b/pystencils/kernel_contrains_check.py @@ -54,12 +54,15 @@ class KernelConstraintsCheck: # Disable double write check inside conditionals # would be triggered by e.g. in-kernel boundaries old_double_write = self.check_double_write_condition + old_independence_condition = self.check_independence_condition self.check_double_write_condition = False + self.check_independence_condition = False if obj.false_block: self.visit(obj.false_block) self.process_expression(obj.condition_expr) self.process_expression(obj.true_block) self.check_double_write_condition = old_double_write + self.check_independence_condition = old_independence_condition self.scopes.pop() elif isinstance(obj, ast.Block): self.scopes.push() diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py index cd744a0d9..4b02ca13d 100644 --- a/pystencils/kernelcreation.py +++ b/pystencils/kernelcreation.py @@ -75,7 +75,7 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol warnings.warn(f"It was not possible to apply the default pystencils optimisations to the " f"AssignmentCollection due to the following problem :{e}") simplification_hints = assignments.simplification_hints - assignments = NodeCollection(assignments.all_assignments) + assignments = NodeCollection.from_assignment_collection(assignments) assignments.simplification_hints = simplification_hints if config.index_fields: diff --git a/pystencils/node_collection.py b/pystencils/node_collection.py index b393df0ed..821eda99b 100644 --- a/pystencils/node_collection.py +++ b/pystencils/node_collection.py @@ -31,7 +31,16 @@ class NodeCollection: @staticmethod def from_assignment_collection(assignment_collection: AssignmentCollection): - return NodeCollection([SympyAssignment(a.lhs, a.rhs) for a in assignment_collection.all_assignments]) + nodes = list() + for assignemt in assignment_collection.all_assignments: + if isinstance(assignemt, Assignment): + nodes.append(SympyAssignment(assignemt.lhs, assignemt.rhs)) + elif isinstance(assignemt, Node): + nodes.append(assignemt) + else: + raise ValueError(f"Unknown node in the AssignmentCollection: {assignemt}") + + return NodeCollection(nodes) def evaluate_terms(self): evaluate_constant_terms = ReplaceOptim( diff --git a/pystencils/typing/cast_functions.py b/pystencils/typing/cast_functions.py index 1b83d223c..9e9ad372e 100644 --- a/pystencils/typing/cast_functions.py +++ b/pystencils/typing/cast_functions.py @@ -89,6 +89,7 @@ class CastFunc(sp.Function): else: return super().is_nonnegative + @property def is_real(self): """ diff --git a/pystencils/typing/utilities.py b/pystencils/typing/utilities.py index c5612e935..da40c510e 100644 --- a/pystencils/typing/utilities.py +++ b/pystencils/typing/utilities.py @@ -129,8 +129,6 @@ def get_type_of_expression(expr, expr = sp.sympify(expr) if isinstance(expr, sp.Integer): return create_type(default_int_type) - elif expr.is_real is False: - return create_type((np.zeros((1,), default_float_type) * 1j).dtype) elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float): return create_type(default_float_type) elif isinstance(expr, ResolvedFieldAccess): -- GitLab