diff --git a/pystencils/rng.py b/pystencils/rng.py index c1daed1d4ba43167a33650c7bcf80a2b167885b5..fed90aceff97bb94e82ae8ca054280c0140f203d 100644 --- a/pystencils/rng.py +++ b/pystencils/rng.py @@ -2,7 +2,7 @@ import copy import numpy as np import sympy as sp -from pystencils.data_types import TypedSymbol +from pystencils.data_types import TypedSymbol, cast_func from pystencils.astnodes import LoopOverCoordinate from pystencils.backends.cbackend import CustomCodeNode from pystencils.sympyextensions import fast_subs @@ -47,7 +47,7 @@ class RNGBase(CustomCodeNode): def get_code(self, dialect, vector_instruction_set, print_arg): code = "\n" for r in self.result_symbols: - if vector_instruction_set and isinstance(self.args[1], sp.Integer): + if vector_instruction_set and not self.args[1].atoms(cast_func): # this vector RNG has become scalar through substitution code += f"{r.dtype} {r.name};\n" else: diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 6fd80687bd36c401303b7c8c215a5be3a64a774c..fc09f34e439c51c57b6579781619522eb166b510 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -576,8 +576,8 @@ def move_constants_before_loop(ast_node): if isinstance(element, ast.Conditional): break else: - critical_symbols = element.symbols_defined - if node.undefined_symbols.intersection(critical_symbols): + critical_symbols = set([s.name for s in element.symbols_defined]) + if set([s.name for s in node.undefined_symbols]).intersection(critical_symbols): break prev_element = element element = element.parent