From 6fed27bbc7e59c1982b01384c75daf5d9d5bff4a Mon Sep 17 00:00:00 2001 From: Michael Kuron <mkuron@icp.uni-stuttgart.de> Date: Fri, 19 Feb 2021 16:40:58 +0100 Subject: [PATCH] some fixes for lbmpy vectorization --- pystencils/cpu/vectorization.py | 2 ++ pystencils/simp/assignment_collection.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 2f49d44a2..4c632b145 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -210,6 +210,8 @@ def insert_vector_casts(ast_node): # special treatment for the unary minus: make sure that the -1 has the same type as the argument dtype = int for arg in expr.args[1:]: + if type(arg) is sp.Pow: + arg = arg.args[0] if type(arg) is vector_memory_access and arg.dtype.base_type.is_float(): dtype = arg.dtype.base_type.numpy_dtype.type elif type(arg) is TypedSymbol and type(arg.dtype) is VectorType and arg.dtype.base_type.is_float(): diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index 950644089..33102dee5 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -437,9 +437,10 @@ class AssignmentCollection: class SymbolGen: """Default symbol generator producing number symbols ζ_0, ζ_1, ...""" - def __init__(self, symbol="xi"): + def __init__(self, symbol="xi", dtype=None): self._ctr = 0 self._symbol = symbol + self._dtype = dtype def __iter__(self): return self @@ -447,4 +448,6 @@ class SymbolGen: def __next__(self): name = f"{self._symbol}_{self._ctr}" self._ctr += 1 + if self._dtype is not None: + return pystencils.TypedSymbol(name, self._dtype) return sp.Symbol(name) -- GitLab