diff --git a/equationcollection/equationcollection.py b/equationcollection/equationcollection.py index 66146ad0cb462ea2457774cb2e090c9118fb0927..1609c95d61e5602c8f9cadc03ac130f8dd7fa2c9 100644 --- a/equationcollection/equationcollection.py +++ b/equationcollection/equationcollection.py @@ -1,5 +1,5 @@ import sympy as sp -from copy import deepcopy +from copy import copy from pystencils.sympyextensions import fastSubs, countNumberOfOperations, sortEquationsTopologically @@ -40,11 +40,20 @@ class EquationCollection(object): return [] def copy(self, mainEquations=None, subexpressions=None): - res = deepcopy(self) + res = copy(self) + res.simplificationHints = self.simplificationHints.copy() + res.subexpressionSymbolNameGenerator = copy(self.subexpressionSymbolNameGenerator) + if mainEquations is not None: res.mainEquations = mainEquations + else: + res.mainEquations = self.mainEquations.copy() + if subexpressions is not None: res.subexpressions = subexpressions + else: + res.subexpressions = self.subexpressions.copy() + return res def copyWithSubstitutionsApplied(self, substitutionDict, addSubstitutionsAsSubexpressions=False, @@ -103,7 +112,7 @@ class EquationCollection(object): @property def operationCount(self): """See :func:`countNumberOfOperations` """ - return countNumberOfOperations(self.allEquations) + return countNumberOfOperations(self.allEquations, onlyType=None) def get(self, symbols, fromMainEquationsOnly=False): """Return the equations which have symbols as left hand sides""" diff --git a/sympyextensions.py b/sympyextensions.py index 566cc7441b57ff4c67d14496983938bb5042c138..6e29cabf9135f2ba4b1b82ccecd44e130db8a5d4 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -420,13 +420,14 @@ def countNumberOfOperations(term, onlyType='real'): """ Counts the number of additions, multiplications and division :param term: a sympy term, equation or sequence of terms/equations + :param onlyType: 'real' or 'int' to count only operations on these types, or None for all :return: a dictionary with 'adds', 'muls' and 'divs' keys """ result = {'adds': 0, 'muls': 0, 'divs': 0} if isinstance(term, Sequence): for element in term: - r = countNumberOfOperations(element) + r = countNumberOfOperations(element, onlyType) for operationName in result.keys(): result[operationName] += r[operationName] return result @@ -469,7 +470,7 @@ def countNumberOfOperations(term, onlyType='real'): elif t.is_integer: pass elif t.func is sp.Pow: - if checkType(t): + if checkType(t.args[0]): visitChildren = False if t.exp.is_integer and t.exp.is_number: if t.exp >= 0: