From 237648aee0865e5b348b48adb4a310e565760443 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Sat, 15 Apr 2017 20:17:37 +0200
Subject: [PATCH] lbmpy: bugfix & tests for split optimization

---
 cpu/kernelcreation.py | 10 ++++++++--
 sympyextensions.py    | 22 +++++++++++++++++++++-
 2 files changed, 29 insertions(+), 3 deletions(-)

diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py
index d05c3143b..8ae94cbbf 100644
--- a/cpu/kernelcreation.py
+++ b/cpu/kernelcreation.py
@@ -3,7 +3,7 @@ import sympy as sp
 from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
 from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, \
     typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
-from pystencils.types import TypedSymbol, BasicType, StructType
+from pystencils.types import TypedSymbol, BasicType, StructType, createType
 from pystencils.field import Field
 import pystencils.astnodes as ast
 
@@ -30,11 +30,17 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
 
     :return: :class:`pystencils.ast.KernelFunction` node
     """
+    if typeForSymbol is None:
+        typeForSymbol = 'double'
+
     def typeSymbol(term):
         if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
             return term
         elif isinstance(term, sp.Symbol):
-            return TypedSymbol(term.name, typeForSymbol[term.name])
+            if isinstance(typeForSymbol, str):
+                return TypedSymbol(term.name, createType(typeForSymbol))
+            else:
+                return TypedSymbol(term.name, typeForSymbol[term.name])
         else:
             raise ValueError("Term has to be field access or symbol")
 
diff --git a/sympyextensions.py b/sympyextensions.py
index 73dc9e2fa..97169a4ce 100644
--- a/sympyextensions.py
+++ b/sympyextensions.py
@@ -326,7 +326,9 @@ def countNumberOfOperations(term):
         elif t.func is sp.Float:
             pass
         elif isinstance(t, sp.Symbol):
-            pass
+            visitChildren = False
+        elif isinstance(t, sp.tensor.Indexed):
+            visitChildren = False
         elif t.is_integer:
             pass
         elif t.func is sp.Pow:
@@ -352,6 +354,24 @@ def countNumberOfOperations(term):
     return result
 
 
+def countNumberOfOperationsInAst(ast):
+    """Counts number of operations in an abstract syntax tree, see also :func:`countNumberOfOperations`"""
+    from pystencils.astnodes import SympyAssignment
+    result = {'adds': 0, 'muls': 0, 'divs': 0}
+
+    def visit(node):
+        if isinstance(node, SympyAssignment):
+            r = countNumberOfOperations(node.rhs)
+            result['adds'] += r['adds']
+            result['muls'] += r['muls']
+            result['divs'] += r['divs']
+        else:
+            for arg in node.args:
+                visit(arg)
+    visit(ast)
+    return result
+
+
 def matrixFromColumnVectors(columnVectors):
     """Creates a sympy matrix from column vectors.
         :param columnVectors: nested sequence - i.e. a sequence of column vectors
-- 
GitLab