Skip to content
Snippets Groups Projects
Commit 237648ae authored by Martin Bauer's avatar Martin Bauer
Browse files

lbmpy: bugfix & tests for split optimization

parent 93b1d694
No related merge requests found
......@@ -3,7 +3,7 @@ import sympy as sp
from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, \
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
from pystencils.types import TypedSymbol, BasicType, StructType
from pystencils.types import TypedSymbol, BasicType, StructType, createType
from pystencils.field import Field
import pystencils.astnodes as ast
......@@ -30,11 +30,17 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
:return: :class:`pystencils.ast.KernelFunction` node
"""
if typeForSymbol is None:
typeForSymbol = 'double'
def typeSymbol(term):
if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
return TypedSymbol(term.name, typeForSymbol[term.name])
if isinstance(typeForSymbol, str):
return TypedSymbol(term.name, createType(typeForSymbol))
else:
return TypedSymbol(term.name, typeForSymbol[term.name])
else:
raise ValueError("Term has to be field access or symbol")
......
......@@ -326,7 +326,9 @@ def countNumberOfOperations(term):
elif t.func is sp.Float:
pass
elif isinstance(t, sp.Symbol):
pass
visitChildren = False
elif isinstance(t, sp.tensor.Indexed):
visitChildren = False
elif t.is_integer:
pass
elif t.func is sp.Pow:
......@@ -352,6 +354,24 @@ def countNumberOfOperations(term):
return result
def countNumberOfOperationsInAst(ast):
"""Counts number of operations in an abstract syntax tree, see also :func:`countNumberOfOperations`"""
from pystencils.astnodes import SympyAssignment
result = {'adds': 0, 'muls': 0, 'divs': 0}
def visit(node):
if isinstance(node, SympyAssignment):
r = countNumberOfOperations(node.rhs)
result['adds'] += r['adds']
result['muls'] += r['muls']
result['divs'] += r['divs']
else:
for arg in node.args:
visit(arg)
visit(ast)
return result
def matrixFromColumnVectors(columnVectors):
"""Creates a sympy matrix from column vectors.
:param columnVectors: nested sequence - i.e. a sequence of column vectors
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment