From 8e483da4da549b2bc5fe398ef798f10397a05ed5 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 3 Feb 2017 16:11:19 +0100
Subject: [PATCH] Entropic to new lbmpy

---
 ast.py => astnodes.py     |  0
 backends/cbackend.py      |  2 +-
 backends/dot.py           |  2 +-
 cpu/kerncraft.py          |  2 +-
 cpu/kernelcreation.py     |  2 +-
 display_utils.py          |  2 +-
 gpucuda/kernelcreation.py |  2 +-
 llvm/kernelcreation.py    |  2 +-
 sympyextensions.py        | 30 +++++++++++++++++++-----------
 transformations.py        |  2 +-
 10 files changed, 27 insertions(+), 19 deletions(-)
 rename ast.py => astnodes.py (100%)

diff --git a/ast.py b/astnodes.py
similarity index 100%
rename from ast.py
rename to astnodes.py
diff --git a/backends/cbackend.py b/backends/cbackend.py
index f59741d5c..72be644e8 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -1,6 +1,6 @@
 import textwrap
 from sympy.utilities.codegen import CCodePrinter
-from pystencils.ast import Node
+from pystencils.astnodes import Node
 
 
 def generateC(astNode):
diff --git a/backends/dot.py b/backends/dot.py
index b8d99f75f..2c7e6ce2c 100644
--- a/backends/dot.py
+++ b/backends/dot.py
@@ -37,7 +37,7 @@ class DotPrinter(Printer):
 
 
 def __shortened(node):
-    from pystencils.ast import LoopOverCoordinate, KernelFunction, SympyAssignment
+    from pystencils.astnodes import LoopOverCoordinate, KernelFunction, SympyAssignment
     if isinstance(node, LoopOverCoordinate):
         return "Loop over dim %d" % (node.coordinateToLoopOver,)
     elif isinstance(node, KernelFunction):
diff --git a/cpu/kerncraft.py b/cpu/kerncraft.py
index 40238c8f2..3c00c167c 100644
--- a/cpu/kerncraft.py
+++ b/cpu/kerncraft.py
@@ -1,6 +1,6 @@
 from pystencils.transformations import makeLoopOverDomain, typingFromSympyInspection, \
     typeAllEquations, moveConstantsBeforeLoop, getOptimalLoopOrdering
-import pystencils.ast as ast
+import pystencils.astnodes as ast
 from pystencils.backends.cbackend import CBackend, CustomSympyPrinter
 from pystencils import TypedSymbol
 
diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py
index 13fb0d785..b3340a39b 100644
--- a/cpu/kernelcreation.py
+++ b/cpu/kernelcreation.py
@@ -3,7 +3,7 @@ from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain,
     typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
 from pystencils.types import TypedSymbol, DataType
 from pystencils.field import Field
-import pystencils.ast as ast
+import pystencils.astnodes as ast
 
 
 def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, splitGroups=(),
diff --git a/display_utils.py b/display_utils.py
index 490928a81..785fcd5d7 100644
--- a/display_utils.py
+++ b/display_utils.py
@@ -2,7 +2,7 @@
 
 def toDot(expr, graphStyle={}):
     """Show a sympy or pystencils AST as dot graph"""
-    from pystencils.ast import Node
+    from pystencils.astnodes import Node
     import graphviz
     if isinstance(expr, Node):
         from pystencils.backends.dot import dotprint
diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py
index f3b68c2f7..43118652e 100644
--- a/gpucuda/kernelcreation.py
+++ b/gpucuda/kernelcreation.py
@@ -2,7 +2,7 @@ import sympy as sp
 
 from pystencils.transformations import resolveFieldAccesses, typeAllEquations, \
     parseBasePointerInfo, typingFromSympyInspection
-from pystencils.ast import Block, KernelFunction
+from pystencils.astnodes import Block, KernelFunction
 from pystencils import Field
 
 BLOCK_IDX = list(sp.symbols("blockIdx.x blockIdx.y blockIdx.z"))
diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py
index e0957fc8d..a13001973 100644
--- a/llvm/kernelcreation.py
+++ b/llvm/kernelcreation.py
@@ -3,7 +3,7 @@ from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain,
     typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
 from pystencils.types import TypedSymbol, DataType
 from pystencils.field import Field
-import pystencils.ast as ast
+import pystencils.astnodes as ast
 
 
 def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, splitGroups=(),
diff --git a/sympyextensions.py b/sympyextensions.py
index 1d3203e66..48e1b598b 100644
--- a/sympyextensions.py
+++ b/sympyextensions.py
@@ -1,7 +1,7 @@
-import sympy as sp
 import operator
 from collections import defaultdict, Sequence
 import warnings
+import sympy as sp
 
 
 def fastSubs(term, subsDict):
@@ -155,9 +155,13 @@ def replaceSecondOrderProducts(expr, searchSymbols, positive=None, replaceMixed=
 
 def removeHigherOrderTerms(term, order=3, symbols=None):
     """
-    Remove all terms from a sum that contain 'order' or more factors of given 'symbols'
-    Example: symbols = ['u_x', 'u_y'] and order =2
-             removes terms u_x**2, u_x*u_y, u_y**2, u_x**3, ....
+    Removes all terms that that contain more than 'order' factors of given 'symbols'
+
+    Example:
+        >>> x, y = sp.symbols("x y")
+        >>> term = x**2 * y + y**2 * x + y**3 + x + y ** 2
+        >>> removeHigherOrderTerms(term, order=2, symbols=[x, y])
+        x + y**2
     """
     from sympy.core.power import Pow
     from sympy.core.add import Add, Mul
@@ -171,15 +175,19 @@ def removeHigherOrderTerms(term, order=3, symbols=None):
 
     def velocityFactorsInProduct(product):
         uFactorCount = 0
-        for factor in product.args:
-            if type(factor) == Pow:
-                if factor.args[0] in symbols:
-                    uFactorCount += factor.args[1]
-            if factor in symbols:
-                uFactorCount += 1
+        if type(product) is Mul:
+            for factor in product.args:
+                if type(factor) == Pow:
+                    if factor.args[0] in symbols:
+                        uFactorCount += factor.args[1]
+                if factor in symbols:
+                    uFactorCount += 1
+        elif type(product) is Pow:
+            if product.args[0] in symbols:
+                uFactorCount += product.args[1]
         return uFactorCount
 
-    if type(term) == Mul:
+    if type(term) == Mul or type(term) == Pow:
         if velocityFactorsInProduct(term) <= order:
             return term
         else:
diff --git a/transformations.py b/transformations.py
index 816e46a08..c93d98ca2 100644
--- a/transformations.py
+++ b/transformations.py
@@ -6,7 +6,7 @@ from sympy.tensor import IndexedBase
 from pystencils.field import Field, offsetComponentToDirectionString
 from pystencils.types import TypedSymbol, DataType
 from pystencils.slicing import normalizeSlice
-import pystencils.ast as ast
+import pystencils.astnodes as ast
 
 
 def fastSubs(term, subsDict):
-- 
GitLab