diff --git a/equationcollection/equationcollection.py b/equationcollection/equationcollection.py
index 8c6c6f32fef341296fd5b3163eb185f690bfd921..2ea0c2b6f4f59ce66d6c0cbf029a1ccaeb003336 100644
--- a/equationcollection/equationcollection.py
+++ b/equationcollection/equationcollection.py
@@ -1,4 +1,5 @@
 import sympy as sp
+from copy import copy, deepcopy
 from pystencils.sympyextensions import fastSubs, countNumberOfOperations
 
 
@@ -20,52 +21,50 @@ class EquationCollection:
 
     # ----------------------------------------- Creation ---------------------------------------------------------------
 
-    def __init__(self, equations, subExpressions, simplificationHints={}, subexpressionSymbolNameGenerator=None):
+    def __init__(self, equations, subExpressions, simplificationHints=None, subexpressionSymbolNameGenerator=None):
         self.mainEquations = equations
         self.subexpressions = subExpressions
+
+        if simplificationHints is None:
+            simplificationHints = {}
+
         self.simplificationHints = simplificationHints
 
-        def symbolGen():
-            """Use this generator to create new unused symbols for subexpressions"""
-            counter = 0
-            while True:
-                counter += 1
-                newSymbol = sp.Symbol("xi_" + str(counter))
-                if newSymbol in self.boundSymbols:
-                    continue
-                yield newSymbol
+        class SymbolGen:
+            def __init__(self):
+                self._ctr = 0
+
+            def __iter__(self):
+                return self
+
+            def __next__(self):
+                self._ctr += 1
+                return sp.Symbol("xi_" + str(self._ctr))
 
         if subexpressionSymbolNameGenerator is None:
-            self.subexpressionSymbolNameGenerator = symbolGen()
+            self.subexpressionSymbolNameGenerator = SymbolGen()
         else:
             self.subexpressionSymbolNameGenerator = subexpressionSymbolNameGenerator
 
-    def newWithAdditionalSubexpressions(self, newEquations, additionalSubExpressions):
-        """
-        Returns a new equation collection, that has `newEquations` as mainEquations.
-        The `additionalSubExpressions` are appended to the existing subexpressions.
-        Simplifications hints are copied over.
-        """
-        assert len(self.mainEquations) == len(newEquations), "Number of update equations cannot be changed"
-        res = EquationCollection(newEquations,
-                                 self.subexpressions + additionalSubExpressions,
-                                 self.simplificationHints)
-        res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator
+    def copy(self, mainEquations=None, subexpressions=None):
+        res = deepcopy(self)
+        if mainEquations is not None:
+            res.mainEquations = mainEquations
+        if subexpressions is not None:
+            res.subexpressions = subexpressions
         return res
 
-    def newWithSubstitutionsApplied(self, substitutionDict, addSubstitutionsAsSubexpresions=False):
+    def copyWithSubstitutionsApplied(self, substitutionDict, addSubstitutionsAsSubexpressions=False):
         """
         Returns a new equation collection, where terms are substituted according to the passed `substitutionDict`.
         Substitutions are made in the subexpression terms and the main equations
         """
         newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions]
         newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations]
-        if addSubstitutionsAsSubexpresions:
+        if addSubstitutionsAsSubexpressions:
             newSubexpressions = [sp.Eq(b, a) for a, b in substitutionDict.items()] + newSubexpressions
 
-        res = EquationCollection(newEquations, newSubexpressions, self.simplificationHints)
-        res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator
-        return res
+        return self.copy(newEquations, newSubexpressions)
 
     def addSimplificationHint(self, key, value):
         """
@@ -178,41 +177,45 @@ class EquationCollection:
                     substitutionDict[otherSubexpressionEq.lhs] = newLhs
             else:
                 processedOtherSubexpressionEquations.append(fastSubs(otherSubexpressionEq, substitutionDict))
-        return EquationCollection(self.mainEquations + other.mainEquations,
-                                  self.subexpressions + processedOtherSubexpressionEquations)
+        return self.copy(self.mainEquations + other.mainEquations,
+                         self.subexpressions + processedOtherSubexpressionEquations)
 
-    def extract(self, symbolsToExtract):
-        """
-        Creates a new equation collection with equations that have symbolsToExtract as left-hand-sides and
-        only the necessary subexpressions that are used in these equations
-        """
-        symbolsToExtract = set(symbolsToExtract)
-        newEquations = []
+    def getDependentSymbols(self, symbolSequence):
+        """Returns a list of symbols that depend on the passed symbols."""
 
-        subexprMap = {e.lhs: e.rhs for e in self.subexpressions}
-        handledSymbols = set()
-        queue = []
+        queue = list(symbolSequence)
 
         def addSymbolsFromExpr(expr):
             dependentSymbols = expr.atoms(sp.Symbol)
             for ds in dependentSymbols:
-                if ds not in handledSymbols:
-                    queue.append(ds)
-                    handledSymbols.add(ds)
+                queue.append(ds)
 
-        for eq in self.allEquations:
-            if eq.lhs in symbolsToExtract:
-                newEquations.append(eq)
-                addSymbolsFromExpr(eq.rhs)
+        handledSymbols = set()
+        eqMap = {e.lhs: e.rhs for e in self.allEquations}
 
         while len(queue) > 0:
             e = queue.pop(0)
-            if e not in subexprMap:
+            if e in handledSymbols:
                 continue
-            else:
-                addSymbolsFromExpr(subexprMap[e])
+            if e in eqMap:
+                addSymbolsFromExpr(eqMap[e])
+            handledSymbols.add(e)
+
+        return handledSymbols
+
+    def extract(self, symbolsToExtract):
+        """
+        Creates a new equation collection with equations that have symbolsToExtract as left-hand-sides and
+        only the necessary subexpressions that are used in these equations
+        """
+        symbolsToExtract = set(symbolsToExtract)
+        dependentSymbols = self.getDependentSymbols(symbolsToExtract)
+        newEquations = []
+        for eq in self.allEquations:
+            if eq.lhs in symbolsToExtract:
+                newEquations.append(eq)
 
-        newSubExpr = [eq for eq in self.subexpressions if eq.lhs in handledSymbols and eq.lhs not in symbolsToExtract]
+        newSubExpr = [eq for eq in self.subexpressions if eq.lhs in dependentSymbols and eq.lhs not in symbolsToExtract]
         return EquationCollection(newEquations, newSubExpr)
 
     def newWithoutUnusedSubexpressions(self):
@@ -221,18 +224,30 @@ class EquationCollection:
         allLhs = [eq.lhs for eq in self.mainEquations]
         return self.extract(allLhs)
 
-    def insertSubexpressions(self):
+    def insertSubexpressions(self, subexpressionSymbolsToKeep=set()):
         """Returns a new equation collection by inserting all subexpressions into the main equations"""
         if len(self.subexpressions) == 0:
-            return EquationCollection(self.mainEquations, self.subexpressions, self.simplificationHints)
-        subsDict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
+            return self.copy()
+
+        subexpressionSymbolsToKeep = set(subexpressionSymbolsToKeep)
+
+        keptSubexpressions = []
+        if self.subexpressions[0].lhs in subexpressionSymbolsToKeep:
+            subsDict = {}
+            keptSubexpressions = self.subexpressions[0]
+        else:
+            subsDict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
+
         subExpr = [e for e in self.subexpressions]
         for i in range(1, len(subExpr)):
             subExpr[i] = fastSubs(subExpr[i], subsDict)
-            subsDict[subExpr[i].lhs] = subExpr[i].rhs
+            if subExpr[i].lhs in subexpressionSymbolsToKeep:
+                keptSubexpressions.append(subExpr[i])
+            else:
+                subsDict[subExpr[i].lhs] = subExpr[i].rhs
 
         newEq = [fastSubs(eq, subsDict) for eq in self.mainEquations]
-        return EquationCollection(newEq, [], self.simplificationHints)
+        return self.copy(newEq, keptSubexpressions)
 
     def lambdify(self, symbols, module=None, fixedSymbols={}):
         """
@@ -241,7 +256,7 @@ class EquationCollection:
         :param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g. 'numpy'
         :param fixedSymbols: dictionary with substitutions, that are applied before lambdification
         """
-        eqs = self.newWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations
+        eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations
         lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs}
 
         def f(*args, **kwargs):
diff --git a/equationcollection/simplifications.py b/equationcollection/simplifications.py
index 412a889cca6273a1da429c2b65fca7fd1dcdcfc2..32837359911f8a348cc14d7873a451e5aee12986 100644
--- a/equationcollection/simplifications.py
+++ b/equationcollection/simplifications.py
@@ -1,5 +1,4 @@
 import sympy as sp
-from pystencils.equationcollection import EquationCollection
 from pystencils.sympyextensions import replaceAdditive
 
 
@@ -21,21 +20,18 @@ def sympyCSE(equationCollection):
     topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions])
     newSubexpressions = [sp.Eq(a[0], a[1]) for a in topologicallySortedPairs]
 
-    return EquationCollection(modifiedUpdateEquations, newSubexpressions, equationCollection.simplificationHints,
-                              equationCollection.subexpressionSymbolNameGenerator)
+    return equationCollection.copy(modifiedUpdateEquations, newSubexpressions)
 
 
 def applyOnAllEquations(equationCollection, operation):
     """Applies sympy expand operation to all equations in collection"""
     result = [operation(s) for s in equationCollection.mainEquations]
-    return equationCollection.newWithAdditionalSubexpressions(result, [])
+    return equationCollection.copy(result)
 
 
 def applyOnAllSubexpressions(equationCollection, operation):
-    return EquationCollection(equationCollection.mainEquations,
-                              [operation(s) for s in equationCollection.subexpressions],
-                              equationCollection.simplificationHints,
-                              equationCollection.subexpressionSymbolNameGenerator)
+    return equationCollection.copy(equationCollection.mainEquations,
+                                   [operation(s) for s in equationCollection.subexpressions])
 
 
 def subexpressionSubstitutionInExistingSubexpressions(equationCollection):
@@ -49,8 +45,7 @@ def subexpressionSubstitutionInExistingSubexpressions(equationCollection):
             newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs)
         result.append(sp.Eq(s.lhs, newRhs))
 
-    return EquationCollection(equationCollection.mainEquations, result, equationCollection.simplificationHints,
-                              equationCollection.subexpressionSymbolNameGenerator)
+    return equationCollection.copy(equationCollection.mainEquations, result)
 
 
 def subexpressionSubstitutionInMainEquations(equationCollection):
@@ -61,7 +56,7 @@ def subexpressionSubstitutionInMainEquations(equationCollection):
         for subExpr in equationCollection.subexpressions:
             newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0)
         result.append(sp.Eq(s.lhs, newRhs))
-    return equationCollection.newWithAdditionalSubexpressions(result, [])
+    return equationCollection.copy(result)
 
 
 def addSubexpressionsForDivisions(equationCollection):
@@ -80,4 +75,4 @@ def addSubexpressionsForDivisions(equationCollection):
 
     newSymbolGen = equationCollection.subexpressionSymbolNameGenerator
     substitutions = {divisor: newSymbol for newSymbol, divisor in zip(newSymbolGen, divisors)}
-    return equationCollection.newWithSubstitutionsApplied(substitutions, True)
+    return equationCollection.copyWithSubstitutionsApplied(substitutions, True)
diff --git a/sympyextensions.py b/sympyextensions.py
index e02b58be9b3c94f62117be32bec8f8f5900ee4ae..8daa67136a85d8bca900f2091ec6b35fbfce543f 100644
--- a/sympyextensions.py
+++ b/sympyextensions.py
@@ -14,7 +14,11 @@ def fastSubs(term, subsDict):
             return expr
         paramList = [visit(a) for a in expr.args]
         return expr if not paramList else expr.func(*paramList)
-    return visit(term)
+
+    if len(subsDict) == 0:
+        return term
+    else:
+        return visit(term)
 
 
 def replaceAdditive(expr, replacement, subExpression, requiredMatchReplacement=0.5, requiredMatchOriginal=None):