From 9f071fdd6c6bacfdbf3e8afdecbfd615304dd67c Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 30 Mar 2018 19:33:52 +0200
Subject: [PATCH] pystencils: Assignment instead of sympy.Eq

- Previously sympy.Eq was used to represent assignments. However Eq
  represents equality not assignment. This means that sometimes sympy
  "simplified" an equation like a = a  to True,
-> replaced sp.Eq by pystencils.Assignment everywhere
- renamed EquationCollection to AssignmentCollection
---
 __init__.py                                   |  9 +-
 assignment.py                                 | 14 +++
 assignment_collection/__init__.py             |  2 +
 .../assignment_collection.py                  | 79 ++++++++---------
 assignment_collection/simplifications.py      | 87 +++++++++++++++++++
 .../simplificationstrategy.py                 | 26 +++---
 backends/dot.py                               | 27 +-----
 boundaries/boundaryconditions.py              |  6 +-
 boundaries/boundaryhandling.py                |  3 +-
 cpu/kernelcreation.py                         |  4 +-
 equationcollection/__init__.py                |  2 -
 equationcollection/simplifications.py         | 87 -------------------
 field.py                                      | 25 +++---
 finitedifferences.py                          |  6 +-
 gpucuda/periodicity.py                        |  4 +-
 kernelcreation.py                             |  8 +-
 llvm/llvm.py                                  |  2 +-
 sympyextensions.py                            | 11 +--
 transformations/transformations.py            |  3 +-
 19 files changed, 202 insertions(+), 203 deletions(-)
 create mode 100644 assignment.py
 create mode 100644 assignment_collection/__init__.py
 rename equationcollection/equationcollection.py => assignment_collection/assignment_collection.py (81%)
 create mode 100644 assignment_collection/simplifications.py
 rename {equationcollection => assignment_collection}/simplificationstrategy.py (87%)
 delete mode 100644 equationcollection/__init__.py
 delete mode 100644 equationcollection/simplifications.py

diff --git a/__init__.py b/__init__.py
index d8b3397a3..36cc17032 100644
--- a/__init__.py
+++ b/__init__.py
@@ -3,14 +3,13 @@ from pystencils.data_types import TypedSymbol
 from pystencils.slicing import makeSlice
 from pystencils.kernelcreation import createKernel, createIndexedKernel
 from pystencils.display_utils import showCode, toDot
-from pystencils.equationcollection import EquationCollection
-from sympy.codegen.ast import Assignment as Assign
-
+from pystencils.assignment_collection import AssignmentCollection
+from pystencils.assignment import Assignment
 
 __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions',
            'TypedSymbol',
            'makeSlice',
            'createKernel', 'createIndexedKernel',
            'showCode', 'toDot',
-           'EquationCollection',
-           'Assign']
+           'AssignmentCollection',
+           'Assignment']
diff --git a/assignment.py b/assignment.py
new file mode 100644
index 000000000..135a5d126
--- /dev/null
+++ b/assignment.py
@@ -0,0 +1,14 @@
+from sympy.codegen.ast import Assignment
+from sympy.printing.latex import LatexPrinter
+
+__all__ = ['Assignment']
+
+
+def print_assignment_latex(printer, expr):
+    """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer"""
+    printed_lhs = printer.doprint(expr.lhs)
+    printed_rhs = printer.doprint(expr.rhs)
+    return f"{printed_lhs} \leftarrow {printed_rhs}"
+
+
+LatexPrinter._print_Assignment = print_assignment_latex
diff --git a/assignment_collection/__init__.py b/assignment_collection/__init__.py
new file mode 100644
index 000000000..a71a7d05c
--- /dev/null
+++ b/assignment_collection/__init__.py
@@ -0,0 +1,2 @@
+from pystencils.assignment_collection.assignment_collection import AssignmentCollection
+from pystencils.assignment_collection.simplificationstrategy import SimplificationStrategy
diff --git a/equationcollection/equationcollection.py b/assignment_collection/assignment_collection.py
similarity index 81%
rename from equationcollection/equationcollection.py
rename to assignment_collection/assignment_collection.py
index 1609c95d6..ea95cf6b9 100644
--- a/equationcollection/equationcollection.py
+++ b/assignment_collection/assignment_collection.py
@@ -1,18 +1,19 @@
 import sympy as sp
 from copy import copy
+from pystencils.assignment import Assignment
 from pystencils.sympyextensions import fastSubs, countNumberOfOperations, sortEquationsTopologically
 
 
-class EquationCollection(object):
+class AssignmentCollection(object):
     """
     A collection of equations with subexpression definitions, also represented as equations,
-    that are used in the main equations. EquationCollections can be passed to simplification methods.
+    that are used in the main equations. AssignmentCollection can be passed to simplification methods.
     These simplification methods can change the subexpressions, but the number and
     left hand side of the main equations themselves is not altered.
     Additionally a dictionary of simplification hints is stored, which are set by the functions that create
     equation collections to transport information to the simplification system.
 
-    :ivar mainEquations: list of sympy equations
+    :ivar mainAssignments: list of sympy equations
     :ivar subexpressions: list of sympy equations defining subexpressions used in main equations
     :ivar simplificationHints: dictionary that is used to annotate the equation collection with hints that are
                                used by the simplification system. See documentation of the simplification rules for
@@ -22,7 +23,7 @@ class EquationCollection(object):
     # ----------------------------------------- Creation ---------------------------------------------------------------
 
     def __init__(self, equations, subExpressions, simplificationHints=None, subexpressionSymbolNameGenerator=None):
-        self.mainEquations = equations
+        self.mainAssignments = equations
         self.subexpressions = subExpressions
 
         if simplificationHints is None:
@@ -39,15 +40,15 @@ class EquationCollection(object):
     def mainTerms(self):
         return []
 
-    def copy(self, mainEquations=None, subexpressions=None):
+    def copy(self, mainAssignments=None, subexpressions=None):
         res = copy(self)
         res.simplificationHints = self.simplificationHints.copy()
         res.subexpressionSymbolNameGenerator = copy(self.subexpressionSymbolNameGenerator)
 
-        if mainEquations is not None:
-            res.mainEquations = mainEquations
+        if mainAssignments is not None:
+            res.mainAssignments = mainAssignments
         else:
-            res.mainEquations = self.mainEquations.copy()
+            res.mainAssignments = self.mainAssignments.copy()
 
         if subexpressions is not None:
             res.subexpressions = subexpressions
@@ -64,13 +65,13 @@ class EquationCollection(object):
         """
         if substituteOnLhs:
             newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions]
-            newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations]
+            newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainAssignments]
         else:
-            newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.subexpressions]
-            newEquations = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.mainEquations]
+            newSubexpressions = [Assignment(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.subexpressions]
+            newEquations = [Assignment(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.mainAssignments]
 
         if addSubstitutionsAsSubexpressions:
-            newSubexpressions = [sp.Eq(b, a) for a, b in substitutionDict.items()] + newSubexpressions
+            newSubexpressions = [Assignment(b, a) for a, b in substitutionDict.items()] + newSubexpressions
             newSubexpressions = sortEquationsTopologically(newSubexpressions)
         return self.copy(newEquations, newSubexpressions)
 
@@ -86,7 +87,7 @@ class EquationCollection(object):
     @property
     def allEquations(self):
         """Subexpression and main equations in one sequence"""
-        return self.subexpressions + self.mainEquations
+        return self.subexpressions + self.mainAssignments
 
     @property
     def freeSymbols(self):
@@ -100,30 +101,30 @@ class EquationCollection(object):
     def boundSymbols(self):
         """Set of all symbols which occur on left-hand-sides i.e. all symbols which are defined."""
         boundSymbolsSet = set([eq.lhs for eq in self.allEquations])
-        assert len(boundSymbolsSet) == len(self.subexpressions) + len(self.mainEquations), \
+        assert len(boundSymbolsSet) == len(self.subexpressions) + len(self.mainAssignments), \
             "Not in SSA form - same symbol assigned multiple times"
         return boundSymbolsSet
 
     @property
     def definedSymbols(self):
         """All symbols that occur as left-hand-sides of the main equations"""
-        return set([eq.lhs for eq in self.mainEquations])
+        return set([eq.lhs for eq in self.mainAssignments])
 
     @property
     def operationCount(self):
         """See :func:`countNumberOfOperations` """
         return countNumberOfOperations(self.allEquations, onlyType=None)
 
-    def get(self, symbols, fromMainEquationsOnly=False):
+    def get(self, symbols, frommainAssignmentsOnly=False):
         """Return the equations which have symbols as left hand sides"""
         if not hasattr(symbols, "__len__"):
             symbols = list(symbols)
         symbols = set(symbols)
 
-        if not fromMainEquationsOnly:
+        if not frommainAssignmentsOnly:
             eqsToSearchIn = self.allEquations
         else:
-            eqsToSearchIn = self.mainEquations
+            eqsToSearchIn = self.mainAssignments
 
         return [eq for eq in eqsToSearchIn if eq.lhs in symbols]
 
@@ -145,19 +146,19 @@ class EquationCollection(object):
         if len(self.subexpressions) > 0:
             result += "<div>Subexpressions:</div>"
             result += makeHtmlEquationTable(self.subexpressions)
-        result += "<div>Main Equations:</div>"
-        result += makeHtmlEquationTable(self.mainEquations)
+        result += "<div>Main Assignments:</div>"
+        result += makeHtmlEquationTable(self.mainAssignments)
         return result
 
     def __repr__(self):
-        return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainEquations])
+        return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainAssignments])
 
     def __str__(self):
         result = "Subexpressions\n"
         for eq in self.subexpressions:
             result += str(eq) + "\n"
-        result += "Main Equations\n"
-        for eq in self.mainEquations:
+        result += "Main Assignments\n"
+        for eq in self.mainAssignments:
             result += str(eq) + "\n"
         return result
 
@@ -165,8 +166,8 @@ class EquationCollection(object):
 
     def merge(self, other):
         """Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
-        ownDefs = set([e.lhs for e in self.mainEquations])
-        otherDefs = set([e.lhs for e in other.mainEquations])
+        ownDefs = set([e.lhs for e in self.mainAssignments])
+        otherDefs = set([e.lhs for e in other.mainAssignments])
         assert len(ownDefs.intersection(otherDefs)) == 0, "Cannot merge, since both collection define the same symbols"
 
         ownSubexpressionSymbols = {e.lhs: e.rhs for e in self.subexpressions}
@@ -180,14 +181,14 @@ class EquationCollection(object):
                 else:
                     # different definition - a new name has to be introduced
                     newLhs = next(self.subexpressionSymbolNameGenerator)
-                    newEq = sp.Eq(newLhs, fastSubs(otherSubexpressionEq.rhs, substitutionDict))
+                    newEq = Assignment(newLhs, fastSubs(otherSubexpressionEq.rhs, substitutionDict))
                     processedOtherSubexpressionEquations.append(newEq)
                     substitutionDict[otherSubexpressionEq.lhs] = newLhs
             else:
                 processedOtherSubexpressionEquations.append(fastSubs(otherSubexpressionEq, substitutionDict))
 
-        processedOtherMainEquations = [fastSubs(eq, substitutionDict) for eq in other.mainEquations]
-        return self.copy(self.mainEquations + processedOtherMainEquations,
+        processedOthermainAssignments = [fastSubs(eq, substitutionDict) for eq in other.mainAssignments]
+        return self.copy(self.mainAssignments + processedOthermainAssignments,
                          self.subexpressions + processedOtherSubexpressionEquations)
 
     def getDependentSymbols(self, symbolSequence):
@@ -226,28 +227,28 @@ class EquationCollection(object):
                 newEquations.append(eq)
 
         newSubExpr = [eq for eq in self.subexpressions if eq.lhs in dependentSymbols and eq.lhs not in symbolsToExtract]
-        return EquationCollection(newEquations, newSubExpr)
+        return AssignmentCollection(newEquations, newSubExpr)
 
     def newWithoutUnusedSubexpressions(self):
         """Returns a new equation collection containing only the subexpressions that
         are used/referenced in the equations"""
-        allLhs = [eq.lhs for eq in self.mainEquations]
+        allLhs = [eq.lhs for eq in self.mainAssignments]
         return self.extract(allLhs)
 
     def appendToSubexpressions(self, rhs, lhs=None, topologicalSort=True):
         if lhs is None:
             lhs = sp.Dummy()
-        eq = sp.Eq(lhs, rhs)
+        eq = Assignment(lhs, rhs)
         self.subexpressions.append(eq)
         if topologicalSort:
-            self.topologicalSort(subexpressions=True, mainEquations=False)
+            self.topologicalSort(subexpressions=True, mainAssignments=False)
         return lhs
 
-    def topologicalSort(self, subexpressions=True, mainEquations=True):
+    def topologicalSort(self, subexpressions=True, mainAssignments=True):
         if subexpressions:
             self.subexpressions = sortEquationsTopologically(self.subexpressions)
-        if mainEquations:
-            self.mainEquations = sortEquationsTopologically(self.mainEquations)
+        if mainAssignments:
+            self.mainAssignments = sortEquationsTopologically(self.mainAssignments)
 
     def insertSubexpression(self, symbol):
         newSubexpressions = []
@@ -260,8 +261,8 @@ class EquationCollection(object):
         if subsDict is None:
             return self
 
-        newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in newSubexpressions]
-        newEqs = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in self.mainEquations]
+        newSubexpressions = [Assignment(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in newSubexpressions]
+        newEqs = [Assignment(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in self.mainAssignments]
         return self.copy(newEqs, newSubexpressions)
 
     def insertSubexpressions(self, subexpressionSymbolsToKeep=set()):
@@ -286,7 +287,7 @@ class EquationCollection(object):
             else:
                 subsDict[subExpr[i].lhs] = subExpr[i].rhs
 
-        newEq = [fastSubs(eq, subsDict) for eq in self.mainEquations]
+        newEq = [fastSubs(eq, subsDict) for eq in self.mainAssignments]
         return self.copy(newEq, keptSubexpressions)
 
     def lambdify(self, symbols, module=None, fixedSymbols={}):
@@ -296,7 +297,7 @@ class EquationCollection(object):
         :param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g. 'numpy'
         :param fixedSymbols: dictionary with substitutions, that are applied before lambdification
         """
-        eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations
+        eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainAssignments
         lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs}
 
         def f(*args, **kwargs):
diff --git a/assignment_collection/simplifications.py b/assignment_collection/simplifications.py
new file mode 100644
index 000000000..7d0eab53d
--- /dev/null
+++ b/assignment_collection/simplifications.py
@@ -0,0 +1,87 @@
+import sympy as sp
+
+from pystencils import Assignment, AssignmentCollection
+from pystencils.sympyextensions import replaceAdditive
+
+
+def sympyCseOnEquationList(eqs):
+    ec = AssignmentCollection(eqs, [])
+    return sympyCSE(ec).allEquations
+
+
+def sympyCSE(assignment_collection):
+    """
+    Searches for common subexpressions inside the equation collection, in both the existing subexpressions as well
+    as the equations themselves. It uses the sympy subexpression detection to do this. Return a new equation collection
+    with the additional subexpressions found
+    """
+    symbolGen = assignment_collection.subexpressionSymbolNameGenerator
+    replacements, newEq = sp.cse(assignment_collection.subexpressions + assignment_collection.mainAssignments,
+                                 symbols=symbolGen)
+    replacementEqs = [Assignment(*r) for r in replacements]
+
+    modifiedSubexpressions = newEq[:len(assignment_collection.subexpressions)]
+    modifiedUpdateEquations = newEq[len(assignment_collection.subexpressions):]
+
+    newSubexpressions = replacementEqs + modifiedSubexpressions
+    topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions])
+    newSubexpressions = [Assignment(a[0], a[1]) for a in topologicallySortedPairs]
+
+    return assignment_collection.copy(modifiedUpdateEquations, newSubexpressions)
+
+
+def applyOnAllEquations(assignment_collection, operation):
+    """Applies sympy expand operation to all equations in collection"""
+    result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.mainAssignments]
+    return assignment_collection.copy(result)
+
+
+def applyOnAllSubexpressions(assignment_collection, operation):
+    result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.subexpressions]
+    return assignment_collection.copy(assignment_collection.mainAssignments, result)
+
+
+def subexpressionSubstitutionInExistingSubexpressions(assignment_collection):
+    """Goes through the subexpressions list and replaces the term in the following subexpressions"""
+    result = []
+    for outerCtr, s in enumerate(assignment_collection.subexpressions):
+        newRhs = s.rhs
+        for innerCtr in range(outerCtr):
+            subExpr = assignment_collection.subexpressions[innerCtr]
+            newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0)
+            newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs)
+        result.append(Assignment(s.lhs, newRhs))
+
+    return assignment_collection.copy(assignment_collection.mainAssignments, result)
+
+
+def subexpressionSubstitutionInmainAssignments(assignment_collection):
+    """Replaces already existing subexpressions in the equations of the assignment_collection"""
+    result = []
+    for s in assignment_collection.mainAssignments:
+        newRhs = s.rhs
+        for subExpr in assignment_collection.subexpressions:
+            newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0)
+        result.append(Assignment(s.lhs, newRhs))
+    return assignment_collection.copy(result)
+
+
+def addSubexpressionsForDivisions(assignment_collection):
+    """Introduces subexpressions for all divisions which have no constant in the denominator.
+    e.g.  :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced."""
+    divisors = set()
+
+    def searchDivisors(term):
+        if term.func == sp.Pow:
+            if term.exp.is_integer and term.exp.is_number and term.exp < 0:
+                divisors.add(term)
+        else:
+            for a in term.args:
+                searchDivisors(a)
+
+    for eq in assignment_collection.allEquations:
+        searchDivisors(eq.rhs)
+
+    newSymbolGen = assignment_collection.subexpressionSymbolNameGenerator
+    substitutions = {divisor: newSymbol for newSymbol, divisor in zip(newSymbolGen, divisors)}
+    return assignment_collection.copyWithSubstitutionsApplied(substitutions, True)
diff --git a/equationcollection/simplificationstrategy.py b/assignment_collection/simplificationstrategy.py
similarity index 87%
rename from equationcollection/simplificationstrategy.py
rename to assignment_collection/simplificationstrategy.py
index 38f4c3f8b..3d8cdd62f 100644
--- a/equationcollection/simplificationstrategy.py
+++ b/assignment_collection/simplificationstrategy.py
@@ -30,11 +30,11 @@ class SimplificationStrategy(object):
             updateRule = t(updateRule)
         return updateRule
 
-    def __call__(self, equationCollection):
+    def __call__(self, assignment_collection):
         """Same as apply"""
-        return self.apply(equationCollection)
+        return self.apply(assignment_collection)
 
-    def createSimplificationReport(self, equationCollection):
+    def createSimplificationReport(self, assignment_collection):
         """
         Returns a simplification report containing the number of operations at each simplification stage, together
         with the run-time the simplification took.
@@ -72,25 +72,25 @@ class SimplificationStrategy(object):
 
         import timeit
         report = Report()
-        op = equationCollection.operationCount
+        op = assignment_collection.operationCount
         total = op['adds'] + op['muls'] + op['divs']
         report.add(ReportElement("OriginalTerm",  '-', op['adds'], op['muls'], op['divs'], total))
         for t in self._rules:
             startTime = timeit.default_timer()
-            equationCollection = t(equationCollection)
+            assignment_collection = t(assignment_collection)
             endTime = timeit.default_timer()
-            op = equationCollection.operationCount
+            op = assignment_collection.operationCount
             timeStr = "%.2f ms" % ((endTime - startTime) * 1000,)
             total = op['adds'] + op['muls'] + op['divs']
             report.add(ReportElement(t.__name__, timeStr, op['adds'], op['muls'], op['divs'], total))
         return report
 
-    def showIntermediateResults(self, equationCollection, symbols=None):
+    def showIntermediateResults(self, assignment_collection, symbols=None):
 
         class IntermediateResults:
             def __init__(self, strategy, eqColl, resSyms):
                 self.strategy = strategy
-                self.equationCollection = eqColl
+                self.assignment_collection = eqColl
                 self.restrictSymbols = resSyms
 
             def __str__(self):
@@ -102,8 +102,8 @@ class SimplificationStrategy(object):
                         text += (" " * 3 + (" " * 3).join(str(eqColl).splitlines(True)))
                     return text
 
-                result = printEqCollection("Initial Version", self.equationCollection)
-                eqColl = self.equationCollection
+                result = printEqCollection("Initial Version", self.assignment_collection)
+                eqColl = self.assignment_collection
                 for rule in self.strategy.rules:
                     eqColl = rule(eqColl)
                     result += printEqCollection(rule.__name__, eqColl)
@@ -119,14 +119,14 @@ class SimplificationStrategy(object):
                     text += "</div>"
                     return text
 
-                result = printEqCollection("Initial Version", self.equationCollection)
-                eqColl = self.equationCollection
+                result = printEqCollection("Initial Version", self.assignment_collection)
+                eqColl = self.assignment_collection
                 for rule in self.strategy.rules:
                     eqColl = rule(eqColl)
                     result += printEqCollection(rule.__name__, eqColl)
                 return result
 
-        return IntermediateResults(self, equationCollection, symbols)
+        return IntermediateResults(self, assignment_collection, symbols)
 
     def __repr__(self):
         result = "Simplification Strategy:\n"
diff --git a/backends/dot.py b/backends/dot.py
index 8797c49f0..e78ac1bb0 100644
--- a/backends/dot.py
+++ b/backends/dot.py
@@ -91,33 +91,14 @@ def dotprint(node, view=False, short=False, full=False, **kwargs):
     :param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph
     :return: string in DOT format
     """
-    nodeToStrFunction = repr
+    node_to_str_function = repr
     if short:
-        nodeToStrFunction = __shortened
+        node_to_str_function = __shortened
     elif full:
-        nodeToStrFunction = lambda expr: repr(type(expr)) + repr(expr)
-    printer = DotPrinter(nodeToStrFunction, full, **kwargs)
+        node_to_str_function = lambda expr: repr(type(expr)) + repr(expr)
+    printer = DotPrinter(node_to_str_function, full, **kwargs)
     dot = printer.doprint(node)
     if view:
         return graphviz.Source(dot)
     return dot
 
-
-if __name__ == "__main__":
-    from pystencils import Field
-    import sympy as sp
-    imgField = Field.createGeneric('I',
-                                   spatialDimensions=2, # 2D image
-                                   indexDimensions=1)   # multiple values per pixel: e.g. RGB
-    w1, w2 = sp.symbols("w_1 w_2")
-    sobelX = -w2 * imgField[-1, 0](1) - w1 * imgField[-1, -1](1) - w1 * imgField[-1, +1](1) \
-             + w2 * imgField[+1, 0](1) + w1 * imgField[+1, -1](1) - w1 * imgField[+1, +1](1)
-    sobelX
-
-    dstField = Field.createGeneric('dst', spatialDimensions=2, indexDimensions=0)
-    updateRule = sp.Eq(dstField[0, 0], sobelX)
-    updateRule
-
-    from pystencils import createKernel
-    ast = createKernel([updateRule])
-    print(dotprint(ast, short=True))
diff --git a/boundaries/boundaryconditions.py b/boundaries/boundaryconditions.py
index 0a799cdac..1cc67d17c 100644
--- a/boundaries/boundaryconditions.py
+++ b/boundaries/boundaryconditions.py
@@ -1,4 +1,4 @@
-import sympy as sp
+from pystencils import Assignment
 from pystencils.boundaries.boundaryhandling import BoundaryOffsetInfo
 
 
@@ -52,13 +52,13 @@ class Neumann(Boundary):
 
         neighbor = BoundaryOffsetInfo.offsetFromDir(directionSymbol, field.spatialDimensions)
         if field.indexDimensions == 0:
-            return [sp.Eq(field[neighbor], field.center)]
+            return [Assignment(field[neighbor], field.center)]
         else:
             from itertools import product
             if not field.hasFixedIndexShape:
                 raise NotImplementedError("Neumann boundary works only for fields with fixed index shape")
             indexIter = product(*(range(i) for i in field.indexShape))
-            return [sp.Eq(field[neighbor](*idx), field(*idx)) for idx in indexIter]
+            return [Assignment(field[neighbor](*idx), field(*idx)) for idx in indexIter]
 
     def __hash__(self):
         # All boundaries of these class behave equal -> should also be equal
diff --git a/boundaries/boundaryhandling.py b/boundaries/boundaryhandling.py
index 7787346c3..76999a549 100644
--- a/boundaries/boundaryhandling.py
+++ b/boundaries/boundaryhandling.py
@@ -1,5 +1,6 @@
 import numpy as np
 import sympy as sp
+from pystencils.assignment import Assignment
 from pystencils import Field, TypedSymbol, createIndexedKernel
 from pystencils.backends.cbackend import CustomCppCode
 from pystencils.boundaries.createindexlist import numpyDataTypeForBoundaryObject, createBoundaryIndexArray
@@ -363,6 +364,6 @@ def createBoundaryKernel(field, indexField, stencil, boundaryFunctor, target='cp
     elements = [BoundaryOffsetInfo(stencil)]
     indexArrDtype = indexField.dtype.numpyDtype
     dirSymbol = TypedSymbol("dir", indexArrDtype.fields['dir'][0])
-    elements += [sp.Eq(dirSymbol, indexField[0]('dir'))]
+    elements += [Assignment(dirSymbol, indexField[0]('dir'))]
     elements += boundaryFunctor(field, directionSymbol=dirSymbol, indexField=indexField)
     return createIndexedKernel(elements, [indexField], target=target, cpuOpenMP=openMP)
diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py
index 0c96c1f45..c4c2833b4 100644
--- a/cpu/kernelcreation.py
+++ b/cpu/kernelcreation.py
@@ -33,7 +33,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol='double',
     :return: :class:`pystencils.ast.KernelFunction` node
     """
 
-    def typeSymbol(term):
+    def type_symbol(term):
         if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
             return term
         elif isinstance(term, sp.Symbol):
@@ -58,7 +58,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol='double',
     code.target = 'cpu'
 
     if splitGroups:
-        typedSplitGroups = [[typeSymbol(s) for s in splitGroup] for splitGroup in splitGroups]
+        typedSplitGroups = [[type_symbol(s) for s in splitGroup] for splitGroup in splitGroups]
         splitInnerLoop(code, typedSplitGroups)
 
     basePointerInfo = [['spatialInner0'], ['spatialInner1']] if len(loopOrder) >= 2 else [['spatialInner0']]
diff --git a/equationcollection/__init__.py b/equationcollection/__init__.py
deleted file mode 100644
index 434f02e21..000000000
--- a/equationcollection/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from pystencils.equationcollection.equationcollection import EquationCollection
-from pystencils.equationcollection.simplificationstrategy import SimplificationStrategy
diff --git a/equationcollection/simplifications.py b/equationcollection/simplifications.py
deleted file mode 100644
index 0e1e92f48..000000000
--- a/equationcollection/simplifications.py
+++ /dev/null
@@ -1,87 +0,0 @@
-import sympy as sp
-
-from pystencils.equationcollection.equationcollection import EquationCollection
-from pystencils.sympyextensions import replaceAdditive
-
-
-def sympyCseOnEquationList(eqs):
-    ec = EquationCollection(eqs, [])
-    return sympyCSE(ec).allEquations
-
-
-def sympyCSE(equationCollection):
-    """
-    Searches for common subexpressions inside the equation collection, in both the existing subexpressions as well
-    as the equations themselves. It uses the sympy subexpression detection to do this. Return a new equation collection
-    with the additional subexpressions found
-    """
-    symbolGen = equationCollection.subexpressionSymbolNameGenerator
-    replacements, newEq = sp.cse(equationCollection.subexpressions + equationCollection.mainEquations,
-                                 symbols=symbolGen)
-    replacementEqs = [sp.Eq(*r) for r in replacements]
-
-    modifiedSubexpressions = newEq[:len(equationCollection.subexpressions)]
-    modifiedUpdateEquations = newEq[len(equationCollection.subexpressions):]
-
-    newSubexpressions = replacementEqs + modifiedSubexpressions
-    topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions])
-    newSubexpressions = [sp.Eq(a[0], a[1]) for a in topologicallySortedPairs]
-
-    return equationCollection.copy(modifiedUpdateEquations, newSubexpressions)
-
-
-def applyOnAllEquations(equationCollection, operation):
-    """Applies sympy expand operation to all equations in collection"""
-    result = [sp.Eq(eq.lhs, operation(eq.rhs)) for eq in equationCollection.mainEquations]
-    return equationCollection.copy(result)
-
-
-def applyOnAllSubexpressions(equationCollection, operation):
-    result = [sp.Eq(eq.lhs, operation(eq.rhs)) for eq in equationCollection.subexpressions]
-    return equationCollection.copy(equationCollection.mainEquations, result)
-
-
-def subexpressionSubstitutionInExistingSubexpressions(equationCollection):
-    """Goes through the subexpressions list and replaces the term in the following subexpressions"""
-    result = []
-    for outerCtr, s in enumerate(equationCollection.subexpressions):
-        newRhs = s.rhs
-        for innerCtr in range(outerCtr):
-            subExpr = equationCollection.subexpressions[innerCtr]
-            newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0)
-            newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs)
-        result.append(sp.Eq(s.lhs, newRhs))
-
-    return equationCollection.copy(equationCollection.mainEquations, result)
-
-
-def subexpressionSubstitutionInMainEquations(equationCollection):
-    """Replaces already existing subexpressions in the equations of the equationCollection"""
-    result = []
-    for s in equationCollection.mainEquations:
-        newRhs = s.rhs
-        for subExpr in equationCollection.subexpressions:
-            newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0)
-        result.append(sp.Eq(s.lhs, newRhs))
-    return equationCollection.copy(result)
-
-
-def addSubexpressionsForDivisions(equationCollection):
-    """Introduces subexpressions for all divisions which have no constant in the denominator.
-    e.g.  :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced."""
-    divisors = set()
-
-    def searchDivisors(term):
-        if term.func == sp.Pow:
-            if term.exp.is_integer and term.exp.is_number and term.exp < 0:
-                divisors.add(term)
-        else:
-            for a in term.args:
-                searchDivisors(a)
-
-    for eq in equationCollection.allEquations:
-        searchDivisors(eq.rhs)
-
-    newSymbolGen = equationCollection.subexpressionSymbolNameGenerator
-    substitutions = {divisor: newSymbol for newSymbol, divisor in zip(newSymbolGen, divisors)}
-    return equationCollection.copyWithSubstitutionsApplied(substitutions, True)
diff --git a/field.py b/field.py
index 24a51755f..d310f9b77 100644
--- a/field.py
+++ b/field.py
@@ -5,6 +5,7 @@ import sympy as sp
 from sympy.core.cache import cacheit
 from sympy.tensor import IndexedBase
 
+from pystencils.assignment import Assignment
 from pystencils.alignedarray import aligned_empty
 from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFromString, StructType
 from pystencils.sympyextensions import isIntegerSequence
@@ -71,10 +72,10 @@ class Field(object):
         >>> src = Field.createGeneric("src", spatialDimensions=2, indexDimensions=1)
         >>> dst = Field.createGeneric("dst", spatialDimensions=2, indexDimensions=1)
         >>> for i, offset in enumerate(stencil):
-        ...     sp.Eq(dst[0,0](i), src[-offset](i))
-        Eq(dst_C^0, src_C^0)
-        Eq(dst_C^1, src_S^1)
-        Eq(dst_C^2, src_N^2)
+        ...     Assignment(dst[0,0](i), src[-offset](i))
+        Assignment(dst_C^0, src_C^0)
+        Assignment(dst_C^1, src_S^1)
+        Assignment(dst_C^2, src_N^2)
     """
 
     @staticmethod
@@ -437,22 +438,22 @@ def extractCommonSubexpressions(equations):
     them in a topologically sorted order, ready for evaluation.
     Usually called before list of equations is passed to :func:`createKernel`
     """
-    replacements, newEq = sp.cse(equations)
+    replacements, new_eq = sp.cse(equations)
     # Workaround for older sympy versions: here subexpressions (temporary = True) are extracted
     # which leads to problems in Piecewise functions which have to a default case indicated by True
-    symbolsEqualToTrue = {r[0]: True for r in replacements if r[1] is sp.true}
+    symbols_equal_to_true = {r[0]: True for r in replacements if r[1] is sp.true}
 
-    replacementEqs = [sp.Eq(*r) for r in replacements if r[1] is not sp.true]
-    equations = replacementEqs + newEq
-    topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equations])
-    equations = [sp.Eq(a[0], a[1].subs(symbolsEqualToTrue)) for a in topologicallySortedPairs]
+    replacement_eqs = [Assignment(*r) for r in replacements if r[1] is not sp.true]
+    equations = replacement_eqs + new_eq
+    topologically_sorted_pairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equations])
+    equations = [Assignment(a[0], a[1].subs(symbols_equal_to_true)) for a in topologically_sorted_pairs]
     return equations
 
 
 def getLayoutFromStrides(strides, indexDimensionIds=[]):
     coordinates = list(range(len(strides)))
-    relevantStrides = [stride for i, stride in enumerate(strides) if i not in indexDimensionIds]
-    result = [x for (y, x) in sorted(zip(relevantStrides, coordinates), key=lambda pair: pair[0], reverse=True)]
+    relevant_strides = [stride for i, stride in enumerate(strides) if i not in indexDimensionIds]
+    result = [x for (y, x) in sorted(zip(relevant_strides, coordinates), key=lambda pair: pair[0], reverse=True)]
     return normalizeLayout(result)
 
 
diff --git a/finitedifferences.py b/finitedifferences.py
index 039a603e1..8dba62707 100644
--- a/finitedifferences.py
+++ b/finitedifferences.py
@@ -1,7 +1,7 @@
 import numpy as np
 import sympy as sp
 
-from pystencils.equationcollection import EquationCollection
+from pystencils.assignment_collection import AssignmentCollection
 from pystencils.field import Field
 from pystencils.transformations import fastSubs
 from pystencils.derivative import Diff
@@ -355,8 +355,8 @@ class Discretization2ndOrder:
             return [self(e) for e in expr]
         elif isinstance(expr, sp.Matrix):
             return expr.applyfunc(self.__call__)
-        elif isinstance(expr, EquationCollection):
-            return expr.copy(mainEquations=[e for e in expr.mainEquations],
+        elif isinstance(expr, AssignmentCollection):
+            return expr.copy(mainAssignments=[e for e in expr.mainAssignments],
                              subexpressions=[e for e in expr.subexpressions])
 
         transientTerms = expr.atoms(Transient)
diff --git a/gpucuda/periodicity.py b/gpucuda/periodicity.py
index 039f21799..5009fce89 100644
--- a/gpucuda/periodicity.py
+++ b/gpucuda/periodicity.py
@@ -1,6 +1,6 @@
 import sympy as sp
 import numpy as np
-from pystencils import Field
+from pystencils import Field, Assignment
 from pystencils.slicing import normalizeSlice, getPeriodicBoundarySrcDstSlices
 from pystencils.gpucuda import makePythonFunction
 from pystencils.gpucuda.kernelcreation import createCUDAKernel
@@ -20,7 +20,7 @@ def createCopyKernel(domainSize, fromSlice, toSlice, indexDimensions=0, indexDim
 
     updateEqs = []
     for i in range(indexDimShape):
-        eq = sp.Eq(f(i), f[tuple(offset)](i))
+        eq = Assignment(f(i), f[tuple(offset)](i))
         updateEqs.append(eq)
 
     ast = createCUDAKernel(updateEqs, iterationSlice=toSlice)
diff --git a/kernelcreation.py b/kernelcreation.py
index 7d27ed6c9..9c84347d9 100644
--- a/kernelcreation.py
+++ b/kernelcreation.py
@@ -1,4 +1,4 @@
-from pystencils.equationcollection import EquationCollection
+from pystencils.assignment_collection import AssignmentCollection
 from pystencils.gpucuda.indexing import indexingCreatorFromParams
 
 
@@ -7,7 +7,7 @@ def createKernel(equations, target='cpu', dataType="double", iterationSlice=None
                  gpuIndexing='block', gpuIndexingParams={}):
     """
     Creates abstract syntax tree (AST) of kernel, using a list of update equations.
-    :param equations: either be a plain list of equations or a EquationCollection object
+    :param equations: either be a plain list of equations or a AssignmentCollection object
     :param target: 'cpu', 'llvm' or 'gpu'
     :param dataType: data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name
                      to type
@@ -32,7 +32,7 @@ def createKernel(equations, target='cpu', dataType="double", iterationSlice=None
 
     # ----  Normalizing parameters
     splitGroups = ()
-    if isinstance(equations, EquationCollection):
+    if isinstance(equations, AssignmentCollection):
         if 'splitGroups' in equations.simplificationHints:
             splitGroups = equations.simplificationHints['splitGroups']
         equations = equations.allEquations
@@ -83,7 +83,7 @@ def createIndexedKernel(equations, indexFields, target='cpu', dataType="double",
     coordinateNames: name of the coordinate fields in the struct data type
     """
 
-    if isinstance(equations, EquationCollection):
+    if isinstance(equations, AssignmentCollection):
         equations = equations.allEquations
     if target == 'cpu':
         from pystencils.cpu import createIndexedKernel
diff --git a/llvm/llvm.py b/llvm/llvm.py
index 6529c699e..08ad220bb 100644
--- a/llvm/llvm.py
+++ b/llvm/llvm.py
@@ -9,7 +9,7 @@ from pystencils.llvm.control_flow import Loop
 from pystencils.data_types import createType, to_llvm_type, getTypeOfExpression, collateTypes, \
     createCompositeTypeFromString
 from sympy import Indexed
-from sympy.codegen.ast import Assignment
+from pystencils.assignment import Assignment
 
 
 def generateLLVM(ast_node, module=None, builder=None):
diff --git a/sympyextensions.py b/sympyextensions.py
index 6e29cabf9..d407d8f61 100644
--- a/sympyextensions.py
+++ b/sympyextensions.py
@@ -6,6 +6,7 @@ import warnings
 import sympy as sp
 
 from pystencils.data_types import getTypeOfExpression, getBaseType
+from pystencils.assignment import Assignment
 
 
 def prod(seq):
@@ -264,7 +265,7 @@ def replaceSecondOrderProducts(expr, searchSymbols, positive=None, replaceMixed=
                 mixedSymbol = sp.Symbol(mixedSymbolName.replace("_", ""))
                 if mixedSymbol not in mixedSymbolsReplaced:
                     mixedSymbolsReplaced.add(mixedSymbol)
-                    replaceMixed.append(sp.Eq(mixedSymbol, u + sign * v))
+                    replaceMixed.append(Assignment(mixedSymbol, u + sign * v))
             else:
                 mixedSymbol = u + sign * v
             return sp.Rational(1, 2) * sign * otherFactors * (mixedSymbol ** 2 - u ** 2 - v ** 2)
@@ -431,7 +432,7 @@ def countNumberOfOperations(term, onlyType='real'):
             for operationName in result.keys():
                 result[operationName] += r[operationName]
         return result
-    elif isinstance(term, sp.Eq):
+    elif isinstance(term, Assignment):
         term = term.rhs
 
     term = term.evalf()
@@ -538,7 +539,7 @@ def getSymmetricPart(term, vars):
 
 def sortEquationsTopologically(equationSequence):
     res = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equationSequence])
-    return [sp.Eq(a, b) for a, b in res]
+    return [Assignment(a, b) for a, b in res]
 
 
 def getEquationsFromFunction(func, **kwargs):
@@ -559,7 +560,7 @@ def getEquationsFromFunction(func, **kwargs):
     ...     S.neighbors @= f[0,1] + f[1,0]
     ...     g[0,0]      @= S.neighbors + f[0,0]
     >>> getEquationsFromFunction(myKernel)
-    [Eq(neighbors, f_E + f_N), Eq(g_C, f_C + neighbors)]
+    [Assignment(neighbors, f_E + f_N), Assignment(g_C, f_C + neighbors)]
     """
     import inspect
     import re
@@ -590,7 +591,7 @@ def getEquationsFromFunction(func, **kwargs):
     code = "".join(sourceLines[1:])
     result = []
     localsDict = {'_result': result,
-                  'Eq': sp.Eq,
+                  'Eq': Assignment,
                   'S': SymbolCreator()}
     localsDict.update(kwargs)
     globalsDict = inspect.stack()[1][0].f_globals.copy()
diff --git a/transformations/transformations.py b/transformations/transformations.py
index 299201f0c..5c9c9d1fe 100644
--- a/transformations/transformations.py
+++ b/transformations/transformations.py
@@ -7,6 +7,7 @@ import sympy as sp
 from sympy.logic.boolalg import Boolean
 from sympy.tensor import IndexedBase
 
+from pystencils.assignment import Assignment
 from pystencils.field import Field, FieldType, offsetComponentToDirectionString
 from pystencils.data_types import TypedSymbol, createType, PointerType, StructType, getBaseType, castFunc
 from pystencils.slicing import normalizeSlice
@@ -737,7 +738,7 @@ def typeAllEquations(eqs, typeForSymbol):
     def visit(object):
         if isinstance(object, list) or isinstance(object, tuple):
             return [visit(e) for e in object]
-        if isinstance(object, sp.Eq) or isinstance(object, ast.SympyAssignment):
+        if isinstance(object, sp.Eq) or isinstance(object, ast.SympyAssignment) or isinstance(object, Assignment):
             newLhs = processLhs(object.lhs)
             newRhs = processRhs(object.rhs)
             return ast.SympyAssignment(newLhs, newRhs)
-- 
GitLab