From 0db069261db8611862f430d5178d56bcf0b12585 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Tue, 10 Oct 2017 14:31:46 +0200
Subject: [PATCH] Caching for LB methods - other small performance improvements

---
 astnodes.py                              |  5 ++--
 cache.py                                 | 27 +++++++++++++++++++++-
 equationcollection/equationcollection.py | 29 ++++++++++++------------
 field.py                                 |  9 ++++----
 sympyextensions.py                       |  4 +++-
 vectorization.py                         |  3 ++-
 6 files changed, 54 insertions(+), 23 deletions(-)

diff --git a/astnodes.py b/astnodes.py
index 0fd41535c..09a23075f 100644
--- a/astnodes.py
+++ b/astnodes.py
@@ -2,6 +2,7 @@ import sympy as sp
 from sympy.tensor import IndexedBase
 from pystencils.field import Field
 from pystencils.data_types import TypedSymbol, createType, get_type_from_sympy, createTypeFromString, castFunc
+from pystencils.sympyextensions import fastSubs
 
 
 class ResolvedFieldAccess(sp.Indexed):
@@ -401,8 +402,8 @@ class SympyAssignment(Node):
             self._isDeclaration = False
 
     def subs(self, *args, **kwargs):
-        self.lhs = self.lhs.subs(*args, **kwargs)
-        self.rhs = self.rhs.subs(*args, **kwargs)
+        self.lhs = fastSubs(self.lhs, *args, **kwargs)
+        self.rhs = fastSubs(self.rhs, *args, **kwargs)
 
     @property
     def args(self):
diff --git a/cache.py b/cache.py
index 43590c58a..4ba9122c4 100644
--- a/cache.py
+++ b/cache.py
@@ -1,3 +1,6 @@
+import sympy as sp
+import json
+
 try:
     from functools import lru_cache as memorycache
 except ImportError:
@@ -5,7 +8,7 @@ except ImportError:
 
 try:
     from joblib import Memory
-    diskcache = Memory(cachedir="/tmp/lbmpy", verbose=False).cache
+    diskcache = Memory(cachedir="/tmp/pystencils/joblib_memcache", verbose=False).cache
 except ImportError:
     # fallback to in-memory caching if joblib is not available
     diskcache = memorycache(maxsize=64)
@@ -17,3 +20,25 @@ import sys
 calledBySphinx = 'sphinx' in sys.modules
 if calledBySphinx:
     diskcache = memorycache(maxsize=64)
+
+
+# ------------------------ Helper classes to JSON serialize sympy objects ----------------------------------------------
+
+
+class SympyJSONEncoder(json.JSONEncoder):
+    def default(self, obj):
+        if isinstance(obj, sp.Basic):
+            return {"_type": "sp", "str": str(obj)}
+        else:
+            super(SympyJSONEncoder, self).default(obj)
+
+
+class SympyJSONDecoder(json.JSONDecoder):
+    def __init__(self, *args, **kwargs):
+        json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
+
+    def object_hook(self, obj):
+        if '_type' in obj:
+            return sp.sympify(obj['str'])
+        else:
+            return obj
\ No newline at end of file
diff --git a/equationcollection/equationcollection.py b/equationcollection/equationcollection.py
index 255f56a7d..5baac35a0 100644
--- a/equationcollection/equationcollection.py
+++ b/equationcollection/equationcollection.py
@@ -30,20 +30,6 @@ class EquationCollection(object):
 
         self.simplificationHints = simplificationHints
 
-        class SymbolGen:
-            def __init__(self):
-                self._ctr = 0
-
-            def __iter__(self):
-                return self
-
-            def __next__(self):
-                self._ctr += 1
-                return sp.Symbol("xi_" + str(self._ctr))
-            
-            def next(self):
-                return self.__next__()
-
         if subexpressionSymbolNameGenerator is None:
             self.subexpressionSymbolNameGenerator = SymbolGen()
         else:
@@ -293,3 +279,18 @@ class EquationCollection(object):
             return {s: f(*args, **kwargs) for s, f in lambdas.items()}
 
         return f
+
+
+class SymbolGen:
+    def __init__(self):
+        self._ctr = 0
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        self._ctr += 1
+        return sp.Symbol("xi_" + str(self._ctr))
+
+    def next(self):
+        return self.__next__()
diff --git a/field.py b/field.py
index 4ad087243..98e388c49 100644
--- a/field.py
+++ b/field.py
@@ -121,10 +121,8 @@ class Field(object):
         spatialDimensions = len(shape) - indexDimensions
         assert spatialDimensions >= 1
 
-        if isinstance(layout, str) and (layout == 'numpy' or layout.lower() == 'c'):
-            layout = tuple(range(spatialDimensions))
-        elif isinstance(layout, str) and (layout == 'reverseNumpy' or layout.lower() == 'f'):
-            layout = tuple(reversed(range(spatialDimensions)))
+        if isinstance(layout, str):
+            layout = layoutStringToTuple(layout, spatialDimensions + indexDimensions)
 
         shape = tuple(int(s) for s in shape)
         strides = computeStrides(shape, layout)
@@ -152,6 +150,9 @@ class Field(object):
         # the coordinates are not the loop counters in that case, but are read from this index field
         self.isIndexField = False
 
+    def newFieldWithDifferentName(self, newName):
+        return Field(newName, self._dtype, self._layout, self.shape, self.strides)
+
     @property
     def spatialDimensions(self):
         return len(self._layout)
diff --git a/sympyextensions.py b/sympyextensions.py
index 37eb4c616..62187fbd6 100644
--- a/sympyextensions.py
+++ b/sympyextensions.py
@@ -80,10 +80,12 @@ def productSymmetric(*args, withDiagonal=True):
             yield tuple(a[i] for a, i in zip(args, idx))
 
 
-def fastSubs(term, subsDict):
+def fastSubs(term, subsDict, skip=None):
     """Similar to sympy subs function.
     This version is much faster for big substitution dictionaries than sympy version"""
     def visit(expr):
+        if skip and skip(expr):
+            return expr
         if expr in subsDict:
             return subsDict[expr]
         if not hasattr(expr, 'args'):
diff --git a/vectorization.py b/vectorization.py
index 8aa6a01db..979666d79 100644
--- a/vectorization.py
+++ b/vectorization.py
@@ -1,6 +1,7 @@
 import sympy as sp
 import warnings
 
+from pystencils.sympyextensions import fastSubs
 from pystencils.transformations import filteredTreeIteration
 from pystencils.data_types import TypedSymbol, VectorType, BasicType, getTypeOfExpression, castFunc, collateTypes, \
     PointerType
@@ -89,7 +90,7 @@ def insertVectorCasts(astNode):
 
     substitutionDict = {}
     for asmt in filteredTreeIteration(astNode, ast.SympyAssignment):
-        subsExpr = asmt.rhs.subs(substitutionDict)
+        subsExpr = fastSubs(asmt.rhs, substitutionDict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
         asmt.rhs = visitExpr(subsExpr)
         rhsType = getTypeOfExpression(asmt.rhs)
         if isinstance(asmt.lhs, TypedSymbol):
-- 
GitLab