From 8e483da4da549b2bc5fe398ef798f10397a05ed5 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Fri, 3 Feb 2017 16:11:19 +0100 Subject: [PATCH] Entropic to new lbmpy --- ast.py => astnodes.py | 0 backends/cbackend.py | 2 +- backends/dot.py | 2 +- cpu/kerncraft.py | 2 +- cpu/kernelcreation.py | 2 +- display_utils.py | 2 +- gpucuda/kernelcreation.py | 2 +- llvm/kernelcreation.py | 2 +- sympyextensions.py | 30 +++++++++++++++++++----------- transformations.py | 2 +- 10 files changed, 27 insertions(+), 19 deletions(-) rename ast.py => astnodes.py (100%) diff --git a/ast.py b/astnodes.py similarity index 100% rename from ast.py rename to astnodes.py diff --git a/backends/cbackend.py b/backends/cbackend.py index f59741d5c..72be644e8 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -1,6 +1,6 @@ import textwrap from sympy.utilities.codegen import CCodePrinter -from pystencils.ast import Node +from pystencils.astnodes import Node def generateC(astNode): diff --git a/backends/dot.py b/backends/dot.py index b8d99f75f..2c7e6ce2c 100644 --- a/backends/dot.py +++ b/backends/dot.py @@ -37,7 +37,7 @@ class DotPrinter(Printer): def __shortened(node): - from pystencils.ast import LoopOverCoordinate, KernelFunction, SympyAssignment + from pystencils.astnodes import LoopOverCoordinate, KernelFunction, SympyAssignment if isinstance(node, LoopOverCoordinate): return "Loop over dim %d" % (node.coordinateToLoopOver,) elif isinstance(node, KernelFunction): diff --git a/cpu/kerncraft.py b/cpu/kerncraft.py index 40238c8f2..3c00c167c 100644 --- a/cpu/kerncraft.py +++ b/cpu/kerncraft.py @@ -1,6 +1,6 @@ from pystencils.transformations import makeLoopOverDomain, typingFromSympyInspection, \ typeAllEquations, moveConstantsBeforeLoop, getOptimalLoopOrdering -import pystencils.ast as ast +import pystencils.astnodes as ast from pystencils.backends.cbackend import CBackend, CustomSympyPrinter from pystencils import TypedSymbol diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index 13fb0d785..b3340a39b 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -3,7 +3,7 @@ from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop from pystencils.types import TypedSymbol, DataType from pystencils.field import Field -import pystencils.ast as ast +import pystencils.astnodes as ast def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, splitGroups=(), diff --git a/display_utils.py b/display_utils.py index 490928a81..785fcd5d7 100644 --- a/display_utils.py +++ b/display_utils.py @@ -2,7 +2,7 @@ def toDot(expr, graphStyle={}): """Show a sympy or pystencils AST as dot graph""" - from pystencils.ast import Node + from pystencils.astnodes import Node import graphviz if isinstance(expr, Node): from pystencils.backends.dot import dotprint diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py index f3b68c2f7..43118652e 100644 --- a/gpucuda/kernelcreation.py +++ b/gpucuda/kernelcreation.py @@ -2,7 +2,7 @@ import sympy as sp from pystencils.transformations import resolveFieldAccesses, typeAllEquations, \ parseBasePointerInfo, typingFromSympyInspection -from pystencils.ast import Block, KernelFunction +from pystencils.astnodes import Block, KernelFunction from pystencils import Field BLOCK_IDX = list(sp.symbols("blockIdx.x blockIdx.y blockIdx.z")) diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index e0957fc8d..a13001973 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -3,7 +3,7 @@ from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop from pystencils.types import TypedSymbol, DataType from pystencils.field import Field -import pystencils.ast as ast +import pystencils.astnodes as ast def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, splitGroups=(), diff --git a/sympyextensions.py b/sympyextensions.py index 1d3203e66..48e1b598b 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -1,7 +1,7 @@ -import sympy as sp import operator from collections import defaultdict, Sequence import warnings +import sympy as sp def fastSubs(term, subsDict): @@ -155,9 +155,13 @@ def replaceSecondOrderProducts(expr, searchSymbols, positive=None, replaceMixed= def removeHigherOrderTerms(term, order=3, symbols=None): """ - Remove all terms from a sum that contain 'order' or more factors of given 'symbols' - Example: symbols = ['u_x', 'u_y'] and order =2 - removes terms u_x**2, u_x*u_y, u_y**2, u_x**3, .... + Removes all terms that that contain more than 'order' factors of given 'symbols' + + Example: + >>> x, y = sp.symbols("x y") + >>> term = x**2 * y + y**2 * x + y**3 + x + y ** 2 + >>> removeHigherOrderTerms(term, order=2, symbols=[x, y]) + x + y**2 """ from sympy.core.power import Pow from sympy.core.add import Add, Mul @@ -171,15 +175,19 @@ def removeHigherOrderTerms(term, order=3, symbols=None): def velocityFactorsInProduct(product): uFactorCount = 0 - for factor in product.args: - if type(factor) == Pow: - if factor.args[0] in symbols: - uFactorCount += factor.args[1] - if factor in symbols: - uFactorCount += 1 + if type(product) is Mul: + for factor in product.args: + if type(factor) == Pow: + if factor.args[0] in symbols: + uFactorCount += factor.args[1] + if factor in symbols: + uFactorCount += 1 + elif type(product) is Pow: + if product.args[0] in symbols: + uFactorCount += product.args[1] return uFactorCount - if type(term) == Mul: + if type(term) == Mul or type(term) == Pow: if velocityFactorsInProduct(term) <= order: return term else: diff --git a/transformations.py b/transformations.py index 816e46a08..c93d98ca2 100644 --- a/transformations.py +++ b/transformations.py @@ -6,7 +6,7 @@ from sympy.tensor import IndexedBase from pystencils.field import Field, offsetComponentToDirectionString from pystencils.types import TypedSymbol, DataType from pystencils.slicing import normalizeSlice -import pystencils.ast as ast +import pystencils.astnodes as ast def fastSubs(term, subsDict): -- GitLab