Commit 9f071fdd authored by Martin Bauer's avatar Martin Bauer
Browse files

pystencils: Assignment instead of sympy.Eq

- Previously sympy.Eq was used to represent assignments. However Eq
  represents equality not assignment. This means that sometimes sympy
  "simplified" an equation like a = a  to True,
-> replaced sp.Eq by pystencils.Assignment everywhere
- renamed EquationCollection to AssignmentCollection
parent eadeaadf
......@@ -3,14 +3,13 @@ from pystencils.data_types import TypedSymbol
from pystencils.slicing import makeSlice
from pystencils.kernelcreation import createKernel, createIndexedKernel
from pystencils.display_utils import showCode, toDot
from pystencils.equationcollection import EquationCollection
from sympy.codegen.ast import Assignment as Assign
from pystencils.assignment_collection import AssignmentCollection
from pystencils.assignment import Assignment
__all__ = ['Field', 'FieldType', 'extractCommonSubexpressions',
'TypedSymbol',
'makeSlice',
'createKernel', 'createIndexedKernel',
'showCode', 'toDot',
'EquationCollection',
'Assign']
'AssignmentCollection',
'Assignment']
from sympy.codegen.ast import Assignment
from sympy.printing.latex import LatexPrinter
__all__ = ['Assignment']
def print_assignment_latex(printer, expr):
"""sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer"""
printed_lhs = printer.doprint(expr.lhs)
printed_rhs = printer.doprint(expr.rhs)
return f"{printed_lhs} \leftarrow {printed_rhs}"
LatexPrinter._print_Assignment = print_assignment_latex
from pystencils.assignment_collection.assignment_collection import AssignmentCollection
from pystencils.assignment_collection.simplificationstrategy import SimplificationStrategy
import sympy as sp
from copy import copy
from pystencils.assignment import Assignment
from pystencils.sympyextensions import fastSubs, countNumberOfOperations, sortEquationsTopologically
class EquationCollection(object):
class AssignmentCollection(object):
"""
A collection of equations with subexpression definitions, also represented as equations,
that are used in the main equations. EquationCollections can be passed to simplification methods.
that are used in the main equations. AssignmentCollection can be passed to simplification methods.
These simplification methods can change the subexpressions, but the number and
left hand side of the main equations themselves is not altered.
Additionally a dictionary of simplification hints is stored, which are set by the functions that create
equation collections to transport information to the simplification system.
:ivar mainEquations: list of sympy equations
:ivar mainAssignments: list of sympy equations
:ivar subexpressions: list of sympy equations defining subexpressions used in main equations
:ivar simplificationHints: dictionary that is used to annotate the equation collection with hints that are
used by the simplification system. See documentation of the simplification rules for
......@@ -22,7 +23,7 @@ class EquationCollection(object):
# ----------------------------------------- Creation ---------------------------------------------------------------
def __init__(self, equations, subExpressions, simplificationHints=None, subexpressionSymbolNameGenerator=None):
self.mainEquations = equations
self.mainAssignments = equations
self.subexpressions = subExpressions
if simplificationHints is None:
......@@ -39,15 +40,15 @@ class EquationCollection(object):
def mainTerms(self):
return []
def copy(self, mainEquations=None, subexpressions=None):
def copy(self, mainAssignments=None, subexpressions=None):
res = copy(self)
res.simplificationHints = self.simplificationHints.copy()
res.subexpressionSymbolNameGenerator = copy(self.subexpressionSymbolNameGenerator)
if mainEquations is not None:
res.mainEquations = mainEquations
if mainAssignments is not None:
res.mainAssignments = mainAssignments
else:
res.mainEquations = self.mainEquations.copy()
res.mainAssignments = self.mainAssignments.copy()
if subexpressions is not None:
res.subexpressions = subexpressions
......@@ -64,13 +65,13 @@ class EquationCollection(object):
"""
if substituteOnLhs:
newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions]
newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations]
newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainAssignments]
else:
newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.subexpressions]
newEquations = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.mainEquations]
newSubexpressions = [Assignment(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.subexpressions]
newEquations = [Assignment(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.mainAssignments]
if addSubstitutionsAsSubexpressions:
newSubexpressions = [sp.Eq(b, a) for a, b in substitutionDict.items()] + newSubexpressions
newSubexpressions = [Assignment(b, a) for a, b in substitutionDict.items()] + newSubexpressions
newSubexpressions = sortEquationsTopologically(newSubexpressions)
return self.copy(newEquations, newSubexpressions)
......@@ -86,7 +87,7 @@ class EquationCollection(object):
@property
def allEquations(self):
"""Subexpression and main equations in one sequence"""
return self.subexpressions + self.mainEquations
return self.subexpressions + self.mainAssignments
@property
def freeSymbols(self):
......@@ -100,30 +101,30 @@ class EquationCollection(object):
def boundSymbols(self):
"""Set of all symbols which occur on left-hand-sides i.e. all symbols which are defined."""
boundSymbolsSet = set([eq.lhs for eq in self.allEquations])
assert len(boundSymbolsSet) == len(self.subexpressions) + len(self.mainEquations), \
assert len(boundSymbolsSet) == len(self.subexpressions) + len(self.mainAssignments), \
"Not in SSA form - same symbol assigned multiple times"
return boundSymbolsSet
@property
def definedSymbols(self):
"""All symbols that occur as left-hand-sides of the main equations"""
return set([eq.lhs for eq in self.mainEquations])
return set([eq.lhs for eq in self.mainAssignments])
@property
def operationCount(self):
"""See :func:`countNumberOfOperations` """
return countNumberOfOperations(self.allEquations, onlyType=None)
def get(self, symbols, fromMainEquationsOnly=False):
def get(self, symbols, frommainAssignmentsOnly=False):
"""Return the equations which have symbols as left hand sides"""
if not hasattr(symbols, "__len__"):
symbols = list(symbols)
symbols = set(symbols)
if not fromMainEquationsOnly:
if not frommainAssignmentsOnly:
eqsToSearchIn = self.allEquations
else:
eqsToSearchIn = self.mainEquations
eqsToSearchIn = self.mainAssignments
return [eq for eq in eqsToSearchIn if eq.lhs in symbols]
......@@ -145,19 +146,19 @@ class EquationCollection(object):
if len(self.subexpressions) > 0:
result += "<div>Subexpressions:</div>"
result += makeHtmlEquationTable(self.subexpressions)
result += "<div>Main Equations:</div>"
result += makeHtmlEquationTable(self.mainEquations)
result += "<div>Main Assignments:</div>"
result += makeHtmlEquationTable(self.mainAssignments)
return result
def __repr__(self):
return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainEquations])
return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainAssignments])
def __str__(self):
result = "Subexpressions\n"
for eq in self.subexpressions:
result += str(eq) + "\n"
result += "Main Equations\n"
for eq in self.mainEquations:
result += "Main Assignments\n"
for eq in self.mainAssignments:
result += str(eq) + "\n"
return result
......@@ -165,8 +166,8 @@ class EquationCollection(object):
def merge(self, other):
"""Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
ownDefs = set([e.lhs for e in self.mainEquations])
otherDefs = set([e.lhs for e in other.mainEquations])
ownDefs = set([e.lhs for e in self.mainAssignments])
otherDefs = set([e.lhs for e in other.mainAssignments])
assert len(ownDefs.intersection(otherDefs)) == 0, "Cannot merge, since both collection define the same symbols"
ownSubexpressionSymbols = {e.lhs: e.rhs for e in self.subexpressions}
......@@ -180,14 +181,14 @@ class EquationCollection(object):
else:
# different definition - a new name has to be introduced
newLhs = next(self.subexpressionSymbolNameGenerator)
newEq = sp.Eq(newLhs, fastSubs(otherSubexpressionEq.rhs, substitutionDict))
newEq = Assignment(newLhs, fastSubs(otherSubexpressionEq.rhs, substitutionDict))
processedOtherSubexpressionEquations.append(newEq)
substitutionDict[otherSubexpressionEq.lhs] = newLhs
else:
processedOtherSubexpressionEquations.append(fastSubs(otherSubexpressionEq, substitutionDict))
processedOtherMainEquations = [fastSubs(eq, substitutionDict) for eq in other.mainEquations]
return self.copy(self.mainEquations + processedOtherMainEquations,
processedOthermainAssignments = [fastSubs(eq, substitutionDict) for eq in other.mainAssignments]
return self.copy(self.mainAssignments + processedOthermainAssignments,
self.subexpressions + processedOtherSubexpressionEquations)
def getDependentSymbols(self, symbolSequence):
......@@ -226,28 +227,28 @@ class EquationCollection(object):
newEquations.append(eq)
newSubExpr = [eq for eq in self.subexpressions if eq.lhs in dependentSymbols and eq.lhs not in symbolsToExtract]
return EquationCollection(newEquations, newSubExpr)
return AssignmentCollection(newEquations, newSubExpr)
def newWithoutUnusedSubexpressions(self):
"""Returns a new equation collection containing only the subexpressions that
are used/referenced in the equations"""
allLhs = [eq.lhs for eq in self.mainEquations]
allLhs = [eq.lhs for eq in self.mainAssignments]
return self.extract(allLhs)
def appendToSubexpressions(self, rhs, lhs=None, topologicalSort=True):
if lhs is None:
lhs = sp.Dummy()
eq = sp.Eq(lhs, rhs)
eq = Assignment(lhs, rhs)
self.subexpressions.append(eq)
if topologicalSort:
self.topologicalSort(subexpressions=True, mainEquations=False)
self.topologicalSort(subexpressions=True, mainAssignments=False)
return lhs
def topologicalSort(self, subexpressions=True, mainEquations=True):
def topologicalSort(self, subexpressions=True, mainAssignments=True):
if subexpressions:
self.subexpressions = sortEquationsTopologically(self.subexpressions)
if mainEquations:
self.mainEquations = sortEquationsTopologically(self.mainEquations)
if mainAssignments:
self.mainAssignments = sortEquationsTopologically(self.mainAssignments)
def insertSubexpression(self, symbol):
newSubexpressions = []
......@@ -260,8 +261,8 @@ class EquationCollection(object):
if subsDict is None:
return self
newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in newSubexpressions]
newEqs = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in self.mainEquations]
newSubexpressions = [Assignment(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in newSubexpressions]
newEqs = [Assignment(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in self.mainAssignments]
return self.copy(newEqs, newSubexpressions)
def insertSubexpressions(self, subexpressionSymbolsToKeep=set()):
......@@ -286,7 +287,7 @@ class EquationCollection(object):
else:
subsDict[subExpr[i].lhs] = subExpr[i].rhs
newEq = [fastSubs(eq, subsDict) for eq in self.mainEquations]
newEq = [fastSubs(eq, subsDict) for eq in self.mainAssignments]
return self.copy(newEq, keptSubexpressions)
def lambdify(self, symbols, module=None, fixedSymbols={}):
......@@ -296,7 +297,7 @@ class EquationCollection(object):
:param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g. 'numpy'
:param fixedSymbols: dictionary with substitutions, that are applied before lambdification
"""
eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations
eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainAssignments
lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs}
def f(*args, **kwargs):
......
import sympy as sp
from pystencils.equationcollection.equationcollection import EquationCollection
from pystencils import Assignment, AssignmentCollection
from pystencils.sympyextensions import replaceAdditive
def sympyCseOnEquationList(eqs):
ec = EquationCollection(eqs, [])
ec = AssignmentCollection(eqs, [])
return sympyCSE(ec).allEquations
def sympyCSE(equationCollection):
def sympyCSE(assignment_collection):
"""
Searches for common subexpressions inside the equation collection, in both the existing subexpressions as well
as the equations themselves. It uses the sympy subexpression detection to do this. Return a new equation collection
with the additional subexpressions found
"""
symbolGen = equationCollection.subexpressionSymbolNameGenerator
replacements, newEq = sp.cse(equationCollection.subexpressions + equationCollection.mainEquations,
symbolGen = assignment_collection.subexpressionSymbolNameGenerator
replacements, newEq = sp.cse(assignment_collection.subexpressions + assignment_collection.mainAssignments,
symbols=symbolGen)
replacementEqs = [sp.Eq(*r) for r in replacements]
replacementEqs = [Assignment(*r) for r in replacements]
modifiedSubexpressions = newEq[:len(equationCollection.subexpressions)]
modifiedUpdateEquations = newEq[len(equationCollection.subexpressions):]
modifiedSubexpressions = newEq[:len(assignment_collection.subexpressions)]
modifiedUpdateEquations = newEq[len(assignment_collection.subexpressions):]
newSubexpressions = replacementEqs + modifiedSubexpressions
topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions])
newSubexpressions = [sp.Eq(a[0], a[1]) for a in topologicallySortedPairs]
newSubexpressions = [Assignment(a[0], a[1]) for a in topologicallySortedPairs]
return equationCollection.copy(modifiedUpdateEquations, newSubexpressions)
return assignment_collection.copy(modifiedUpdateEquations, newSubexpressions)
def applyOnAllEquations(equationCollection, operation):
def applyOnAllEquations(assignment_collection, operation):
"""Applies sympy expand operation to all equations in collection"""
result = [sp.Eq(eq.lhs, operation(eq.rhs)) for eq in equationCollection.mainEquations]
return equationCollection.copy(result)
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.mainAssignments]
return assignment_collection.copy(result)
def applyOnAllSubexpressions(equationCollection, operation):
result = [sp.Eq(eq.lhs, operation(eq.rhs)) for eq in equationCollection.subexpressions]
return equationCollection.copy(equationCollection.mainEquations, result)
def applyOnAllSubexpressions(assignment_collection, operation):
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.subexpressions]
return assignment_collection.copy(assignment_collection.mainAssignments, result)
def subexpressionSubstitutionInExistingSubexpressions(equationCollection):
def subexpressionSubstitutionInExistingSubexpressions(assignment_collection):
"""Goes through the subexpressions list and replaces the term in the following subexpressions"""
result = []
for outerCtr, s in enumerate(equationCollection.subexpressions):
for outerCtr, s in enumerate(assignment_collection.subexpressions):
newRhs = s.rhs
for innerCtr in range(outerCtr):
subExpr = equationCollection.subexpressions[innerCtr]
subExpr = assignment_collection.subexpressions[innerCtr]
newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0)
newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs)
result.append(sp.Eq(s.lhs, newRhs))
result.append(Assignment(s.lhs, newRhs))
return equationCollection.copy(equationCollection.mainEquations, result)
return assignment_collection.copy(assignment_collection.mainAssignments, result)
def subexpressionSubstitutionInMainEquations(equationCollection):
"""Replaces already existing subexpressions in the equations of the equationCollection"""
def subexpressionSubstitutionInmainAssignments(assignment_collection):
"""Replaces already existing subexpressions in the equations of the assignment_collection"""
result = []
for s in equationCollection.mainEquations:
for s in assignment_collection.mainAssignments:
newRhs = s.rhs
for subExpr in equationCollection.subexpressions:
for subExpr in assignment_collection.subexpressions:
newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0)
result.append(sp.Eq(s.lhs, newRhs))
return equationCollection.copy(result)
result.append(Assignment(s.lhs, newRhs))
return assignment_collection.copy(result)
def addSubexpressionsForDivisions(equationCollection):
def addSubexpressionsForDivisions(assignment_collection):
"""Introduces subexpressions for all divisions which have no constant in the denominator.
e.g. :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced."""
divisors = set()
......@@ -79,9 +79,9 @@ def addSubexpressionsForDivisions(equationCollection):
for a in term.args:
searchDivisors(a)
for eq in equationCollection.allEquations:
for eq in assignment_collection.allEquations:
searchDivisors(eq.rhs)
newSymbolGen = equationCollection.subexpressionSymbolNameGenerator
newSymbolGen = assignment_collection.subexpressionSymbolNameGenerator
substitutions = {divisor: newSymbol for newSymbol, divisor in zip(newSymbolGen, divisors)}
return equationCollection.copyWithSubstitutionsApplied(substitutions, True)
return assignment_collection.copyWithSubstitutionsApplied(substitutions, True)
......@@ -30,11 +30,11 @@ class SimplificationStrategy(object):
updateRule = t(updateRule)
return updateRule
def __call__(self, equationCollection):
def __call__(self, assignment_collection):
"""Same as apply"""
return self.apply(equationCollection)
return self.apply(assignment_collection)
def createSimplificationReport(self, equationCollection):
def createSimplificationReport(self, assignment_collection):
"""
Returns a simplification report containing the number of operations at each simplification stage, together
with the run-time the simplification took.
......@@ -72,25 +72,25 @@ class SimplificationStrategy(object):
import timeit
report = Report()
op = equationCollection.operationCount
op = assignment_collection.operationCount
total = op['adds'] + op['muls'] + op['divs']
report.add(ReportElement("OriginalTerm", '-', op['adds'], op['muls'], op['divs'], total))
for t in self._rules:
startTime = timeit.default_timer()
equationCollection = t(equationCollection)
assignment_collection = t(assignment_collection)
endTime = timeit.default_timer()
op = equationCollection.operationCount
op = assignment_collection.operationCount
timeStr = "%.2f ms" % ((endTime - startTime) * 1000,)
total = op['adds'] + op['muls'] + op['divs']
report.add(ReportElement(t.__name__, timeStr, op['adds'], op['muls'], op['divs'], total))
return report
def showIntermediateResults(self, equationCollection, symbols=None):
def showIntermediateResults(self, assignment_collection, symbols=None):
class IntermediateResults:
def __init__(self, strategy, eqColl, resSyms):
self.strategy = strategy
self.equationCollection = eqColl
self.assignment_collection = eqColl
self.restrictSymbols = resSyms
def __str__(self):
......@@ -102,8 +102,8 @@ class SimplificationStrategy(object):
text += (" " * 3 + (" " * 3).join(str(eqColl).splitlines(True)))
return text
result = printEqCollection("Initial Version", self.equationCollection)
eqColl = self.equationCollection
result = printEqCollection("Initial Version", self.assignment_collection)
eqColl = self.assignment_collection
for rule in self.strategy.rules:
eqColl = rule(eqColl)
result += printEqCollection(rule.__name__, eqColl)
......@@ -119,14 +119,14 @@ class SimplificationStrategy(object):
text += "</div>"
return text
result = printEqCollection("Initial Version", self.equationCollection)
eqColl = self.equationCollection
result = printEqCollection("Initial Version", self.assignment_collection)
eqColl = self.assignment_collection
for rule in self.strategy.rules:
eqColl = rule(eqColl)
result += printEqCollection(rule.__name__, eqColl)
return result
return IntermediateResults(self, equationCollection, symbols)
return IntermediateResults(self, assignment_collection, symbols)
def __repr__(self):
result = "Simplification Strategy:\n"
......
......@@ -91,33 +91,14 @@ def dotprint(node, view=False, short=False, full=False, **kwargs):
:param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph
:return: string in DOT format
"""
nodeToStrFunction = repr
node_to_str_function = repr
if short:
nodeToStrFunction = __shortened
node_to_str_function = __shortened
elif full:
nodeToStrFunction = lambda expr: repr(type(expr)) + repr(expr)
printer = DotPrinter(nodeToStrFunction, full, **kwargs)
node_to_str_function = lambda expr: repr(type(expr)) + repr(expr)
printer = DotPrinter(node_to_str_function, full, **kwargs)
dot = printer.doprint(node)
if view:
return graphviz.Source(dot)
return dot
if __name__ == "__main__":
from pystencils import Field
import sympy as sp
imgField = Field.createGeneric('I',
spatialDimensions=2, # 2D image
indexDimensions=1) # multiple values per pixel: e.g. RGB
w1, w2 = sp.symbols("w_1 w_2")
sobelX = -w2 * imgField[-1, 0](1) - w1 * imgField[-1, -1](1) - w1 * imgField[-1, +1](1) \
+ w2 * imgField[+1, 0](1) + w1 * imgField[+1, -1](1) - w1 * imgField[+1, +1](1)
sobelX
dstField = Field.createGeneric('dst', spatialDimensions=2, indexDimensions=0)
updateRule = sp.Eq(dstField[0, 0], sobelX)
updateRule
from pystencils import createKernel
ast = createKernel([updateRule])
print(dotprint(ast, short=True))
import sympy as sp
from pystencils import Assignment
from pystencils.boundaries.boundaryhandling import BoundaryOffsetInfo
......@@ -52,13 +52,13 @@ class Neumann(Boundary):
neighbor = BoundaryOffsetInfo.offsetFromDir(directionSymbol, field.spatialDimensions)
if field.indexDimensions == 0:
return [sp.Eq(field[neighbor], field.center)]
return [Assignment(field[neighbor], field.center)]
else:
from itertools import product
if not field.hasFixedIndexShape:
raise NotImplementedError("Neumann boundary works only for fields with fixed index shape")
indexIter = product(*(range(i) for i in field.indexShape))
return [sp.Eq(field[neighbor](*idx), field(*idx)) for idx in indexIter]
return [Assignment(field[neighbor](*idx), field(*idx)) for idx in indexIter]
def __hash__(self):
# All boundaries of these class behave equal -> should also be equal
......
import numpy as np
import sympy as sp
from pystencils.assignment import Assignment
from pystencils import Field, TypedSymbol, createIndexedKernel
from pystencils.backends.cbackend import CustomCppCode
from pystencils.boundaries.createindexlist import numpyDataTypeForBoundaryObject, createBoundaryIndexArray
......@@ -363,6 +364,6 @@ def createBoundaryKernel(field, indexField, stencil, boundaryFunctor, target='cp
elements = [BoundaryOffsetInfo(stencil)]
indexArrDtype = indexField.dtype.numpyDtype
dirSymbol = TypedSymbol("dir", indexArrDtype.fields['dir'][0])
elements += [sp.Eq(dirSymbol, indexField[0]('dir'))]
elements += [Assignment(dirSymbol, indexField[0]('dir'))]
elements += boundaryFunctor(field, directionSymbol=dirSymbol, indexField=indexField)
return createIndexedKernel(elements, [indexField], target=target, cpuOpenMP=openMP)
......@@ -33,7 +33,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol='double',
:return: :class:`pystencils.ast.KernelFunction` node
"""
def typeSymbol(term):
def type_symbol(term):
if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
......@@ -58,7 +58,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol='double',
code.target = 'cpu'
if splitGroups:
typedSplitGroups = [[typeSymbol(s) for s in splitGroup] for splitGroup in splitGroups]
typedSplitGroups = [[type_symbol(s) for s in splitGroup] for splitGroup in splitGroups]
splitInnerLoop(code, typedSplitGroups)
basePointerInfo = [['spatialInner0'], ['spatialInner1']] if len(loopOrder) >= 2 else [['spatialInner0']]
......
from pystencils.equationcollection.equationcollection import EquationCollection
from pystencils.equationcollection.simplificationstrategy import SimplificationStrategy
......@@ -5,6 +5,7 @@ import sympy as sp
from sympy.core.cache import cacheit
from sympy.tensor import IndexedBase
from pystencils.assignment import Assignment
from pystencils.alignedarray import aligned_empty
from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFromString, StructType
from pystencils.sympyextensions import isIntegerSequence
......@@ -71,10 +72,10 @@ class Field(object):
>>> src = Field.createGeneric("src", spatialDimensions=2, indexDimensions=1)
>>> dst = Field.createGeneric("dst", spatialDimensions=2, indexDimensions=1)
>>> for i, offset in enumerate(stencil):
... sp.Eq(dst[0,0](i), src[-offset](i))
Eq(dst_C^0, src_C^0)
Eq(dst_C^1, src_S^1)
Eq(dst_C^2, src_N^2)
... Assignment(dst[0,0](i), src[-offset](i))
Assignment(dst_C^0, src_C^0)
Assignment(dst_C^1, src_S^1)
Assignment(dst_C^2, src_N^2)
"""
@staticmethod
......@@ -437,22 +438,22 @@ def extractCommonSubexpressions(equations):
them in a topologically sorted order, ready for evaluation.
Usually called before list of equations is passed to :func:`createKernel`
"""
replacements, newEq = sp.cse(equations)
replacements, new_eq = sp.cse(equations)
# Workaround for older sympy versions: here subexpressions (temporary = True) a