diff --git a/__init__.py b/__init__.py index d8b3397a36a7d0ecde0364302e0ae9f9bebf6e59..36cc170327d526897bc65ce00205c599de7cfb10 100644 --- a/__init__.py +++ b/__init__.py @@ -3,14 +3,13 @@ from pystencils.data_types import TypedSymbol from pystencils.slicing import makeSlice from pystencils.kernelcreation import createKernel, createIndexedKernel from pystencils.display_utils import showCode, toDot -from pystencils.equationcollection import EquationCollection -from sympy.codegen.ast import Assignment as Assign - +from pystencils.assignment_collection import AssignmentCollection +from pystencils.assignment import Assignment __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions', 'TypedSymbol', 'makeSlice', 'createKernel', 'createIndexedKernel', 'showCode', 'toDot', - 'EquationCollection', - 'Assign'] + 'AssignmentCollection', + 'Assignment'] diff --git a/assignment.py b/assignment.py new file mode 100644 index 0000000000000000000000000000000000000000..135a5d1264c095a8a6955e9a75d29b62e20048d8 --- /dev/null +++ b/assignment.py @@ -0,0 +1,14 @@ +from sympy.codegen.ast import Assignment +from sympy.printing.latex import LatexPrinter + +__all__ = ['Assignment'] + + +def print_assignment_latex(printer, expr): + """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer""" + printed_lhs = printer.doprint(expr.lhs) + printed_rhs = printer.doprint(expr.rhs) + return f"{printed_lhs} \leftarrow {printed_rhs}" + + +LatexPrinter._print_Assignment = print_assignment_latex diff --git a/assignment_collection/__init__.py b/assignment_collection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a71a7d05cec3af128062e31ebbdd1c246159ef51 --- /dev/null +++ b/assignment_collection/__init__.py @@ -0,0 +1,2 @@ +from pystencils.assignment_collection.assignment_collection import AssignmentCollection +from pystencils.assignment_collection.simplificationstrategy import SimplificationStrategy diff --git a/equationcollection/equationcollection.py b/assignment_collection/assignment_collection.py similarity index 81% rename from equationcollection/equationcollection.py rename to assignment_collection/assignment_collection.py index 1609c95d61e5602c8f9cadc03ac130f8dd7fa2c9..ea95cf6b9890924485af752493aac096aa11a852 100644 --- a/equationcollection/equationcollection.py +++ b/assignment_collection/assignment_collection.py @@ -1,18 +1,19 @@ import sympy as sp from copy import copy +from pystencils.assignment import Assignment from pystencils.sympyextensions import fastSubs, countNumberOfOperations, sortEquationsTopologically -class EquationCollection(object): +class AssignmentCollection(object): """ A collection of equations with subexpression definitions, also represented as equations, - that are used in the main equations. EquationCollections can be passed to simplification methods. + that are used in the main equations. AssignmentCollection can be passed to simplification methods. These simplification methods can change the subexpressions, but the number and left hand side of the main equations themselves is not altered. Additionally a dictionary of simplification hints is stored, which are set by the functions that create equation collections to transport information to the simplification system. - :ivar mainEquations: list of sympy equations + :ivar mainAssignments: list of sympy equations :ivar subexpressions: list of sympy equations defining subexpressions used in main equations :ivar simplificationHints: dictionary that is used to annotate the equation collection with hints that are used by the simplification system. See documentation of the simplification rules for @@ -22,7 +23,7 @@ class EquationCollection(object): # ----------------------------------------- Creation --------------------------------------------------------------- def __init__(self, equations, subExpressions, simplificationHints=None, subexpressionSymbolNameGenerator=None): - self.mainEquations = equations + self.mainAssignments = equations self.subexpressions = subExpressions if simplificationHints is None: @@ -39,15 +40,15 @@ class EquationCollection(object): def mainTerms(self): return [] - def copy(self, mainEquations=None, subexpressions=None): + def copy(self, mainAssignments=None, subexpressions=None): res = copy(self) res.simplificationHints = self.simplificationHints.copy() res.subexpressionSymbolNameGenerator = copy(self.subexpressionSymbolNameGenerator) - if mainEquations is not None: - res.mainEquations = mainEquations + if mainAssignments is not None: + res.mainAssignments = mainAssignments else: - res.mainEquations = self.mainEquations.copy() + res.mainAssignments = self.mainAssignments.copy() if subexpressions is not None: res.subexpressions = subexpressions @@ -64,13 +65,13 @@ class EquationCollection(object): """ if substituteOnLhs: newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions] - newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations] + newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainAssignments] else: - newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.subexpressions] - newEquations = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.mainEquations] + newSubexpressions = [Assignment(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.subexpressions] + newEquations = [Assignment(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.mainAssignments] if addSubstitutionsAsSubexpressions: - newSubexpressions = [sp.Eq(b, a) for a, b in substitutionDict.items()] + newSubexpressions + newSubexpressions = [Assignment(b, a) for a, b in substitutionDict.items()] + newSubexpressions newSubexpressions = sortEquationsTopologically(newSubexpressions) return self.copy(newEquations, newSubexpressions) @@ -86,7 +87,7 @@ class EquationCollection(object): @property def allEquations(self): """Subexpression and main equations in one sequence""" - return self.subexpressions + self.mainEquations + return self.subexpressions + self.mainAssignments @property def freeSymbols(self): @@ -100,30 +101,30 @@ class EquationCollection(object): def boundSymbols(self): """Set of all symbols which occur on left-hand-sides i.e. all symbols which are defined.""" boundSymbolsSet = set([eq.lhs for eq in self.allEquations]) - assert len(boundSymbolsSet) == len(self.subexpressions) + len(self.mainEquations), \ + assert len(boundSymbolsSet) == len(self.subexpressions) + len(self.mainAssignments), \ "Not in SSA form - same symbol assigned multiple times" return boundSymbolsSet @property def definedSymbols(self): """All symbols that occur as left-hand-sides of the main equations""" - return set([eq.lhs for eq in self.mainEquations]) + return set([eq.lhs for eq in self.mainAssignments]) @property def operationCount(self): """See :func:`countNumberOfOperations` """ return countNumberOfOperations(self.allEquations, onlyType=None) - def get(self, symbols, fromMainEquationsOnly=False): + def get(self, symbols, frommainAssignmentsOnly=False): """Return the equations which have symbols as left hand sides""" if not hasattr(symbols, "__len__"): symbols = list(symbols) symbols = set(symbols) - if not fromMainEquationsOnly: + if not frommainAssignmentsOnly: eqsToSearchIn = self.allEquations else: - eqsToSearchIn = self.mainEquations + eqsToSearchIn = self.mainAssignments return [eq for eq in eqsToSearchIn if eq.lhs in symbols] @@ -145,19 +146,19 @@ class EquationCollection(object): if len(self.subexpressions) > 0: result += "<div>Subexpressions:</div>" result += makeHtmlEquationTable(self.subexpressions) - result += "<div>Main Equations:</div>" - result += makeHtmlEquationTable(self.mainEquations) + result += "<div>Main Assignments:</div>" + result += makeHtmlEquationTable(self.mainAssignments) return result def __repr__(self): - return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainEquations]) + return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainAssignments]) def __str__(self): result = "Subexpressions\n" for eq in self.subexpressions: result += str(eq) + "\n" - result += "Main Equations\n" - for eq in self.mainEquations: + result += "Main Assignments\n" + for eq in self.mainAssignments: result += str(eq) + "\n" return result @@ -165,8 +166,8 @@ class EquationCollection(object): def merge(self, other): """Returns a new collection which contains self and other. Subexpressions are renamed if they clash.""" - ownDefs = set([e.lhs for e in self.mainEquations]) - otherDefs = set([e.lhs for e in other.mainEquations]) + ownDefs = set([e.lhs for e in self.mainAssignments]) + otherDefs = set([e.lhs for e in other.mainAssignments]) assert len(ownDefs.intersection(otherDefs)) == 0, "Cannot merge, since both collection define the same symbols" ownSubexpressionSymbols = {e.lhs: e.rhs for e in self.subexpressions} @@ -180,14 +181,14 @@ class EquationCollection(object): else: # different definition - a new name has to be introduced newLhs = next(self.subexpressionSymbolNameGenerator) - newEq = sp.Eq(newLhs, fastSubs(otherSubexpressionEq.rhs, substitutionDict)) + newEq = Assignment(newLhs, fastSubs(otherSubexpressionEq.rhs, substitutionDict)) processedOtherSubexpressionEquations.append(newEq) substitutionDict[otherSubexpressionEq.lhs] = newLhs else: processedOtherSubexpressionEquations.append(fastSubs(otherSubexpressionEq, substitutionDict)) - processedOtherMainEquations = [fastSubs(eq, substitutionDict) for eq in other.mainEquations] - return self.copy(self.mainEquations + processedOtherMainEquations, + processedOthermainAssignments = [fastSubs(eq, substitutionDict) for eq in other.mainAssignments] + return self.copy(self.mainAssignments + processedOthermainAssignments, self.subexpressions + processedOtherSubexpressionEquations) def getDependentSymbols(self, symbolSequence): @@ -226,28 +227,28 @@ class EquationCollection(object): newEquations.append(eq) newSubExpr = [eq for eq in self.subexpressions if eq.lhs in dependentSymbols and eq.lhs not in symbolsToExtract] - return EquationCollection(newEquations, newSubExpr) + return AssignmentCollection(newEquations, newSubExpr) def newWithoutUnusedSubexpressions(self): """Returns a new equation collection containing only the subexpressions that are used/referenced in the equations""" - allLhs = [eq.lhs for eq in self.mainEquations] + allLhs = [eq.lhs for eq in self.mainAssignments] return self.extract(allLhs) def appendToSubexpressions(self, rhs, lhs=None, topologicalSort=True): if lhs is None: lhs = sp.Dummy() - eq = sp.Eq(lhs, rhs) + eq = Assignment(lhs, rhs) self.subexpressions.append(eq) if topologicalSort: - self.topologicalSort(subexpressions=True, mainEquations=False) + self.topologicalSort(subexpressions=True, mainAssignments=False) return lhs - def topologicalSort(self, subexpressions=True, mainEquations=True): + def topologicalSort(self, subexpressions=True, mainAssignments=True): if subexpressions: self.subexpressions = sortEquationsTopologically(self.subexpressions) - if mainEquations: - self.mainEquations = sortEquationsTopologically(self.mainEquations) + if mainAssignments: + self.mainAssignments = sortEquationsTopologically(self.mainAssignments) def insertSubexpression(self, symbol): newSubexpressions = [] @@ -260,8 +261,8 @@ class EquationCollection(object): if subsDict is None: return self - newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in newSubexpressions] - newEqs = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in self.mainEquations] + newSubexpressions = [Assignment(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in newSubexpressions] + newEqs = [Assignment(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in self.mainAssignments] return self.copy(newEqs, newSubexpressions) def insertSubexpressions(self, subexpressionSymbolsToKeep=set()): @@ -286,7 +287,7 @@ class EquationCollection(object): else: subsDict[subExpr[i].lhs] = subExpr[i].rhs - newEq = [fastSubs(eq, subsDict) for eq in self.mainEquations] + newEq = [fastSubs(eq, subsDict) for eq in self.mainAssignments] return self.copy(newEq, keptSubexpressions) def lambdify(self, symbols, module=None, fixedSymbols={}): @@ -296,7 +297,7 @@ class EquationCollection(object): :param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g. 'numpy' :param fixedSymbols: dictionary with substitutions, that are applied before lambdification """ - eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations + eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainAssignments lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs} def f(*args, **kwargs): diff --git a/assignment_collection/simplifications.py b/assignment_collection/simplifications.py new file mode 100644 index 0000000000000000000000000000000000000000..7d0eab53d92faa37d30c7d31c3db44365d79588c --- /dev/null +++ b/assignment_collection/simplifications.py @@ -0,0 +1,87 @@ +import sympy as sp + +from pystencils import Assignment, AssignmentCollection +from pystencils.sympyextensions import replaceAdditive + + +def sympyCseOnEquationList(eqs): + ec = AssignmentCollection(eqs, []) + return sympyCSE(ec).allEquations + + +def sympyCSE(assignment_collection): + """ + Searches for common subexpressions inside the equation collection, in both the existing subexpressions as well + as the equations themselves. It uses the sympy subexpression detection to do this. Return a new equation collection + with the additional subexpressions found + """ + symbolGen = assignment_collection.subexpressionSymbolNameGenerator + replacements, newEq = sp.cse(assignment_collection.subexpressions + assignment_collection.mainAssignments, + symbols=symbolGen) + replacementEqs = [Assignment(*r) for r in replacements] + + modifiedSubexpressions = newEq[:len(assignment_collection.subexpressions)] + modifiedUpdateEquations = newEq[len(assignment_collection.subexpressions):] + + newSubexpressions = replacementEqs + modifiedSubexpressions + topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions]) + newSubexpressions = [Assignment(a[0], a[1]) for a in topologicallySortedPairs] + + return assignment_collection.copy(modifiedUpdateEquations, newSubexpressions) + + +def applyOnAllEquations(assignment_collection, operation): + """Applies sympy expand operation to all equations in collection""" + result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.mainAssignments] + return assignment_collection.copy(result) + + +def applyOnAllSubexpressions(assignment_collection, operation): + result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.subexpressions] + return assignment_collection.copy(assignment_collection.mainAssignments, result) + + +def subexpressionSubstitutionInExistingSubexpressions(assignment_collection): + """Goes through the subexpressions list and replaces the term in the following subexpressions""" + result = [] + for outerCtr, s in enumerate(assignment_collection.subexpressions): + newRhs = s.rhs + for innerCtr in range(outerCtr): + subExpr = assignment_collection.subexpressions[innerCtr] + newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) + newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs) + result.append(Assignment(s.lhs, newRhs)) + + return assignment_collection.copy(assignment_collection.mainAssignments, result) + + +def subexpressionSubstitutionInmainAssignments(assignment_collection): + """Replaces already existing subexpressions in the equations of the assignment_collection""" + result = [] + for s in assignment_collection.mainAssignments: + newRhs = s.rhs + for subExpr in assignment_collection.subexpressions: + newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) + result.append(Assignment(s.lhs, newRhs)) + return assignment_collection.copy(result) + + +def addSubexpressionsForDivisions(assignment_collection): + """Introduces subexpressions for all divisions which have no constant in the denominator. + e.g. :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced.""" + divisors = set() + + def searchDivisors(term): + if term.func == sp.Pow: + if term.exp.is_integer and term.exp.is_number and term.exp < 0: + divisors.add(term) + else: + for a in term.args: + searchDivisors(a) + + for eq in assignment_collection.allEquations: + searchDivisors(eq.rhs) + + newSymbolGen = assignment_collection.subexpressionSymbolNameGenerator + substitutions = {divisor: newSymbol for newSymbol, divisor in zip(newSymbolGen, divisors)} + return assignment_collection.copyWithSubstitutionsApplied(substitutions, True) diff --git a/equationcollection/simplificationstrategy.py b/assignment_collection/simplificationstrategy.py similarity index 87% rename from equationcollection/simplificationstrategy.py rename to assignment_collection/simplificationstrategy.py index 38f4c3f8b5918c56b8f9313d31b00a805cdd90d9..3d8cdd62fb0a93752f524e2cbf1c62f384949f3f 100644 --- a/equationcollection/simplificationstrategy.py +++ b/assignment_collection/simplificationstrategy.py @@ -30,11 +30,11 @@ class SimplificationStrategy(object): updateRule = t(updateRule) return updateRule - def __call__(self, equationCollection): + def __call__(self, assignment_collection): """Same as apply""" - return self.apply(equationCollection) + return self.apply(assignment_collection) - def createSimplificationReport(self, equationCollection): + def createSimplificationReport(self, assignment_collection): """ Returns a simplification report containing the number of operations at each simplification stage, together with the run-time the simplification took. @@ -72,25 +72,25 @@ class SimplificationStrategy(object): import timeit report = Report() - op = equationCollection.operationCount + op = assignment_collection.operationCount total = op['adds'] + op['muls'] + op['divs'] report.add(ReportElement("OriginalTerm", '-', op['adds'], op['muls'], op['divs'], total)) for t in self._rules: startTime = timeit.default_timer() - equationCollection = t(equationCollection) + assignment_collection = t(assignment_collection) endTime = timeit.default_timer() - op = equationCollection.operationCount + op = assignment_collection.operationCount timeStr = "%.2f ms" % ((endTime - startTime) * 1000,) total = op['adds'] + op['muls'] + op['divs'] report.add(ReportElement(t.__name__, timeStr, op['adds'], op['muls'], op['divs'], total)) return report - def showIntermediateResults(self, equationCollection, symbols=None): + def showIntermediateResults(self, assignment_collection, symbols=None): class IntermediateResults: def __init__(self, strategy, eqColl, resSyms): self.strategy = strategy - self.equationCollection = eqColl + self.assignment_collection = eqColl self.restrictSymbols = resSyms def __str__(self): @@ -102,8 +102,8 @@ class SimplificationStrategy(object): text += (" " * 3 + (" " * 3).join(str(eqColl).splitlines(True))) return text - result = printEqCollection("Initial Version", self.equationCollection) - eqColl = self.equationCollection + result = printEqCollection("Initial Version", self.assignment_collection) + eqColl = self.assignment_collection for rule in self.strategy.rules: eqColl = rule(eqColl) result += printEqCollection(rule.__name__, eqColl) @@ -119,14 +119,14 @@ class SimplificationStrategy(object): text += "</div>" return text - result = printEqCollection("Initial Version", self.equationCollection) - eqColl = self.equationCollection + result = printEqCollection("Initial Version", self.assignment_collection) + eqColl = self.assignment_collection for rule in self.strategy.rules: eqColl = rule(eqColl) result += printEqCollection(rule.__name__, eqColl) return result - return IntermediateResults(self, equationCollection, symbols) + return IntermediateResults(self, assignment_collection, symbols) def __repr__(self): result = "Simplification Strategy:\n" diff --git a/backends/dot.py b/backends/dot.py index 8797c49f07ef6e244f07deac2458c4026d29b2bf..e78ac1bb0c550be5af7537edbf098f4aa086bf0f 100644 --- a/backends/dot.py +++ b/backends/dot.py @@ -91,33 +91,14 @@ def dotprint(node, view=False, short=False, full=False, **kwargs): :param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph :return: string in DOT format """ - nodeToStrFunction = repr + node_to_str_function = repr if short: - nodeToStrFunction = __shortened + node_to_str_function = __shortened elif full: - nodeToStrFunction = lambda expr: repr(type(expr)) + repr(expr) - printer = DotPrinter(nodeToStrFunction, full, **kwargs) + node_to_str_function = lambda expr: repr(type(expr)) + repr(expr) + printer = DotPrinter(node_to_str_function, full, **kwargs) dot = printer.doprint(node) if view: return graphviz.Source(dot) return dot - -if __name__ == "__main__": - from pystencils import Field - import sympy as sp - imgField = Field.createGeneric('I', - spatialDimensions=2, # 2D image - indexDimensions=1) # multiple values per pixel: e.g. RGB - w1, w2 = sp.symbols("w_1 w_2") - sobelX = -w2 * imgField[-1, 0](1) - w1 * imgField[-1, -1](1) - w1 * imgField[-1, +1](1) \ - + w2 * imgField[+1, 0](1) + w1 * imgField[+1, -1](1) - w1 * imgField[+1, +1](1) - sobelX - - dstField = Field.createGeneric('dst', spatialDimensions=2, indexDimensions=0) - updateRule = sp.Eq(dstField[0, 0], sobelX) - updateRule - - from pystencils import createKernel - ast = createKernel([updateRule]) - print(dotprint(ast, short=True)) diff --git a/boundaries/boundaryconditions.py b/boundaries/boundaryconditions.py index 0a799cdac8722d5d42ce8c6bfb136e29cbb351f5..1cc67d17ca8df82b8c3b601533528c03a03368f6 100644 --- a/boundaries/boundaryconditions.py +++ b/boundaries/boundaryconditions.py @@ -1,4 +1,4 @@ -import sympy as sp +from pystencils import Assignment from pystencils.boundaries.boundaryhandling import BoundaryOffsetInfo @@ -52,13 +52,13 @@ class Neumann(Boundary): neighbor = BoundaryOffsetInfo.offsetFromDir(directionSymbol, field.spatialDimensions) if field.indexDimensions == 0: - return [sp.Eq(field[neighbor], field.center)] + return [Assignment(field[neighbor], field.center)] else: from itertools import product if not field.hasFixedIndexShape: raise NotImplementedError("Neumann boundary works only for fields with fixed index shape") indexIter = product(*(range(i) for i in field.indexShape)) - return [sp.Eq(field[neighbor](*idx), field(*idx)) for idx in indexIter] + return [Assignment(field[neighbor](*idx), field(*idx)) for idx in indexIter] def __hash__(self): # All boundaries of these class behave equal -> should also be equal diff --git a/boundaries/boundaryhandling.py b/boundaries/boundaryhandling.py index 7787346c34e60e25097a4befa8a8c0b5b44060fc..76999a549ca25639401b03c667ce7ca951430e1c 100644 --- a/boundaries/boundaryhandling.py +++ b/boundaries/boundaryhandling.py @@ -1,5 +1,6 @@ import numpy as np import sympy as sp +from pystencils.assignment import Assignment from pystencils import Field, TypedSymbol, createIndexedKernel from pystencils.backends.cbackend import CustomCppCode from pystencils.boundaries.createindexlist import numpyDataTypeForBoundaryObject, createBoundaryIndexArray @@ -363,6 +364,6 @@ def createBoundaryKernel(field, indexField, stencil, boundaryFunctor, target='cp elements = [BoundaryOffsetInfo(stencil)] indexArrDtype = indexField.dtype.numpyDtype dirSymbol = TypedSymbol("dir", indexArrDtype.fields['dir'][0]) - elements += [sp.Eq(dirSymbol, indexField[0]('dir'))] + elements += [Assignment(dirSymbol, indexField[0]('dir'))] elements += boundaryFunctor(field, directionSymbol=dirSymbol, indexField=indexField) return createIndexedKernel(elements, [indexField], target=target, cpuOpenMP=openMP) diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index 0c96c1f4557f16235a785ceec71f464f0833d4ca..c4c2833b49aa1c6282621164630e0074220ae6bf 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -33,7 +33,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol='double', :return: :class:`pystencils.ast.KernelFunction` node """ - def typeSymbol(term): + def type_symbol(term): if isinstance(term, Field.Access) or isinstance(term, TypedSymbol): return term elif isinstance(term, sp.Symbol): @@ -58,7 +58,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol='double', code.target = 'cpu' if splitGroups: - typedSplitGroups = [[typeSymbol(s) for s in splitGroup] for splitGroup in splitGroups] + typedSplitGroups = [[type_symbol(s) for s in splitGroup] for splitGroup in splitGroups] splitInnerLoop(code, typedSplitGroups) basePointerInfo = [['spatialInner0'], ['spatialInner1']] if len(loopOrder) >= 2 else [['spatialInner0']] diff --git a/equationcollection/__init__.py b/equationcollection/__init__.py deleted file mode 100644 index 434f02e213073ec556c9d70b5ed9091283533347..0000000000000000000000000000000000000000 --- a/equationcollection/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from pystencils.equationcollection.equationcollection import EquationCollection -from pystencils.equationcollection.simplificationstrategy import SimplificationStrategy diff --git a/equationcollection/simplifications.py b/equationcollection/simplifications.py deleted file mode 100644 index 0e1e92f4872e58262d5b8667911d2392261d0819..0000000000000000000000000000000000000000 --- a/equationcollection/simplifications.py +++ /dev/null @@ -1,87 +0,0 @@ -import sympy as sp - -from pystencils.equationcollection.equationcollection import EquationCollection -from pystencils.sympyextensions import replaceAdditive - - -def sympyCseOnEquationList(eqs): - ec = EquationCollection(eqs, []) - return sympyCSE(ec).allEquations - - -def sympyCSE(equationCollection): - """ - Searches for common subexpressions inside the equation collection, in both the existing subexpressions as well - as the equations themselves. It uses the sympy subexpression detection to do this. Return a new equation collection - with the additional subexpressions found - """ - symbolGen = equationCollection.subexpressionSymbolNameGenerator - replacements, newEq = sp.cse(equationCollection.subexpressions + equationCollection.mainEquations, - symbols=symbolGen) - replacementEqs = [sp.Eq(*r) for r in replacements] - - modifiedSubexpressions = newEq[:len(equationCollection.subexpressions)] - modifiedUpdateEquations = newEq[len(equationCollection.subexpressions):] - - newSubexpressions = replacementEqs + modifiedSubexpressions - topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions]) - newSubexpressions = [sp.Eq(a[0], a[1]) for a in topologicallySortedPairs] - - return equationCollection.copy(modifiedUpdateEquations, newSubexpressions) - - -def applyOnAllEquations(equationCollection, operation): - """Applies sympy expand operation to all equations in collection""" - result = [sp.Eq(eq.lhs, operation(eq.rhs)) for eq in equationCollection.mainEquations] - return equationCollection.copy(result) - - -def applyOnAllSubexpressions(equationCollection, operation): - result = [sp.Eq(eq.lhs, operation(eq.rhs)) for eq in equationCollection.subexpressions] - return equationCollection.copy(equationCollection.mainEquations, result) - - -def subexpressionSubstitutionInExistingSubexpressions(equationCollection): - """Goes through the subexpressions list and replaces the term in the following subexpressions""" - result = [] - for outerCtr, s in enumerate(equationCollection.subexpressions): - newRhs = s.rhs - for innerCtr in range(outerCtr): - subExpr = equationCollection.subexpressions[innerCtr] - newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) - newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs) - result.append(sp.Eq(s.lhs, newRhs)) - - return equationCollection.copy(equationCollection.mainEquations, result) - - -def subexpressionSubstitutionInMainEquations(equationCollection): - """Replaces already existing subexpressions in the equations of the equationCollection""" - result = [] - for s in equationCollection.mainEquations: - newRhs = s.rhs - for subExpr in equationCollection.subexpressions: - newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) - result.append(sp.Eq(s.lhs, newRhs)) - return equationCollection.copy(result) - - -def addSubexpressionsForDivisions(equationCollection): - """Introduces subexpressions for all divisions which have no constant in the denominator. - e.g. :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced.""" - divisors = set() - - def searchDivisors(term): - if term.func == sp.Pow: - if term.exp.is_integer and term.exp.is_number and term.exp < 0: - divisors.add(term) - else: - for a in term.args: - searchDivisors(a) - - for eq in equationCollection.allEquations: - searchDivisors(eq.rhs) - - newSymbolGen = equationCollection.subexpressionSymbolNameGenerator - substitutions = {divisor: newSymbol for newSymbol, divisor in zip(newSymbolGen, divisors)} - return equationCollection.copyWithSubstitutionsApplied(substitutions, True) diff --git a/field.py b/field.py index 24a51755f17909e16449c070b950daeae87a1e48..d310f9b77d341205912bc6c030fa2877c9cd3166 100644 --- a/field.py +++ b/field.py @@ -5,6 +5,7 @@ import sympy as sp from sympy.core.cache import cacheit from sympy.tensor import IndexedBase +from pystencils.assignment import Assignment from pystencils.alignedarray import aligned_empty from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFromString, StructType from pystencils.sympyextensions import isIntegerSequence @@ -71,10 +72,10 @@ class Field(object): >>> src = Field.createGeneric("src", spatialDimensions=2, indexDimensions=1) >>> dst = Field.createGeneric("dst", spatialDimensions=2, indexDimensions=1) >>> for i, offset in enumerate(stencil): - ... sp.Eq(dst[0,0](i), src[-offset](i)) - Eq(dst_C^0, src_C^0) - Eq(dst_C^1, src_S^1) - Eq(dst_C^2, src_N^2) + ... Assignment(dst[0,0](i), src[-offset](i)) + Assignment(dst_C^0, src_C^0) + Assignment(dst_C^1, src_S^1) + Assignment(dst_C^2, src_N^2) """ @staticmethod @@ -437,22 +438,22 @@ def extractCommonSubexpressions(equations): them in a topologically sorted order, ready for evaluation. Usually called before list of equations is passed to :func:`createKernel` """ - replacements, newEq = sp.cse(equations) + replacements, new_eq = sp.cse(equations) # Workaround for older sympy versions: here subexpressions (temporary = True) are extracted # which leads to problems in Piecewise functions which have to a default case indicated by True - symbolsEqualToTrue = {r[0]: True for r in replacements if r[1] is sp.true} + symbols_equal_to_true = {r[0]: True for r in replacements if r[1] is sp.true} - replacementEqs = [sp.Eq(*r) for r in replacements if r[1] is not sp.true] - equations = replacementEqs + newEq - topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equations]) - equations = [sp.Eq(a[0], a[1].subs(symbolsEqualToTrue)) for a in topologicallySortedPairs] + replacement_eqs = [Assignment(*r) for r in replacements if r[1] is not sp.true] + equations = replacement_eqs + new_eq + topologically_sorted_pairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equations]) + equations = [Assignment(a[0], a[1].subs(symbols_equal_to_true)) for a in topologically_sorted_pairs] return equations def getLayoutFromStrides(strides, indexDimensionIds=[]): coordinates = list(range(len(strides))) - relevantStrides = [stride for i, stride in enumerate(strides) if i not in indexDimensionIds] - result = [x for (y, x) in sorted(zip(relevantStrides, coordinates), key=lambda pair: pair[0], reverse=True)] + relevant_strides = [stride for i, stride in enumerate(strides) if i not in indexDimensionIds] + result = [x for (y, x) in sorted(zip(relevant_strides, coordinates), key=lambda pair: pair[0], reverse=True)] return normalizeLayout(result) diff --git a/finitedifferences.py b/finitedifferences.py index 039a603e14c04e92bf06066ee3bd4e1032256816..8dba62707946d985d3f5d381f3923018519f48b7 100644 --- a/finitedifferences.py +++ b/finitedifferences.py @@ -1,7 +1,7 @@ import numpy as np import sympy as sp -from pystencils.equationcollection import EquationCollection +from pystencils.assignment_collection import AssignmentCollection from pystencils.field import Field from pystencils.transformations import fastSubs from pystencils.derivative import Diff @@ -355,8 +355,8 @@ class Discretization2ndOrder: return [self(e) for e in expr] elif isinstance(expr, sp.Matrix): return expr.applyfunc(self.__call__) - elif isinstance(expr, EquationCollection): - return expr.copy(mainEquations=[e for e in expr.mainEquations], + elif isinstance(expr, AssignmentCollection): + return expr.copy(mainAssignments=[e for e in expr.mainAssignments], subexpressions=[e for e in expr.subexpressions]) transientTerms = expr.atoms(Transient) diff --git a/gpucuda/periodicity.py b/gpucuda/periodicity.py index 039f21799d9092ea35d89d7a8736691c6a8adb05..5009fce894b41b75bd9b4845b1d7f96fa536e256 100644 --- a/gpucuda/periodicity.py +++ b/gpucuda/periodicity.py @@ -1,6 +1,6 @@ import sympy as sp import numpy as np -from pystencils import Field +from pystencils import Field, Assignment from pystencils.slicing import normalizeSlice, getPeriodicBoundarySrcDstSlices from pystencils.gpucuda import makePythonFunction from pystencils.gpucuda.kernelcreation import createCUDAKernel @@ -20,7 +20,7 @@ def createCopyKernel(domainSize, fromSlice, toSlice, indexDimensions=0, indexDim updateEqs = [] for i in range(indexDimShape): - eq = sp.Eq(f(i), f[tuple(offset)](i)) + eq = Assignment(f(i), f[tuple(offset)](i)) updateEqs.append(eq) ast = createCUDAKernel(updateEqs, iterationSlice=toSlice) diff --git a/kernelcreation.py b/kernelcreation.py index 7d27ed6c9179b9af0bc9332c6b9c705ebd59fdf8..9c84347d93f940ed783a1cd45681b595c7c1c98c 100644 --- a/kernelcreation.py +++ b/kernelcreation.py @@ -1,4 +1,4 @@ -from pystencils.equationcollection import EquationCollection +from pystencils.assignment_collection import AssignmentCollection from pystencils.gpucuda.indexing import indexingCreatorFromParams @@ -7,7 +7,7 @@ def createKernel(equations, target='cpu', dataType="double", iterationSlice=None gpuIndexing='block', gpuIndexingParams={}): """ Creates abstract syntax tree (AST) of kernel, using a list of update equations. - :param equations: either be a plain list of equations or a EquationCollection object + :param equations: either be a plain list of equations or a AssignmentCollection object :param target: 'cpu', 'llvm' or 'gpu' :param dataType: data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name to type @@ -32,7 +32,7 @@ def createKernel(equations, target='cpu', dataType="double", iterationSlice=None # ---- Normalizing parameters splitGroups = () - if isinstance(equations, EquationCollection): + if isinstance(equations, AssignmentCollection): if 'splitGroups' in equations.simplificationHints: splitGroups = equations.simplificationHints['splitGroups'] equations = equations.allEquations @@ -83,7 +83,7 @@ def createIndexedKernel(equations, indexFields, target='cpu', dataType="double", coordinateNames: name of the coordinate fields in the struct data type """ - if isinstance(equations, EquationCollection): + if isinstance(equations, AssignmentCollection): equations = equations.allEquations if target == 'cpu': from pystencils.cpu import createIndexedKernel diff --git a/llvm/llvm.py b/llvm/llvm.py index 6529c699eef065cc4e8e9abe960df8a5c6bcfdd2..08ad220bb0a07627392e42d70484c697f7941a79 100644 --- a/llvm/llvm.py +++ b/llvm/llvm.py @@ -9,7 +9,7 @@ from pystencils.llvm.control_flow import Loop from pystencils.data_types import createType, to_llvm_type, getTypeOfExpression, collateTypes, \ createCompositeTypeFromString from sympy import Indexed -from sympy.codegen.ast import Assignment +from pystencils.assignment import Assignment def generateLLVM(ast_node, module=None, builder=None): diff --git a/sympyextensions.py b/sympyextensions.py index 6e29cabf9135f2ba4b1b82ccecd44e130db8a5d4..d407d8f61ef463ecb2fcabe00af598f99c63ece7 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -6,6 +6,7 @@ import warnings import sympy as sp from pystencils.data_types import getTypeOfExpression, getBaseType +from pystencils.assignment import Assignment def prod(seq): @@ -264,7 +265,7 @@ def replaceSecondOrderProducts(expr, searchSymbols, positive=None, replaceMixed= mixedSymbol = sp.Symbol(mixedSymbolName.replace("_", "")) if mixedSymbol not in mixedSymbolsReplaced: mixedSymbolsReplaced.add(mixedSymbol) - replaceMixed.append(sp.Eq(mixedSymbol, u + sign * v)) + replaceMixed.append(Assignment(mixedSymbol, u + sign * v)) else: mixedSymbol = u + sign * v return sp.Rational(1, 2) * sign * otherFactors * (mixedSymbol ** 2 - u ** 2 - v ** 2) @@ -431,7 +432,7 @@ def countNumberOfOperations(term, onlyType='real'): for operationName in result.keys(): result[operationName] += r[operationName] return result - elif isinstance(term, sp.Eq): + elif isinstance(term, Assignment): term = term.rhs term = term.evalf() @@ -538,7 +539,7 @@ def getSymmetricPart(term, vars): def sortEquationsTopologically(equationSequence): res = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equationSequence]) - return [sp.Eq(a, b) for a, b in res] + return [Assignment(a, b) for a, b in res] def getEquationsFromFunction(func, **kwargs): @@ -559,7 +560,7 @@ def getEquationsFromFunction(func, **kwargs): ... S.neighbors @= f[0,1] + f[1,0] ... g[0,0] @= S.neighbors + f[0,0] >>> getEquationsFromFunction(myKernel) - [Eq(neighbors, f_E + f_N), Eq(g_C, f_C + neighbors)] + [Assignment(neighbors, f_E + f_N), Assignment(g_C, f_C + neighbors)] """ import inspect import re @@ -590,7 +591,7 @@ def getEquationsFromFunction(func, **kwargs): code = "".join(sourceLines[1:]) result = [] localsDict = {'_result': result, - 'Eq': sp.Eq, + 'Eq': Assignment, 'S': SymbolCreator()} localsDict.update(kwargs) globalsDict = inspect.stack()[1][0].f_globals.copy() diff --git a/transformations/transformations.py b/transformations/transformations.py index 299201f0cdf39c9b7fa2bd0ddbcac9eec6832496..5c9c9d1fe765939ddfb60dc98d77de0d7b20917b 100644 --- a/transformations/transformations.py +++ b/transformations/transformations.py @@ -7,6 +7,7 @@ import sympy as sp from sympy.logic.boolalg import Boolean from sympy.tensor import IndexedBase +from pystencils.assignment import Assignment from pystencils.field import Field, FieldType, offsetComponentToDirectionString from pystencils.data_types import TypedSymbol, createType, PointerType, StructType, getBaseType, castFunc from pystencils.slicing import normalizeSlice @@ -737,7 +738,7 @@ def typeAllEquations(eqs, typeForSymbol): def visit(object): if isinstance(object, list) or isinstance(object, tuple): return [visit(e) for e in object] - if isinstance(object, sp.Eq) or isinstance(object, ast.SympyAssignment): + if isinstance(object, sp.Eq) or isinstance(object, ast.SympyAssignment) or isinstance(object, Assignment): newLhs = processLhs(object.lhs) newRhs = processRhs(object.rhs) return ast.SympyAssignment(newLhs, newRhs)