Skip to content
Snippets Groups Projects
Commit 6bb096fd authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'vectorization' into 'master'

some fixes for lbmpy vectorization

See merge request pycodegen/pystencils!216
parents 1f9d65ed e46f1658
No related merge requests found
...@@ -209,10 +209,11 @@ def insert_vector_casts(ast_node): ...@@ -209,10 +209,11 @@ def insert_vector_casts(ast_node):
if expr.func is sp.Mul and expr.args[0] == -1: if expr.func is sp.Mul and expr.args[0] == -1:
# special treatment for the unary minus: make sure that the -1 has the same type as the argument # special treatment for the unary minus: make sure that the -1 has the same type as the argument
dtype = int dtype = int
for arg in expr.args[1:]: for arg in expr.atoms(vector_memory_access):
if type(arg) is vector_memory_access and arg.dtype.base_type.is_float(): if arg.dtype.base_type.is_float():
dtype = arg.dtype.base_type.numpy_dtype.type dtype = arg.dtype.base_type.numpy_dtype.type
elif type(arg) is TypedSymbol and type(arg.dtype) is VectorType and arg.dtype.base_type.is_float(): for arg in expr.atoms(TypedSymbol):
if type(arg.dtype) is VectorType and arg.dtype.base_type.is_float():
dtype = arg.dtype.base_type.numpy_dtype.type dtype = arg.dtype.base_type.numpy_dtype.type
if dtype is not int: if dtype is not int:
if dtype is np.float32: if dtype is np.float32:
......
...@@ -437,9 +437,10 @@ class AssignmentCollection: ...@@ -437,9 +437,10 @@ class AssignmentCollection:
class SymbolGen: class SymbolGen:
"""Default symbol generator producing number symbols ζ_0, ζ_1, ...""" """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
def __init__(self, symbol="xi"): def __init__(self, symbol="xi", dtype=None):
self._ctr = 0 self._ctr = 0
self._symbol = symbol self._symbol = symbol
self._dtype = dtype
def __iter__(self): def __iter__(self):
return self return self
...@@ -447,4 +448,6 @@ class SymbolGen: ...@@ -447,4 +448,6 @@ class SymbolGen:
def __next__(self): def __next__(self):
name = f"{self._symbol}_{self._ctr}" name = f"{self._symbol}_{self._ctr}"
self._ctr += 1 self._ctr += 1
if self._dtype is not None:
return pystencils.TypedSymbol(name, self._dtype)
return sp.Symbol(name) return sp.Symbol(name)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment