From 58eb707a53f4f8fc801ee7bffd74f40e7455955e Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Mon, 22 Feb 2021 16:32:21 +0100
Subject: [PATCH] move_constants_before_loop: only compare symbol names, not
 types

The vectorized RNG can switch between scalar and vector types due to loop cutting.
---
 pystencils/rng.py             | 4 ++--
 pystencils/transformations.py | 4 ++--
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/pystencils/rng.py b/pystencils/rng.py
index c1daed1d4..fed90acef 100644
--- a/pystencils/rng.py
+++ b/pystencils/rng.py
@@ -2,7 +2,7 @@ import copy
 import numpy as np
 import sympy as sp
 
-from pystencils.data_types import TypedSymbol
+from pystencils.data_types import TypedSymbol, cast_func
 from pystencils.astnodes import LoopOverCoordinate
 from pystencils.backends.cbackend import CustomCodeNode
 from pystencils.sympyextensions import fast_subs
@@ -47,7 +47,7 @@ class RNGBase(CustomCodeNode):
     def get_code(self, dialect, vector_instruction_set, print_arg):
         code = "\n"
         for r in self.result_symbols:
-            if vector_instruction_set and isinstance(self.args[1], sp.Integer):
+            if vector_instruction_set and not self.args[1].atoms(cast_func):
                 # this vector RNG has become scalar through substitution
                 code += f"{r.dtype} {r.name};\n"
             else:
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index 6fd80687b..fc09f34e4 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -576,8 +576,8 @@ def move_constants_before_loop(ast_node):
             if isinstance(element, ast.Conditional):
                 break
             else:
-                critical_symbols = element.symbols_defined
-            if node.undefined_symbols.intersection(critical_symbols):
+                critical_symbols = set([s.name for s in element.symbols_defined])
+            if set([s.name for s in node.undefined_symbols]).intersection(critical_symbols):
                 break
             prev_element = element
             element = element.parent
-- 
GitLab