diff --git a/ast.py b/astnodes.py
similarity index 100%
rename from ast.py
rename to astnodes.py
diff --git a/backends/cbackend.py b/backends/cbackend.py
index f59741d5cc1b153215e8cd901366cd6cdebf01ca..72be644e809d1bcfa0309b8e7acfe44711d578c7 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -1,6 +1,6 @@
 import textwrap
 from sympy.utilities.codegen import CCodePrinter
-from pystencils.ast import Node
+from pystencils.astnodes import Node
 
 
 def generateC(astNode):
diff --git a/backends/dot.py b/backends/dot.py
index b8d99f75f454f041123646724370c4f494507350..2c7e6ce2ce6c03ec901b0946a0c02162ce1ed2a8 100644
--- a/backends/dot.py
+++ b/backends/dot.py
@@ -37,7 +37,7 @@ class DotPrinter(Printer):
 
 
 def __shortened(node):
-    from pystencils.ast import LoopOverCoordinate, KernelFunction, SympyAssignment
+    from pystencils.astnodes import LoopOverCoordinate, KernelFunction, SympyAssignment
     if isinstance(node, LoopOverCoordinate):
         return "Loop over dim %d" % (node.coordinateToLoopOver,)
     elif isinstance(node, KernelFunction):
diff --git a/cpu/kerncraft.py b/cpu/kerncraft.py
index 40238c8f27d68f783ac3966c00b4d5fd08a825a6..3c00c167c8904fa8406c1ee794349d95fbe41a2b 100644
--- a/cpu/kerncraft.py
+++ b/cpu/kerncraft.py
@@ -1,6 +1,6 @@
 from pystencils.transformations import makeLoopOverDomain, typingFromSympyInspection, \
     typeAllEquations, moveConstantsBeforeLoop, getOptimalLoopOrdering
-import pystencils.ast as ast
+import pystencils.astnodes as ast
 from pystencils.backends.cbackend import CBackend, CustomSympyPrinter
 from pystencils import TypedSymbol
 
diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py
index 13fb0d785d0e064bcd1db84534cd850135021d38..b3340a39b3a3a9bfb80d4d8e8885a7db4240a66d 100644
--- a/cpu/kernelcreation.py
+++ b/cpu/kernelcreation.py
@@ -3,7 +3,7 @@ from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain,
     typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
 from pystencils.types import TypedSymbol, DataType
 from pystencils.field import Field
-import pystencils.ast as ast
+import pystencils.astnodes as ast
 
 
 def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, splitGroups=(),
diff --git a/display_utils.py b/display_utils.py
index 490928a81cca6cd7a0f52b3d120315e3c6a9f136..785fcd5d76f0b5ba47a63db16608fb7d57d3cc61 100644
--- a/display_utils.py
+++ b/display_utils.py
@@ -2,7 +2,7 @@
 
 def toDot(expr, graphStyle={}):
     """Show a sympy or pystencils AST as dot graph"""
-    from pystencils.ast import Node
+    from pystencils.astnodes import Node
     import graphviz
     if isinstance(expr, Node):
         from pystencils.backends.dot import dotprint
diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py
index f3b68c2f7eafdf2d2ccffbf9552911bee637c141..43118652efe822d7bfb4c7c22c7140a09315eb04 100644
--- a/gpucuda/kernelcreation.py
+++ b/gpucuda/kernelcreation.py
@@ -2,7 +2,7 @@ import sympy as sp
 
 from pystencils.transformations import resolveFieldAccesses, typeAllEquations, \
     parseBasePointerInfo, typingFromSympyInspection
-from pystencils.ast import Block, KernelFunction
+from pystencils.astnodes import Block, KernelFunction
 from pystencils import Field
 
 BLOCK_IDX = list(sp.symbols("blockIdx.x blockIdx.y blockIdx.z"))
diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py
index e0957fc8d551661deddd3153010783f62a5f62a3..a13001973936d7ea95cb740a878c0a106ff660c9 100644
--- a/llvm/kernelcreation.py
+++ b/llvm/kernelcreation.py
@@ -3,7 +3,7 @@ from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain,
     typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
 from pystencils.types import TypedSymbol, DataType
 from pystencils.field import Field
-import pystencils.ast as ast
+import pystencils.astnodes as ast
 
 
 def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, splitGroups=(),
diff --git a/sympyextensions.py b/sympyextensions.py
index 1d3203e662572e8a30181385d529d1ea9193c30c..48e1b598b671025334cf55c4d6bc5e990eff75a3 100644
--- a/sympyextensions.py
+++ b/sympyextensions.py
@@ -1,7 +1,7 @@
-import sympy as sp
 import operator
 from collections import defaultdict, Sequence
 import warnings
+import sympy as sp
 
 
 def fastSubs(term, subsDict):
@@ -155,9 +155,13 @@ def replaceSecondOrderProducts(expr, searchSymbols, positive=None, replaceMixed=
 
 def removeHigherOrderTerms(term, order=3, symbols=None):
     """
-    Remove all terms from a sum that contain 'order' or more factors of given 'symbols'
-    Example: symbols = ['u_x', 'u_y'] and order =2
-             removes terms u_x**2, u_x*u_y, u_y**2, u_x**3, ....
+    Removes all terms that that contain more than 'order' factors of given 'symbols'
+
+    Example:
+        >>> x, y = sp.symbols("x y")
+        >>> term = x**2 * y + y**2 * x + y**3 + x + y ** 2
+        >>> removeHigherOrderTerms(term, order=2, symbols=[x, y])
+        x + y**2
     """
     from sympy.core.power import Pow
     from sympy.core.add import Add, Mul
@@ -171,15 +175,19 @@ def removeHigherOrderTerms(term, order=3, symbols=None):
 
     def velocityFactorsInProduct(product):
         uFactorCount = 0
-        for factor in product.args:
-            if type(factor) == Pow:
-                if factor.args[0] in symbols:
-                    uFactorCount += factor.args[1]
-            if factor in symbols:
-                uFactorCount += 1
+        if type(product) is Mul:
+            for factor in product.args:
+                if type(factor) == Pow:
+                    if factor.args[0] in symbols:
+                        uFactorCount += factor.args[1]
+                if factor in symbols:
+                    uFactorCount += 1
+        elif type(product) is Pow:
+            if product.args[0] in symbols:
+                uFactorCount += product.args[1]
         return uFactorCount
 
-    if type(term) == Mul:
+    if type(term) == Mul or type(term) == Pow:
         if velocityFactorsInProduct(term) <= order:
             return term
         else:
diff --git a/transformations.py b/transformations.py
index 816e46a089f6827c5ae3b2d108ddca05b8fb8a54..c93d98ca291521838efda267d9e62bbe4a684ecf 100644
--- a/transformations.py
+++ b/transformations.py
@@ -6,7 +6,7 @@ from sympy.tensor import IndexedBase
 from pystencils.field import Field, offsetComponentToDirectionString
 from pystencils.types import TypedSymbol, DataType
 from pystencils.slicing import normalizeSlice
-import pystencils.ast as ast
+import pystencils.astnodes as ast
 
 
 def fastSubs(term, subsDict):