Skip to content
Snippets Groups Projects
Commit 6bb096fd authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'vectorization' into 'master'

some fixes for lbmpy vectorization

See merge request pycodegen/pystencils!216
parents 1f9d65ed e46f1658
No related merge requests found
......@@ -209,10 +209,11 @@ def insert_vector_casts(ast_node):
if expr.func is sp.Mul and expr.args[0] == -1:
# special treatment for the unary minus: make sure that the -1 has the same type as the argument
dtype = int
for arg in expr.args[1:]:
if type(arg) is vector_memory_access and arg.dtype.base_type.is_float():
for arg in expr.atoms(vector_memory_access):
if arg.dtype.base_type.is_float():
dtype = arg.dtype.base_type.numpy_dtype.type
elif type(arg) is TypedSymbol and type(arg.dtype) is VectorType and arg.dtype.base_type.is_float():
for arg in expr.atoms(TypedSymbol):
if type(arg.dtype) is VectorType and arg.dtype.base_type.is_float():
dtype = arg.dtype.base_type.numpy_dtype.type
if dtype is not int:
if dtype is np.float32:
......
......@@ -437,9 +437,10 @@ class AssignmentCollection:
class SymbolGen:
"""Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
def __init__(self, symbol="xi"):
def __init__(self, symbol="xi", dtype=None):
self._ctr = 0
self._symbol = symbol
self._dtype = dtype
def __iter__(self):
return self
......@@ -447,4 +448,6 @@ class SymbolGen:
def __next__(self):
name = f"{self._symbol}_{self._ctr}"
self._ctr += 1
if self._dtype is not None:
return pystencils.TypedSymbol(name, self._dtype)
return sp.Symbol(name)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment