From fdc7f2f186653b79de342da78a5b6265fb16108a Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 11 Aug 2017 16:37:42 +0200
Subject: [PATCH] Worked on waLBerla kernel generation

---
 equationcollection/simplifications.py |  7 +++++++
 field.py                              | 11 +++++------
 sympyextensions.py                    |  2 +-
 types.py                              |  6 +++++-
 4 files changed, 18 insertions(+), 8 deletions(-)

diff --git a/equationcollection/simplifications.py b/equationcollection/simplifications.py
index ebe3cc79c..0e1e92f48 100644
--- a/equationcollection/simplifications.py
+++ b/equationcollection/simplifications.py
@@ -1,7 +1,14 @@
 import sympy as sp
+
+from pystencils.equationcollection.equationcollection import EquationCollection
 from pystencils.sympyextensions import replaceAdditive
 
 
+def sympyCseOnEquationList(eqs):
+    ec = EquationCollection(eqs, [])
+    return sympyCSE(ec).allEquations
+
+
 def sympyCSE(equationCollection):
     """
     Searches for common subexpressions inside the equation collection, in both the existing subexpressions as well
diff --git a/field.py b/field.py
index 97a5893a9..93915a37b 100644
--- a/field.py
+++ b/field.py
@@ -254,19 +254,18 @@ class Field(object):
 
             if constantOffsets:
                 offsetName = offsetToDirectionString(offsets)
-
                 if field.indexDimensions == 0:
-                    obj = super(Field.Access, self).__xnew__(self, fieldName + "_" + offsetName)
+                    symbolName = fieldName + "_" + offsetName
                 elif field.indexDimensions == 1:
-                    obj = super(Field.Access, self).__xnew__(self, fieldName + "_" + offsetName + "^" + str(idx[0]))
+                    symbolName = fieldName + "_" + offsetName + "^" + str(idx[0])
                 else:
                     idxStr = ",".join([str(e) for e in idx])
-                    obj = super(Field.Access, self).__xnew__(self, fieldName + "_" + offsetName + "^" + idxStr)
-
+                    symbolName = fieldName + "_" + offsetName + "^" + idxStr
             else:
                 offsetName = "%0.10X" % (abs(hash(tuple(offsetsAndIndex))))
-                obj = super(Field.Access, self).__xnew__(self, fieldName + "_" + offsetName)
+                symbolName = fieldName + "_" + offsetName
 
+            obj = super(Field.Access, self).__xnew__(self, "{" + symbolName + "}")
             obj._field = field
             obj._offsets = []
             for o in offsets:
diff --git a/sympyextensions.py b/sympyextensions.py
index d14630b9d..37eb4c616 100644
--- a/sympyextensions.py
+++ b/sympyextensions.py
@@ -501,7 +501,7 @@ def getSymmetricPart(term, vars):
     :returns: :math:`\frac{1}{2} [ f(x_0, x_1, ..) + f(-x_0, -x_1) ]`
     """
     substitutionDict = {e: -e for e in vars}
-    return sp.Rational(1, 2) * (term + fastSubs(term, substitutionDict))
+    return sp.Rational(1, 2) * (term + term.subs(substitutionDict))
 
 
 def sortEquationsTopologically(equationSequence):
diff --git a/types.py b/types.py
index d88fd4289..2fb1fcf93 100644
--- a/types.py
+++ b/types.py
@@ -12,7 +12,11 @@ class TypedSymbol(sp.Symbol):
 
     def __new_stage2__(cls, name, dtype):
         obj = super(TypedSymbol, cls).__xnew__(cls, name)
-        obj._dtype = createType(dtype)
+        try:
+            obj._dtype = createType(dtype)
+        except TypeError:
+            # on error keep the string
+            obj._dtype = dtype
         return obj
 
     __xnew__ = staticmethod(__new_stage2__)
-- 
GitLab