From 237648aee0865e5b348b48adb4a310e565760443 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Sat, 15 Apr 2017 20:17:37 +0200 Subject: [PATCH] lbmpy: bugfix & tests for split optimization --- cpu/kernelcreation.py | 10 ++++++++-- sympyextensions.py | 22 +++++++++++++++++++++- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index d05c3143b..8ae94cbbf 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -3,7 +3,7 @@ import sympy as sp from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, \ typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop -from pystencils.types import TypedSymbol, BasicType, StructType +from pystencils.types import TypedSymbol, BasicType, StructType, createType from pystencils.field import Field import pystencils.astnodes as ast @@ -30,11 +30,17 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl :return: :class:`pystencils.ast.KernelFunction` node """ + if typeForSymbol is None: + typeForSymbol = 'double' + def typeSymbol(term): if isinstance(term, Field.Access) or isinstance(term, TypedSymbol): return term elif isinstance(term, sp.Symbol): - return TypedSymbol(term.name, typeForSymbol[term.name]) + if isinstance(typeForSymbol, str): + return TypedSymbol(term.name, createType(typeForSymbol)) + else: + return TypedSymbol(term.name, typeForSymbol[term.name]) else: raise ValueError("Term has to be field access or symbol") diff --git a/sympyextensions.py b/sympyextensions.py index 73dc9e2fa..97169a4ce 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -326,7 +326,9 @@ def countNumberOfOperations(term): elif t.func is sp.Float: pass elif isinstance(t, sp.Symbol): - pass + visitChildren = False + elif isinstance(t, sp.tensor.Indexed): + visitChildren = False elif t.is_integer: pass elif t.func is sp.Pow: @@ -352,6 +354,24 @@ def countNumberOfOperations(term): return result +def countNumberOfOperationsInAst(ast): + """Counts number of operations in an abstract syntax tree, see also :func:`countNumberOfOperations`""" + from pystencils.astnodes import SympyAssignment + result = {'adds': 0, 'muls': 0, 'divs': 0} + + def visit(node): + if isinstance(node, SympyAssignment): + r = countNumberOfOperations(node.rhs) + result['adds'] += r['adds'] + result['muls'] += r['muls'] + result['divs'] += r['divs'] + else: + for arg in node.args: + visit(arg) + visit(ast) + return result + + def matrixFromColumnVectors(columnVectors): """Creates a sympy matrix from column vectors. :param columnVectors: nested sequence - i.e. a sequence of column vectors -- GitLab