diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index d05c3143bdee8df0cb13229b2e6a2ec7524ff721..8ae94cbbfbd52984f1bcafc824b7007e8960ce71 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -3,7 +3,7 @@ import sympy as sp from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, \ typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop -from pystencils.types import TypedSymbol, BasicType, StructType +from pystencils.types import TypedSymbol, BasicType, StructType, createType from pystencils.field import Field import pystencils.astnodes as ast @@ -30,11 +30,17 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl :return: :class:`pystencils.ast.KernelFunction` node """ + if typeForSymbol is None: + typeForSymbol = 'double' + def typeSymbol(term): if isinstance(term, Field.Access) or isinstance(term, TypedSymbol): return term elif isinstance(term, sp.Symbol): - return TypedSymbol(term.name, typeForSymbol[term.name]) + if isinstance(typeForSymbol, str): + return TypedSymbol(term.name, createType(typeForSymbol)) + else: + return TypedSymbol(term.name, typeForSymbol[term.name]) else: raise ValueError("Term has to be field access or symbol") diff --git a/sympyextensions.py b/sympyextensions.py index 73dc9e2faa4751b79366dffc08e617ec2eac6963..97169a4ce6cfb855d51cfd845d93d1a7689e1eb3 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -326,7 +326,9 @@ def countNumberOfOperations(term): elif t.func is sp.Float: pass elif isinstance(t, sp.Symbol): - pass + visitChildren = False + elif isinstance(t, sp.tensor.Indexed): + visitChildren = False elif t.is_integer: pass elif t.func is sp.Pow: @@ -352,6 +354,24 @@ def countNumberOfOperations(term): return result +def countNumberOfOperationsInAst(ast): + """Counts number of operations in an abstract syntax tree, see also :func:`countNumberOfOperations`""" + from pystencils.astnodes import SympyAssignment + result = {'adds': 0, 'muls': 0, 'divs': 0} + + def visit(node): + if isinstance(node, SympyAssignment): + r = countNumberOfOperations(node.rhs) + result['adds'] += r['adds'] + result['muls'] += r['muls'] + result['divs'] += r['divs'] + else: + for arg in node.args: + visit(arg) + visit(ast) + return result + + def matrixFromColumnVectors(columnVectors): """Creates a sympy matrix from column vectors. :param columnVectors: nested sequence - i.e. a sequence of column vectors