diff --git a/ast.py b/astnodes.py similarity index 100% rename from ast.py rename to astnodes.py diff --git a/backends/cbackend.py b/backends/cbackend.py index f59741d5cc1b153215e8cd901366cd6cdebf01ca..72be644e809d1bcfa0309b8e7acfe44711d578c7 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -1,6 +1,6 @@ import textwrap from sympy.utilities.codegen import CCodePrinter -from pystencils.ast import Node +from pystencils.astnodes import Node def generateC(astNode): diff --git a/backends/dot.py b/backends/dot.py index b8d99f75f454f041123646724370c4f494507350..2c7e6ce2ce6c03ec901b0946a0c02162ce1ed2a8 100644 --- a/backends/dot.py +++ b/backends/dot.py @@ -37,7 +37,7 @@ class DotPrinter(Printer): def __shortened(node): - from pystencils.ast import LoopOverCoordinate, KernelFunction, SympyAssignment + from pystencils.astnodes import LoopOverCoordinate, KernelFunction, SympyAssignment if isinstance(node, LoopOverCoordinate): return "Loop over dim %d" % (node.coordinateToLoopOver,) elif isinstance(node, KernelFunction): diff --git a/cpu/kerncraft.py b/cpu/kerncraft.py index 40238c8f27d68f783ac3966c00b4d5fd08a825a6..3c00c167c8904fa8406c1ee794349d95fbe41a2b 100644 --- a/cpu/kerncraft.py +++ b/cpu/kerncraft.py @@ -1,6 +1,6 @@ from pystencils.transformations import makeLoopOverDomain, typingFromSympyInspection, \ typeAllEquations, moveConstantsBeforeLoop, getOptimalLoopOrdering -import pystencils.ast as ast +import pystencils.astnodes as ast from pystencils.backends.cbackend import CBackend, CustomSympyPrinter from pystencils import TypedSymbol diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index 13fb0d785d0e064bcd1db84534cd850135021d38..b3340a39b3a3a9bfb80d4d8e8885a7db4240a66d 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -3,7 +3,7 @@ from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop from pystencils.types import TypedSymbol, DataType from pystencils.field import Field -import pystencils.ast as ast +import pystencils.astnodes as ast def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, splitGroups=(), diff --git a/display_utils.py b/display_utils.py index 490928a81cca6cd7a0f52b3d120315e3c6a9f136..785fcd5d76f0b5ba47a63db16608fb7d57d3cc61 100644 --- a/display_utils.py +++ b/display_utils.py @@ -2,7 +2,7 @@ def toDot(expr, graphStyle={}): """Show a sympy or pystencils AST as dot graph""" - from pystencils.ast import Node + from pystencils.astnodes import Node import graphviz if isinstance(expr, Node): from pystencils.backends.dot import dotprint diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py index f3b68c2f7eafdf2d2ccffbf9552911bee637c141..43118652efe822d7bfb4c7c22c7140a09315eb04 100644 --- a/gpucuda/kernelcreation.py +++ b/gpucuda/kernelcreation.py @@ -2,7 +2,7 @@ import sympy as sp from pystencils.transformations import resolveFieldAccesses, typeAllEquations, \ parseBasePointerInfo, typingFromSympyInspection -from pystencils.ast import Block, KernelFunction +from pystencils.astnodes import Block, KernelFunction from pystencils import Field BLOCK_IDX = list(sp.symbols("blockIdx.x blockIdx.y blockIdx.z")) diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index e0957fc8d551661deddd3153010783f62a5f62a3..a13001973936d7ea95cb740a878c0a106ff660c9 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -3,7 +3,7 @@ from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop from pystencils.types import TypedSymbol, DataType from pystencils.field import Field -import pystencils.ast as ast +import pystencils.astnodes as ast def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, splitGroups=(), diff --git a/sympyextensions.py b/sympyextensions.py index 1d3203e662572e8a30181385d529d1ea9193c30c..48e1b598b671025334cf55c4d6bc5e990eff75a3 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -1,7 +1,7 @@ -import sympy as sp import operator from collections import defaultdict, Sequence import warnings +import sympy as sp def fastSubs(term, subsDict): @@ -155,9 +155,13 @@ def replaceSecondOrderProducts(expr, searchSymbols, positive=None, replaceMixed= def removeHigherOrderTerms(term, order=3, symbols=None): """ - Remove all terms from a sum that contain 'order' or more factors of given 'symbols' - Example: symbols = ['u_x', 'u_y'] and order =2 - removes terms u_x**2, u_x*u_y, u_y**2, u_x**3, .... + Removes all terms that that contain more than 'order' factors of given 'symbols' + + Example: + >>> x, y = sp.symbols("x y") + >>> term = x**2 * y + y**2 * x + y**3 + x + y ** 2 + >>> removeHigherOrderTerms(term, order=2, symbols=[x, y]) + x + y**2 """ from sympy.core.power import Pow from sympy.core.add import Add, Mul @@ -171,15 +175,19 @@ def removeHigherOrderTerms(term, order=3, symbols=None): def velocityFactorsInProduct(product): uFactorCount = 0 - for factor in product.args: - if type(factor) == Pow: - if factor.args[0] in symbols: - uFactorCount += factor.args[1] - if factor in symbols: - uFactorCount += 1 + if type(product) is Mul: + for factor in product.args: + if type(factor) == Pow: + if factor.args[0] in symbols: + uFactorCount += factor.args[1] + if factor in symbols: + uFactorCount += 1 + elif type(product) is Pow: + if product.args[0] in symbols: + uFactorCount += product.args[1] return uFactorCount - if type(term) == Mul: + if type(term) == Mul or type(term) == Pow: if velocityFactorsInProduct(term) <= order: return term else: diff --git a/transformations.py b/transformations.py index 816e46a089f6827c5ae3b2d108ddca05b8fb8a54..c93d98ca291521838efda267d9e62bbe4a684ecf 100644 --- a/transformations.py +++ b/transformations.py @@ -6,7 +6,7 @@ from sympy.tensor import IndexedBase from pystencils.field import Field, offsetComponentToDirectionString from pystencils.types import TypedSymbol, DataType from pystencils.slicing import normalizeSlice -import pystencils.ast as ast +import pystencils.astnodes as ast def fastSubs(term, subsDict):