From 9f071fdd6c6bacfdbf3e8afdecbfd615304dd67c Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Fri, 30 Mar 2018 19:33:52 +0200 Subject: [PATCH] pystencils: Assignment instead of sympy.Eq - Previously sympy.Eq was used to represent assignments. However Eq represents equality not assignment. This means that sometimes sympy "simplified" an equation like a = a to True, -> replaced sp.Eq by pystencils.Assignment everywhere - renamed EquationCollection to AssignmentCollection --- __init__.py | 9 +- assignment.py | 14 +++ assignment_collection/__init__.py | 2 + .../assignment_collection.py | 79 ++++++++--------- assignment_collection/simplifications.py | 87 +++++++++++++++++++ .../simplificationstrategy.py | 26 +++--- backends/dot.py | 27 +----- boundaries/boundaryconditions.py | 6 +- boundaries/boundaryhandling.py | 3 +- cpu/kernelcreation.py | 4 +- equationcollection/__init__.py | 2 - equationcollection/simplifications.py | 87 ------------------- field.py | 25 +++--- finitedifferences.py | 6 +- gpucuda/periodicity.py | 4 +- kernelcreation.py | 8 +- llvm/llvm.py | 2 +- sympyextensions.py | 11 +-- transformations/transformations.py | 3 +- 19 files changed, 202 insertions(+), 203 deletions(-) create mode 100644 assignment.py create mode 100644 assignment_collection/__init__.py rename equationcollection/equationcollection.py => assignment_collection/assignment_collection.py (81%) create mode 100644 assignment_collection/simplifications.py rename {equationcollection => assignment_collection}/simplificationstrategy.py (87%) delete mode 100644 equationcollection/__init__.py delete mode 100644 equationcollection/simplifications.py diff --git a/__init__.py b/__init__.py index d8b3397a3..36cc17032 100644 --- a/__init__.py +++ b/__init__.py @@ -3,14 +3,13 @@ from pystencils.data_types import TypedSymbol from pystencils.slicing import makeSlice from pystencils.kernelcreation import createKernel, createIndexedKernel from pystencils.display_utils import showCode, toDot -from pystencils.equationcollection import EquationCollection -from sympy.codegen.ast import Assignment as Assign - +from pystencils.assignment_collection import AssignmentCollection +from pystencils.assignment import Assignment __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions', 'TypedSymbol', 'makeSlice', 'createKernel', 'createIndexedKernel', 'showCode', 'toDot', - 'EquationCollection', - 'Assign'] + 'AssignmentCollection', + 'Assignment'] diff --git a/assignment.py b/assignment.py new file mode 100644 index 000000000..135a5d126 --- /dev/null +++ b/assignment.py @@ -0,0 +1,14 @@ +from sympy.codegen.ast import Assignment +from sympy.printing.latex import LatexPrinter + +__all__ = ['Assignment'] + + +def print_assignment_latex(printer, expr): + """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer""" + printed_lhs = printer.doprint(expr.lhs) + printed_rhs = printer.doprint(expr.rhs) + return f"{printed_lhs} \leftarrow {printed_rhs}" + + +LatexPrinter._print_Assignment = print_assignment_latex diff --git a/assignment_collection/__init__.py b/assignment_collection/__init__.py new file mode 100644 index 000000000..a71a7d05c --- /dev/null +++ b/assignment_collection/__init__.py @@ -0,0 +1,2 @@ +from pystencils.assignment_collection.assignment_collection import AssignmentCollection +from pystencils.assignment_collection.simplificationstrategy import SimplificationStrategy diff --git a/equationcollection/equationcollection.py b/assignment_collection/assignment_collection.py similarity index 81% rename from equationcollection/equationcollection.py rename to assignment_collection/assignment_collection.py index 1609c95d6..ea95cf6b9 100644 --- a/equationcollection/equationcollection.py +++ b/assignment_collection/assignment_collection.py @@ -1,18 +1,19 @@ import sympy as sp from copy import copy +from pystencils.assignment import Assignment from pystencils.sympyextensions import fastSubs, countNumberOfOperations, sortEquationsTopologically -class EquationCollection(object): +class AssignmentCollection(object): """ A collection of equations with subexpression definitions, also represented as equations, - that are used in the main equations. EquationCollections can be passed to simplification methods. + that are used in the main equations. AssignmentCollection can be passed to simplification methods. These simplification methods can change the subexpressions, but the number and left hand side of the main equations themselves is not altered. Additionally a dictionary of simplification hints is stored, which are set by the functions that create equation collections to transport information to the simplification system. - :ivar mainEquations: list of sympy equations + :ivar mainAssignments: list of sympy equations :ivar subexpressions: list of sympy equations defining subexpressions used in main equations :ivar simplificationHints: dictionary that is used to annotate the equation collection with hints that are used by the simplification system. See documentation of the simplification rules for @@ -22,7 +23,7 @@ class EquationCollection(object): # ----------------------------------------- Creation --------------------------------------------------------------- def __init__(self, equations, subExpressions, simplificationHints=None, subexpressionSymbolNameGenerator=None): - self.mainEquations = equations + self.mainAssignments = equations self.subexpressions = subExpressions if simplificationHints is None: @@ -39,15 +40,15 @@ class EquationCollection(object): def mainTerms(self): return [] - def copy(self, mainEquations=None, subexpressions=None): + def copy(self, mainAssignments=None, subexpressions=None): res = copy(self) res.simplificationHints = self.simplificationHints.copy() res.subexpressionSymbolNameGenerator = copy(self.subexpressionSymbolNameGenerator) - if mainEquations is not None: - res.mainEquations = mainEquations + if mainAssignments is not None: + res.mainAssignments = mainAssignments else: - res.mainEquations = self.mainEquations.copy() + res.mainAssignments = self.mainAssignments.copy() if subexpressions is not None: res.subexpressions = subexpressions @@ -64,13 +65,13 @@ class EquationCollection(object): """ if substituteOnLhs: newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions] - newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations] + newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainAssignments] else: - newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.subexpressions] - newEquations = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.mainEquations] + newSubexpressions = [Assignment(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.subexpressions] + newEquations = [Assignment(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.mainAssignments] if addSubstitutionsAsSubexpressions: - newSubexpressions = [sp.Eq(b, a) for a, b in substitutionDict.items()] + newSubexpressions + newSubexpressions = [Assignment(b, a) for a, b in substitutionDict.items()] + newSubexpressions newSubexpressions = sortEquationsTopologically(newSubexpressions) return self.copy(newEquations, newSubexpressions) @@ -86,7 +87,7 @@ class EquationCollection(object): @property def allEquations(self): """Subexpression and main equations in one sequence""" - return self.subexpressions + self.mainEquations + return self.subexpressions + self.mainAssignments @property def freeSymbols(self): @@ -100,30 +101,30 @@ class EquationCollection(object): def boundSymbols(self): """Set of all symbols which occur on left-hand-sides i.e. all symbols which are defined.""" boundSymbolsSet = set([eq.lhs for eq in self.allEquations]) - assert len(boundSymbolsSet) == len(self.subexpressions) + len(self.mainEquations), \ + assert len(boundSymbolsSet) == len(self.subexpressions) + len(self.mainAssignments), \ "Not in SSA form - same symbol assigned multiple times" return boundSymbolsSet @property def definedSymbols(self): """All symbols that occur as left-hand-sides of the main equations""" - return set([eq.lhs for eq in self.mainEquations]) + return set([eq.lhs for eq in self.mainAssignments]) @property def operationCount(self): """See :func:`countNumberOfOperations` """ return countNumberOfOperations(self.allEquations, onlyType=None) - def get(self, symbols, fromMainEquationsOnly=False): + def get(self, symbols, frommainAssignmentsOnly=False): """Return the equations which have symbols as left hand sides""" if not hasattr(symbols, "__len__"): symbols = list(symbols) symbols = set(symbols) - if not fromMainEquationsOnly: + if not frommainAssignmentsOnly: eqsToSearchIn = self.allEquations else: - eqsToSearchIn = self.mainEquations + eqsToSearchIn = self.mainAssignments return [eq for eq in eqsToSearchIn if eq.lhs in symbols] @@ -145,19 +146,19 @@ class EquationCollection(object): if len(self.subexpressions) > 0: result += "<div>Subexpressions:</div>" result += makeHtmlEquationTable(self.subexpressions) - result += "<div>Main Equations:</div>" - result += makeHtmlEquationTable(self.mainEquations) + result += "<div>Main Assignments:</div>" + result += makeHtmlEquationTable(self.mainAssignments) return result def __repr__(self): - return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainEquations]) + return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainAssignments]) def __str__(self): result = "Subexpressions\n" for eq in self.subexpressions: result += str(eq) + "\n" - result += "Main Equations\n" - for eq in self.mainEquations: + result += "Main Assignments\n" + for eq in self.mainAssignments: result += str(eq) + "\n" return result @@ -165,8 +166,8 @@ class EquationCollection(object): def merge(self, other): """Returns a new collection which contains self and other. Subexpressions are renamed if they clash.""" - ownDefs = set([e.lhs for e in self.mainEquations]) - otherDefs = set([e.lhs for e in other.mainEquations]) + ownDefs = set([e.lhs for e in self.mainAssignments]) + otherDefs = set([e.lhs for e in other.mainAssignments]) assert len(ownDefs.intersection(otherDefs)) == 0, "Cannot merge, since both collection define the same symbols" ownSubexpressionSymbols = {e.lhs: e.rhs for e in self.subexpressions} @@ -180,14 +181,14 @@ class EquationCollection(object): else: # different definition - a new name has to be introduced newLhs = next(self.subexpressionSymbolNameGenerator) - newEq = sp.Eq(newLhs, fastSubs(otherSubexpressionEq.rhs, substitutionDict)) + newEq = Assignment(newLhs, fastSubs(otherSubexpressionEq.rhs, substitutionDict)) processedOtherSubexpressionEquations.append(newEq) substitutionDict[otherSubexpressionEq.lhs] = newLhs else: processedOtherSubexpressionEquations.append(fastSubs(otherSubexpressionEq, substitutionDict)) - processedOtherMainEquations = [fastSubs(eq, substitutionDict) for eq in other.mainEquations] - return self.copy(self.mainEquations + processedOtherMainEquations, + processedOthermainAssignments = [fastSubs(eq, substitutionDict) for eq in other.mainAssignments] + return self.copy(self.mainAssignments + processedOthermainAssignments, self.subexpressions + processedOtherSubexpressionEquations) def getDependentSymbols(self, symbolSequence): @@ -226,28 +227,28 @@ class EquationCollection(object): newEquations.append(eq) newSubExpr = [eq for eq in self.subexpressions if eq.lhs in dependentSymbols and eq.lhs not in symbolsToExtract] - return EquationCollection(newEquations, newSubExpr) + return AssignmentCollection(newEquations, newSubExpr) def newWithoutUnusedSubexpressions(self): """Returns a new equation collection containing only the subexpressions that are used/referenced in the equations""" - allLhs = [eq.lhs for eq in self.mainEquations] + allLhs = [eq.lhs for eq in self.mainAssignments] return self.extract(allLhs) def appendToSubexpressions(self, rhs, lhs=None, topologicalSort=True): if lhs is None: lhs = sp.Dummy() - eq = sp.Eq(lhs, rhs) + eq = Assignment(lhs, rhs) self.subexpressions.append(eq) if topologicalSort: - self.topologicalSort(subexpressions=True, mainEquations=False) + self.topologicalSort(subexpressions=True, mainAssignments=False) return lhs - def topologicalSort(self, subexpressions=True, mainEquations=True): + def topologicalSort(self, subexpressions=True, mainAssignments=True): if subexpressions: self.subexpressions = sortEquationsTopologically(self.subexpressions) - if mainEquations: - self.mainEquations = sortEquationsTopologically(self.mainEquations) + if mainAssignments: + self.mainAssignments = sortEquationsTopologically(self.mainAssignments) def insertSubexpression(self, symbol): newSubexpressions = [] @@ -260,8 +261,8 @@ class EquationCollection(object): if subsDict is None: return self - newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in newSubexpressions] - newEqs = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in self.mainEquations] + newSubexpressions = [Assignment(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in newSubexpressions] + newEqs = [Assignment(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in self.mainAssignments] return self.copy(newEqs, newSubexpressions) def insertSubexpressions(self, subexpressionSymbolsToKeep=set()): @@ -286,7 +287,7 @@ class EquationCollection(object): else: subsDict[subExpr[i].lhs] = subExpr[i].rhs - newEq = [fastSubs(eq, subsDict) for eq in self.mainEquations] + newEq = [fastSubs(eq, subsDict) for eq in self.mainAssignments] return self.copy(newEq, keptSubexpressions) def lambdify(self, symbols, module=None, fixedSymbols={}): @@ -296,7 +297,7 @@ class EquationCollection(object): :param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g. 'numpy' :param fixedSymbols: dictionary with substitutions, that are applied before lambdification """ - eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations + eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainAssignments lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs} def f(*args, **kwargs): diff --git a/assignment_collection/simplifications.py b/assignment_collection/simplifications.py new file mode 100644 index 000000000..7d0eab53d --- /dev/null +++ b/assignment_collection/simplifications.py @@ -0,0 +1,87 @@ +import sympy as sp + +from pystencils import Assignment, AssignmentCollection +from pystencils.sympyextensions import replaceAdditive + + +def sympyCseOnEquationList(eqs): + ec = AssignmentCollection(eqs, []) + return sympyCSE(ec).allEquations + + +def sympyCSE(assignment_collection): + """ + Searches for common subexpressions inside the equation collection, in both the existing subexpressions as well + as the equations themselves. It uses the sympy subexpression detection to do this. Return a new equation collection + with the additional subexpressions found + """ + symbolGen = assignment_collection.subexpressionSymbolNameGenerator + replacements, newEq = sp.cse(assignment_collection.subexpressions + assignment_collection.mainAssignments, + symbols=symbolGen) + replacementEqs = [Assignment(*r) for r in replacements] + + modifiedSubexpressions = newEq[:len(assignment_collection.subexpressions)] + modifiedUpdateEquations = newEq[len(assignment_collection.subexpressions):] + + newSubexpressions = replacementEqs + modifiedSubexpressions + topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions]) + newSubexpressions = [Assignment(a[0], a[1]) for a in topologicallySortedPairs] + + return assignment_collection.copy(modifiedUpdateEquations, newSubexpressions) + + +def applyOnAllEquations(assignment_collection, operation): + """Applies sympy expand operation to all equations in collection""" + result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.mainAssignments] + return assignment_collection.copy(result) + + +def applyOnAllSubexpressions(assignment_collection, operation): + result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.subexpressions] + return assignment_collection.copy(assignment_collection.mainAssignments, result) + + +def subexpressionSubstitutionInExistingSubexpressions(assignment_collection): + """Goes through the subexpressions list and replaces the term in the following subexpressions""" + result = [] + for outerCtr, s in enumerate(assignment_collection.subexpressions): + newRhs = s.rhs + for innerCtr in range(outerCtr): + subExpr = assignment_collection.subexpressions[innerCtr] + newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) + newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs) + result.append(Assignment(s.lhs, newRhs)) + + return assignment_collection.copy(assignment_collection.mainAssignments, result) + + +def subexpressionSubstitutionInmainAssignments(assignment_collection): + """Replaces already existing subexpressions in the equations of the assignment_collection""" + result = [] + for s in assignment_collection.mainAssignments: + newRhs = s.rhs + for subExpr in assignment_collection.subexpressions: + newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) + result.append(Assignment(s.lhs, newRhs)) + return assignment_collection.copy(result) + + +def addSubexpressionsForDivisions(assignment_collection): + """Introduces subexpressions for all divisions which have no constant in the denominator. + e.g. :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced.""" + divisors = set() + + def searchDivisors(term): + if term.func == sp.Pow: + if term.exp.is_integer and term.exp.is_number and term.exp < 0: + divisors.add(term) + else: + for a in term.args: + searchDivisors(a) + + for eq in assignment_collection.allEquations: + searchDivisors(eq.rhs) + + newSymbolGen = assignment_collection.subexpressionSymbolNameGenerator + substitutions = {divisor: newSymbol for newSymbol, divisor in zip(newSymbolGen, divisors)} + return assignment_collection.copyWithSubstitutionsApplied(substitutions, True) diff --git a/equationcollection/simplificationstrategy.py b/assignment_collection/simplificationstrategy.py similarity index 87% rename from equationcollection/simplificationstrategy.py rename to assignment_collection/simplificationstrategy.py index 38f4c3f8b..3d8cdd62f 100644 --- a/equationcollection/simplificationstrategy.py +++ b/assignment_collection/simplificationstrategy.py @@ -30,11 +30,11 @@ class SimplificationStrategy(object): updateRule = t(updateRule) return updateRule - def __call__(self, equationCollection): + def __call__(self, assignment_collection): """Same as apply""" - return self.apply(equationCollection) + return self.apply(assignment_collection) - def createSimplificationReport(self, equationCollection): + def createSimplificationReport(self, assignment_collection): """ Returns a simplification report containing the number of operations at each simplification stage, together with the run-time the simplification took. @@ -72,25 +72,25 @@ class SimplificationStrategy(object): import timeit report = Report() - op = equationCollection.operationCount + op = assignment_collection.operationCount total = op['adds'] + op['muls'] + op['divs'] report.add(ReportElement("OriginalTerm", '-', op['adds'], op['muls'], op['divs'], total)) for t in self._rules: startTime = timeit.default_timer() - equationCollection = t(equationCollection) + assignment_collection = t(assignment_collection) endTime = timeit.default_timer() - op = equationCollection.operationCount + op = assignment_collection.operationCount timeStr = "%.2f ms" % ((endTime - startTime) * 1000,) total = op['adds'] + op['muls'] + op['divs'] report.add(ReportElement(t.__name__, timeStr, op['adds'], op['muls'], op['divs'], total)) return report - def showIntermediateResults(self, equationCollection, symbols=None): + def showIntermediateResults(self, assignment_collection, symbols=None): class IntermediateResults: def __init__(self, strategy, eqColl, resSyms): self.strategy = strategy - self.equationCollection = eqColl + self.assignment_collection = eqColl self.restrictSymbols = resSyms def __str__(self): @@ -102,8 +102,8 @@ class SimplificationStrategy(object): text += (" " * 3 + (" " * 3).join(str(eqColl).splitlines(True))) return text - result = printEqCollection("Initial Version", self.equationCollection) - eqColl = self.equationCollection + result = printEqCollection("Initial Version", self.assignment_collection) + eqColl = self.assignment_collection for rule in self.strategy.rules: eqColl = rule(eqColl) result += printEqCollection(rule.__name__, eqColl) @@ -119,14 +119,14 @@ class SimplificationStrategy(object): text += "</div>" return text - result = printEqCollection("Initial Version", self.equationCollection) - eqColl = self.equationCollection + result = printEqCollection("Initial Version", self.assignment_collection) + eqColl = self.assignment_collection for rule in self.strategy.rules: eqColl = rule(eqColl) result += printEqCollection(rule.__name__, eqColl) return result - return IntermediateResults(self, equationCollection, symbols) + return IntermediateResults(self, assignment_collection, symbols) def __repr__(self): result = "Simplification Strategy:\n" diff --git a/backends/dot.py b/backends/dot.py index 8797c49f0..e78ac1bb0 100644 --- a/backends/dot.py +++ b/backends/dot.py @@ -91,33 +91,14 @@ def dotprint(node, view=False, short=False, full=False, **kwargs): :param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph :return: string in DOT format """ - nodeToStrFunction = repr + node_to_str_function = repr if short: - nodeToStrFunction = __shortened + node_to_str_function = __shortened elif full: - nodeToStrFunction = lambda expr: repr(type(expr)) + repr(expr) - printer = DotPrinter(nodeToStrFunction, full, **kwargs) + node_to_str_function = lambda expr: repr(type(expr)) + repr(expr) + printer = DotPrinter(node_to_str_function, full, **kwargs) dot = printer.doprint(node) if view: return graphviz.Source(dot) return dot - -if __name__ == "__main__": - from pystencils import Field - import sympy as sp - imgField = Field.createGeneric('I', - spatialDimensions=2, # 2D image - indexDimensions=1) # multiple values per pixel: e.g. RGB - w1, w2 = sp.symbols("w_1 w_2") - sobelX = -w2 * imgField[-1, 0](1) - w1 * imgField[-1, -1](1) - w1 * imgField[-1, +1](1) \ - + w2 * imgField[+1, 0](1) + w1 * imgField[+1, -1](1) - w1 * imgField[+1, +1](1) - sobelX - - dstField = Field.createGeneric('dst', spatialDimensions=2, indexDimensions=0) - updateRule = sp.Eq(dstField[0, 0], sobelX) - updateRule - - from pystencils import createKernel - ast = createKernel([updateRule]) - print(dotprint(ast, short=True)) diff --git a/boundaries/boundaryconditions.py b/boundaries/boundaryconditions.py index 0a799cdac..1cc67d17c 100644 --- a/boundaries/boundaryconditions.py +++ b/boundaries/boundaryconditions.py @@ -1,4 +1,4 @@ -import sympy as sp +from pystencils import Assignment from pystencils.boundaries.boundaryhandling import BoundaryOffsetInfo @@ -52,13 +52,13 @@ class Neumann(Boundary): neighbor = BoundaryOffsetInfo.offsetFromDir(directionSymbol, field.spatialDimensions) if field.indexDimensions == 0: - return [sp.Eq(field[neighbor], field.center)] + return [Assignment(field[neighbor], field.center)] else: from itertools import product if not field.hasFixedIndexShape: raise NotImplementedError("Neumann boundary works only for fields with fixed index shape") indexIter = product(*(range(i) for i in field.indexShape)) - return [sp.Eq(field[neighbor](*idx), field(*idx)) for idx in indexIter] + return [Assignment(field[neighbor](*idx), field(*idx)) for idx in indexIter] def __hash__(self): # All boundaries of these class behave equal -> should also be equal diff --git a/boundaries/boundaryhandling.py b/boundaries/boundaryhandling.py index 7787346c3..76999a549 100644 --- a/boundaries/boundaryhandling.py +++ b/boundaries/boundaryhandling.py @@ -1,5 +1,6 @@ import numpy as np import sympy as sp +from pystencils.assignment import Assignment from pystencils import Field, TypedSymbol, createIndexedKernel from pystencils.backends.cbackend import CustomCppCode from pystencils.boundaries.createindexlist import numpyDataTypeForBoundaryObject, createBoundaryIndexArray @@ -363,6 +364,6 @@ def createBoundaryKernel(field, indexField, stencil, boundaryFunctor, target='cp elements = [BoundaryOffsetInfo(stencil)] indexArrDtype = indexField.dtype.numpyDtype dirSymbol = TypedSymbol("dir", indexArrDtype.fields['dir'][0]) - elements += [sp.Eq(dirSymbol, indexField[0]('dir'))] + elements += [Assignment(dirSymbol, indexField[0]('dir'))] elements += boundaryFunctor(field, directionSymbol=dirSymbol, indexField=indexField) return createIndexedKernel(elements, [indexField], target=target, cpuOpenMP=openMP) diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index 0c96c1f45..c4c2833b4 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -33,7 +33,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol='double', :return: :class:`pystencils.ast.KernelFunction` node """ - def typeSymbol(term): + def type_symbol(term): if isinstance(term, Field.Access) or isinstance(term, TypedSymbol): return term elif isinstance(term, sp.Symbol): @@ -58,7 +58,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol='double', code.target = 'cpu' if splitGroups: - typedSplitGroups = [[typeSymbol(s) for s in splitGroup] for splitGroup in splitGroups] + typedSplitGroups = [[type_symbol(s) for s in splitGroup] for splitGroup in splitGroups] splitInnerLoop(code, typedSplitGroups) basePointerInfo = [['spatialInner0'], ['spatialInner1']] if len(loopOrder) >= 2 else [['spatialInner0']] diff --git a/equationcollection/__init__.py b/equationcollection/__init__.py deleted file mode 100644 index 434f02e21..000000000 --- a/equationcollection/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from pystencils.equationcollection.equationcollection import EquationCollection -from pystencils.equationcollection.simplificationstrategy import SimplificationStrategy diff --git a/equationcollection/simplifications.py b/equationcollection/simplifications.py deleted file mode 100644 index 0e1e92f48..000000000 --- a/equationcollection/simplifications.py +++ /dev/null @@ -1,87 +0,0 @@ -import sympy as sp - -from pystencils.equationcollection.equationcollection import EquationCollection -from pystencils.sympyextensions import replaceAdditive - - -def sympyCseOnEquationList(eqs): - ec = EquationCollection(eqs, []) - return sympyCSE(ec).allEquations - - -def sympyCSE(equationCollection): - """ - Searches for common subexpressions inside the equation collection, in both the existing subexpressions as well - as the equations themselves. It uses the sympy subexpression detection to do this. Return a new equation collection - with the additional subexpressions found - """ - symbolGen = equationCollection.subexpressionSymbolNameGenerator - replacements, newEq = sp.cse(equationCollection.subexpressions + equationCollection.mainEquations, - symbols=symbolGen) - replacementEqs = [sp.Eq(*r) for r in replacements] - - modifiedSubexpressions = newEq[:len(equationCollection.subexpressions)] - modifiedUpdateEquations = newEq[len(equationCollection.subexpressions):] - - newSubexpressions = replacementEqs + modifiedSubexpressions - topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions]) - newSubexpressions = [sp.Eq(a[0], a[1]) for a in topologicallySortedPairs] - - return equationCollection.copy(modifiedUpdateEquations, newSubexpressions) - - -def applyOnAllEquations(equationCollection, operation): - """Applies sympy expand operation to all equations in collection""" - result = [sp.Eq(eq.lhs, operation(eq.rhs)) for eq in equationCollection.mainEquations] - return equationCollection.copy(result) - - -def applyOnAllSubexpressions(equationCollection, operation): - result = [sp.Eq(eq.lhs, operation(eq.rhs)) for eq in equationCollection.subexpressions] - return equationCollection.copy(equationCollection.mainEquations, result) - - -def subexpressionSubstitutionInExistingSubexpressions(equationCollection): - """Goes through the subexpressions list and replaces the term in the following subexpressions""" - result = [] - for outerCtr, s in enumerate(equationCollection.subexpressions): - newRhs = s.rhs - for innerCtr in range(outerCtr): - subExpr = equationCollection.subexpressions[innerCtr] - newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) - newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs) - result.append(sp.Eq(s.lhs, newRhs)) - - return equationCollection.copy(equationCollection.mainEquations, result) - - -def subexpressionSubstitutionInMainEquations(equationCollection): - """Replaces already existing subexpressions in the equations of the equationCollection""" - result = [] - for s in equationCollection.mainEquations: - newRhs = s.rhs - for subExpr in equationCollection.subexpressions: - newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) - result.append(sp.Eq(s.lhs, newRhs)) - return equationCollection.copy(result) - - -def addSubexpressionsForDivisions(equationCollection): - """Introduces subexpressions for all divisions which have no constant in the denominator. - e.g. :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced.""" - divisors = set() - - def searchDivisors(term): - if term.func == sp.Pow: - if term.exp.is_integer and term.exp.is_number and term.exp < 0: - divisors.add(term) - else: - for a in term.args: - searchDivisors(a) - - for eq in equationCollection.allEquations: - searchDivisors(eq.rhs) - - newSymbolGen = equationCollection.subexpressionSymbolNameGenerator - substitutions = {divisor: newSymbol for newSymbol, divisor in zip(newSymbolGen, divisors)} - return equationCollection.copyWithSubstitutionsApplied(substitutions, True) diff --git a/field.py b/field.py index 24a51755f..d310f9b77 100644 --- a/field.py +++ b/field.py @@ -5,6 +5,7 @@ import sympy as sp from sympy.core.cache import cacheit from sympy.tensor import IndexedBase +from pystencils.assignment import Assignment from pystencils.alignedarray import aligned_empty from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFromString, StructType from pystencils.sympyextensions import isIntegerSequence @@ -71,10 +72,10 @@ class Field(object): >>> src = Field.createGeneric("src", spatialDimensions=2, indexDimensions=1) >>> dst = Field.createGeneric("dst", spatialDimensions=2, indexDimensions=1) >>> for i, offset in enumerate(stencil): - ... sp.Eq(dst[0,0](i), src[-offset](i)) - Eq(dst_C^0, src_C^0) - Eq(dst_C^1, src_S^1) - Eq(dst_C^2, src_N^2) + ... Assignment(dst[0,0](i), src[-offset](i)) + Assignment(dst_C^0, src_C^0) + Assignment(dst_C^1, src_S^1) + Assignment(dst_C^2, src_N^2) """ @staticmethod @@ -437,22 +438,22 @@ def extractCommonSubexpressions(equations): them in a topologically sorted order, ready for evaluation. Usually called before list of equations is passed to :func:`createKernel` """ - replacements, newEq = sp.cse(equations) + replacements, new_eq = sp.cse(equations) # Workaround for older sympy versions: here subexpressions (temporary = True) are extracted # which leads to problems in Piecewise functions which have to a default case indicated by True - symbolsEqualToTrue = {r[0]: True for r in replacements if r[1] is sp.true} + symbols_equal_to_true = {r[0]: True for r in replacements if r[1] is sp.true} - replacementEqs = [sp.Eq(*r) for r in replacements if r[1] is not sp.true] - equations = replacementEqs + newEq - topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equations]) - equations = [sp.Eq(a[0], a[1].subs(symbolsEqualToTrue)) for a in topologicallySortedPairs] + replacement_eqs = [Assignment(*r) for r in replacements if r[1] is not sp.true] + equations = replacement_eqs + new_eq + topologically_sorted_pairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equations]) + equations = [Assignment(a[0], a[1].subs(symbols_equal_to_true)) for a in topologically_sorted_pairs] return equations def getLayoutFromStrides(strides, indexDimensionIds=[]): coordinates = list(range(len(strides))) - relevantStrides = [stride for i, stride in enumerate(strides) if i not in indexDimensionIds] - result = [x for (y, x) in sorted(zip(relevantStrides, coordinates), key=lambda pair: pair[0], reverse=True)] + relevant_strides = [stride for i, stride in enumerate(strides) if i not in indexDimensionIds] + result = [x for (y, x) in sorted(zip(relevant_strides, coordinates), key=lambda pair: pair[0], reverse=True)] return normalizeLayout(result) diff --git a/finitedifferences.py b/finitedifferences.py index 039a603e1..8dba62707 100644 --- a/finitedifferences.py +++ b/finitedifferences.py @@ -1,7 +1,7 @@ import numpy as np import sympy as sp -from pystencils.equationcollection import EquationCollection +from pystencils.assignment_collection import AssignmentCollection from pystencils.field import Field from pystencils.transformations import fastSubs from pystencils.derivative import Diff @@ -355,8 +355,8 @@ class Discretization2ndOrder: return [self(e) for e in expr] elif isinstance(expr, sp.Matrix): return expr.applyfunc(self.__call__) - elif isinstance(expr, EquationCollection): - return expr.copy(mainEquations=[e for e in expr.mainEquations], + elif isinstance(expr, AssignmentCollection): + return expr.copy(mainAssignments=[e for e in expr.mainAssignments], subexpressions=[e for e in expr.subexpressions]) transientTerms = expr.atoms(Transient) diff --git a/gpucuda/periodicity.py b/gpucuda/periodicity.py index 039f21799..5009fce89 100644 --- a/gpucuda/periodicity.py +++ b/gpucuda/periodicity.py @@ -1,6 +1,6 @@ import sympy as sp import numpy as np -from pystencils import Field +from pystencils import Field, Assignment from pystencils.slicing import normalizeSlice, getPeriodicBoundarySrcDstSlices from pystencils.gpucuda import makePythonFunction from pystencils.gpucuda.kernelcreation import createCUDAKernel @@ -20,7 +20,7 @@ def createCopyKernel(domainSize, fromSlice, toSlice, indexDimensions=0, indexDim updateEqs = [] for i in range(indexDimShape): - eq = sp.Eq(f(i), f[tuple(offset)](i)) + eq = Assignment(f(i), f[tuple(offset)](i)) updateEqs.append(eq) ast = createCUDAKernel(updateEqs, iterationSlice=toSlice) diff --git a/kernelcreation.py b/kernelcreation.py index 7d27ed6c9..9c84347d9 100644 --- a/kernelcreation.py +++ b/kernelcreation.py @@ -1,4 +1,4 @@ -from pystencils.equationcollection import EquationCollection +from pystencils.assignment_collection import AssignmentCollection from pystencils.gpucuda.indexing import indexingCreatorFromParams @@ -7,7 +7,7 @@ def createKernel(equations, target='cpu', dataType="double", iterationSlice=None gpuIndexing='block', gpuIndexingParams={}): """ Creates abstract syntax tree (AST) of kernel, using a list of update equations. - :param equations: either be a plain list of equations or a EquationCollection object + :param equations: either be a plain list of equations or a AssignmentCollection object :param target: 'cpu', 'llvm' or 'gpu' :param dataType: data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name to type @@ -32,7 +32,7 @@ def createKernel(equations, target='cpu', dataType="double", iterationSlice=None # ---- Normalizing parameters splitGroups = () - if isinstance(equations, EquationCollection): + if isinstance(equations, AssignmentCollection): if 'splitGroups' in equations.simplificationHints: splitGroups = equations.simplificationHints['splitGroups'] equations = equations.allEquations @@ -83,7 +83,7 @@ def createIndexedKernel(equations, indexFields, target='cpu', dataType="double", coordinateNames: name of the coordinate fields in the struct data type """ - if isinstance(equations, EquationCollection): + if isinstance(equations, AssignmentCollection): equations = equations.allEquations if target == 'cpu': from pystencils.cpu import createIndexedKernel diff --git a/llvm/llvm.py b/llvm/llvm.py index 6529c699e..08ad220bb 100644 --- a/llvm/llvm.py +++ b/llvm/llvm.py @@ -9,7 +9,7 @@ from pystencils.llvm.control_flow import Loop from pystencils.data_types import createType, to_llvm_type, getTypeOfExpression, collateTypes, \ createCompositeTypeFromString from sympy import Indexed -from sympy.codegen.ast import Assignment +from pystencils.assignment import Assignment def generateLLVM(ast_node, module=None, builder=None): diff --git a/sympyextensions.py b/sympyextensions.py index 6e29cabf9..d407d8f61 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -6,6 +6,7 @@ import warnings import sympy as sp from pystencils.data_types import getTypeOfExpression, getBaseType +from pystencils.assignment import Assignment def prod(seq): @@ -264,7 +265,7 @@ def replaceSecondOrderProducts(expr, searchSymbols, positive=None, replaceMixed= mixedSymbol = sp.Symbol(mixedSymbolName.replace("_", "")) if mixedSymbol not in mixedSymbolsReplaced: mixedSymbolsReplaced.add(mixedSymbol) - replaceMixed.append(sp.Eq(mixedSymbol, u + sign * v)) + replaceMixed.append(Assignment(mixedSymbol, u + sign * v)) else: mixedSymbol = u + sign * v return sp.Rational(1, 2) * sign * otherFactors * (mixedSymbol ** 2 - u ** 2 - v ** 2) @@ -431,7 +432,7 @@ def countNumberOfOperations(term, onlyType='real'): for operationName in result.keys(): result[operationName] += r[operationName] return result - elif isinstance(term, sp.Eq): + elif isinstance(term, Assignment): term = term.rhs term = term.evalf() @@ -538,7 +539,7 @@ def getSymmetricPart(term, vars): def sortEquationsTopologically(equationSequence): res = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equationSequence]) - return [sp.Eq(a, b) for a, b in res] + return [Assignment(a, b) for a, b in res] def getEquationsFromFunction(func, **kwargs): @@ -559,7 +560,7 @@ def getEquationsFromFunction(func, **kwargs): ... S.neighbors @= f[0,1] + f[1,0] ... g[0,0] @= S.neighbors + f[0,0] >>> getEquationsFromFunction(myKernel) - [Eq(neighbors, f_E + f_N), Eq(g_C, f_C + neighbors)] + [Assignment(neighbors, f_E + f_N), Assignment(g_C, f_C + neighbors)] """ import inspect import re @@ -590,7 +591,7 @@ def getEquationsFromFunction(func, **kwargs): code = "".join(sourceLines[1:]) result = [] localsDict = {'_result': result, - 'Eq': sp.Eq, + 'Eq': Assignment, 'S': SymbolCreator()} localsDict.update(kwargs) globalsDict = inspect.stack()[1][0].f_globals.copy() diff --git a/transformations/transformations.py b/transformations/transformations.py index 299201f0c..5c9c9d1fe 100644 --- a/transformations/transformations.py +++ b/transformations/transformations.py @@ -7,6 +7,7 @@ import sympy as sp from sympy.logic.boolalg import Boolean from sympy.tensor import IndexedBase +from pystencils.assignment import Assignment from pystencils.field import Field, FieldType, offsetComponentToDirectionString from pystencils.data_types import TypedSymbol, createType, PointerType, StructType, getBaseType, castFunc from pystencils.slicing import normalizeSlice @@ -737,7 +738,7 @@ def typeAllEquations(eqs, typeForSymbol): def visit(object): if isinstance(object, list) or isinstance(object, tuple): return [visit(e) for e in object] - if isinstance(object, sp.Eq) or isinstance(object, ast.SympyAssignment): + if isinstance(object, sp.Eq) or isinstance(object, ast.SympyAssignment) or isinstance(object, Assignment): newLhs = processLhs(object.lhs) newRhs = processRhs(object.rhs) return ast.SympyAssignment(newLhs, newRhs) -- GitLab