diff --git a/equationcollection/simplifications.py b/equationcollection/simplifications.py index ebe3cc79c4a45d128cca8659271afa7fb7170f12..0e1e92f4872e58262d5b8667911d2392261d0819 100644 --- a/equationcollection/simplifications.py +++ b/equationcollection/simplifications.py @@ -1,7 +1,14 @@ import sympy as sp + +from pystencils.equationcollection.equationcollection import EquationCollection from pystencils.sympyextensions import replaceAdditive +def sympyCseOnEquationList(eqs): + ec = EquationCollection(eqs, []) + return sympyCSE(ec).allEquations + + def sympyCSE(equationCollection): """ Searches for common subexpressions inside the equation collection, in both the existing subexpressions as well diff --git a/field.py b/field.py index 97a5893a9cbe507a41181fcc0eafdc36cf12cc57..93915a37b1b6994c18e3491aeaf32c1cd51b2392 100644 --- a/field.py +++ b/field.py @@ -254,19 +254,18 @@ class Field(object): if constantOffsets: offsetName = offsetToDirectionString(offsets) - if field.indexDimensions == 0: - obj = super(Field.Access, self).__xnew__(self, fieldName + "_" + offsetName) + symbolName = fieldName + "_" + offsetName elif field.indexDimensions == 1: - obj = super(Field.Access, self).__xnew__(self, fieldName + "_" + offsetName + "^" + str(idx[0])) + symbolName = fieldName + "_" + offsetName + "^" + str(idx[0]) else: idxStr = ",".join([str(e) for e in idx]) - obj = super(Field.Access, self).__xnew__(self, fieldName + "_" + offsetName + "^" + idxStr) - + symbolName = fieldName + "_" + offsetName + "^" + idxStr else: offsetName = "%0.10X" % (abs(hash(tuple(offsetsAndIndex)))) - obj = super(Field.Access, self).__xnew__(self, fieldName + "_" + offsetName) + symbolName = fieldName + "_" + offsetName + obj = super(Field.Access, self).__xnew__(self, "{" + symbolName + "}") obj._field = field obj._offsets = [] for o in offsets: diff --git a/sympyextensions.py b/sympyextensions.py index d14630b9dca638915d43d8363b88facddc3e52ff..37eb4c6167da82a69247ad761ff68206ba1ecfcc 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -501,7 +501,7 @@ def getSymmetricPart(term, vars): :returns: :math:`\frac{1}{2} [ f(x_0, x_1, ..) + f(-x_0, -x_1) ]` """ substitutionDict = {e: -e for e in vars} - return sp.Rational(1, 2) * (term + fastSubs(term, substitutionDict)) + return sp.Rational(1, 2) * (term + term.subs(substitutionDict)) def sortEquationsTopologically(equationSequence): diff --git a/types.py b/types.py index d88fd4289513bcb61afbfef945d021493260c76c..2fb1fcf9328789a7983ab3f2c00c1b5aa52d77aa 100644 --- a/types.py +++ b/types.py @@ -12,7 +12,11 @@ class TypedSymbol(sp.Symbol): def __new_stage2__(cls, name, dtype): obj = super(TypedSymbol, cls).__xnew__(cls, name) - obj._dtype = createType(dtype) + try: + obj._dtype = createType(dtype) + except TypeError: + # on error keep the string + obj._dtype = dtype return obj __xnew__ = staticmethod(__new_stage2__)