diff --git a/astnodes.py b/astnodes.py index 0fd41535c25d73964bae3a7c876b722b3ef23da9..09a23075f769add17c230cf0193fa78669c013ee 100644 --- a/astnodes.py +++ b/astnodes.py @@ -2,6 +2,7 @@ import sympy as sp from sympy.tensor import IndexedBase from pystencils.field import Field from pystencils.data_types import TypedSymbol, createType, get_type_from_sympy, createTypeFromString, castFunc +from pystencils.sympyextensions import fastSubs class ResolvedFieldAccess(sp.Indexed): @@ -401,8 +402,8 @@ class SympyAssignment(Node): self._isDeclaration = False def subs(self, *args, **kwargs): - self.lhs = self.lhs.subs(*args, **kwargs) - self.rhs = self.rhs.subs(*args, **kwargs) + self.lhs = fastSubs(self.lhs, *args, **kwargs) + self.rhs = fastSubs(self.rhs, *args, **kwargs) @property def args(self): diff --git a/cache.py b/cache.py index 43590c58ac98b887bfe617a09f910f1a847b2c5e..4ba9122c44e8ca98ef1a5b1103bacd7bb291e512 100644 --- a/cache.py +++ b/cache.py @@ -1,3 +1,6 @@ +import sympy as sp +import json + try: from functools import lru_cache as memorycache except ImportError: @@ -5,7 +8,7 @@ except ImportError: try: from joblib import Memory - diskcache = Memory(cachedir="/tmp/lbmpy", verbose=False).cache + diskcache = Memory(cachedir="/tmp/pystencils/joblib_memcache", verbose=False).cache except ImportError: # fallback to in-memory caching if joblib is not available diskcache = memorycache(maxsize=64) @@ -17,3 +20,25 @@ import sys calledBySphinx = 'sphinx' in sys.modules if calledBySphinx: diskcache = memorycache(maxsize=64) + + +# ------------------------ Helper classes to JSON serialize sympy objects ---------------------------------------------- + + +class SympyJSONEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, sp.Basic): + return {"_type": "sp", "str": str(obj)} + else: + super(SympyJSONEncoder, self).default(obj) + + +class SympyJSONDecoder(json.JSONDecoder): + def __init__(self, *args, **kwargs): + json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) + + def object_hook(self, obj): + if '_type' in obj: + return sp.sympify(obj['str']) + else: + return obj \ No newline at end of file diff --git a/equationcollection/equationcollection.py b/equationcollection/equationcollection.py index 255f56a7de67368ff402c9aad00c5b7b29ef9d47..5baac35a095fb52139a990e73d919530a7ceacef 100644 --- a/equationcollection/equationcollection.py +++ b/equationcollection/equationcollection.py @@ -30,20 +30,6 @@ class EquationCollection(object): self.simplificationHints = simplificationHints - class SymbolGen: - def __init__(self): - self._ctr = 0 - - def __iter__(self): - return self - - def __next__(self): - self._ctr += 1 - return sp.Symbol("xi_" + str(self._ctr)) - - def next(self): - return self.__next__() - if subexpressionSymbolNameGenerator is None: self.subexpressionSymbolNameGenerator = SymbolGen() else: @@ -293,3 +279,18 @@ class EquationCollection(object): return {s: f(*args, **kwargs) for s, f in lambdas.items()} return f + + +class SymbolGen: + def __init__(self): + self._ctr = 0 + + def __iter__(self): + return self + + def __next__(self): + self._ctr += 1 + return sp.Symbol("xi_" + str(self._ctr)) + + def next(self): + return self.__next__() diff --git a/field.py b/field.py index 4ad087243a4a11a41f7983b7b54b85dd1b0d26d9..98e388c49c96221350c6019815c873d63d4f1da3 100644 --- a/field.py +++ b/field.py @@ -121,10 +121,8 @@ class Field(object): spatialDimensions = len(shape) - indexDimensions assert spatialDimensions >= 1 - if isinstance(layout, str) and (layout == 'numpy' or layout.lower() == 'c'): - layout = tuple(range(spatialDimensions)) - elif isinstance(layout, str) and (layout == 'reverseNumpy' or layout.lower() == 'f'): - layout = tuple(reversed(range(spatialDimensions))) + if isinstance(layout, str): + layout = layoutStringToTuple(layout, spatialDimensions + indexDimensions) shape = tuple(int(s) for s in shape) strides = computeStrides(shape, layout) @@ -152,6 +150,9 @@ class Field(object): # the coordinates are not the loop counters in that case, but are read from this index field self.isIndexField = False + def newFieldWithDifferentName(self, newName): + return Field(newName, self._dtype, self._layout, self.shape, self.strides) + @property def spatialDimensions(self): return len(self._layout) diff --git a/sympyextensions.py b/sympyextensions.py index 37eb4c6167da82a69247ad761ff68206ba1ecfcc..62187fbd62804fb6a276c9dc919c8f1ec6e47a2f 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -80,10 +80,12 @@ def productSymmetric(*args, withDiagonal=True): yield tuple(a[i] for a, i in zip(args, idx)) -def fastSubs(term, subsDict): +def fastSubs(term, subsDict, skip=None): """Similar to sympy subs function. This version is much faster for big substitution dictionaries than sympy version""" def visit(expr): + if skip and skip(expr): + return expr if expr in subsDict: return subsDict[expr] if not hasattr(expr, 'args'): diff --git a/vectorization.py b/vectorization.py index 8aa6a01db004a6454997f9c6fe2950412462a79c..979666d7959ed6d0a65670ddd222bad2c60c55f6 100644 --- a/vectorization.py +++ b/vectorization.py @@ -1,6 +1,7 @@ import sympy as sp import warnings +from pystencils.sympyextensions import fastSubs from pystencils.transformations import filteredTreeIteration from pystencils.data_types import TypedSymbol, VectorType, BasicType, getTypeOfExpression, castFunc, collateTypes, \ PointerType @@ -89,7 +90,7 @@ def insertVectorCasts(astNode): substitutionDict = {} for asmt in filteredTreeIteration(astNode, ast.SympyAssignment): - subsExpr = asmt.rhs.subs(substitutionDict) + subsExpr = fastSubs(asmt.rhs, substitutionDict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess)) asmt.rhs = visitExpr(subsExpr) rhsType = getTypeOfExpression(asmt.rhs) if isinstance(asmt.lhs, TypedSymbol):