Commit 0db06926 authored by Martin Bauer's avatar Martin Bauer
Browse files

Caching for LB methods - other small performance improvements

parent e4ed4efc
......@@ -2,6 +2,7 @@ import sympy as sp
from sympy.tensor import IndexedBase
from pystencils.field import Field
from pystencils.data_types import TypedSymbol, createType, get_type_from_sympy, createTypeFromString, castFunc
from pystencils.sympyextensions import fastSubs
class ResolvedFieldAccess(sp.Indexed):
......@@ -401,8 +402,8 @@ class SympyAssignment(Node):
self._isDeclaration = False
def subs(self, *args, **kwargs):
self.lhs = self.lhs.subs(*args, **kwargs)
self.rhs = self.rhs.subs(*args, **kwargs)
self.lhs = fastSubs(self.lhs, *args, **kwargs)
self.rhs = fastSubs(self.rhs, *args, **kwargs)
def args(self):
import sympy as sp
import json
from functools import lru_cache as memorycache
except ImportError:
......@@ -5,7 +8,7 @@ except ImportError:
from joblib import Memory
diskcache = Memory(cachedir="/tmp/lbmpy", verbose=False).cache
diskcache = Memory(cachedir="/tmp/pystencils/joblib_memcache", verbose=False).cache
except ImportError:
# fallback to in-memory caching if joblib is not available
diskcache = memorycache(maxsize=64)
......@@ -17,3 +20,25 @@ import sys
calledBySphinx = 'sphinx' in sys.modules
if calledBySphinx:
diskcache = memorycache(maxsize=64)
# ------------------------ Helper classes to JSON serialize sympy objects ----------------------------------------------
class SympyJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, sp.Basic):
return {"_type": "sp", "str": str(obj)}
super(SympyJSONEncoder, self).default(obj)
class SympyJSONDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, obj):
if '_type' in obj:
return sp.sympify(obj['str'])
return obj
\ No newline at end of file
......@@ -30,20 +30,6 @@ class EquationCollection(object):
self.simplificationHints = simplificationHints
class SymbolGen:
def __init__(self):
self._ctr = 0
def __iter__(self):
return self
def __next__(self):
self._ctr += 1
return sp.Symbol("xi_" + str(self._ctr))
def next(self):
return self.__next__()
if subexpressionSymbolNameGenerator is None:
self.subexpressionSymbolNameGenerator = SymbolGen()
......@@ -293,3 +279,18 @@ class EquationCollection(object):
return {s: f(*args, **kwargs) for s, f in lambdas.items()}
return f
class SymbolGen:
def __init__(self):
self._ctr = 0
def __iter__(self):
return self
def __next__(self):
self._ctr += 1
return sp.Symbol("xi_" + str(self._ctr))
def next(self):
return self.__next__()
......@@ -121,10 +121,8 @@ class Field(object):
spatialDimensions = len(shape) - indexDimensions
assert spatialDimensions >= 1
if isinstance(layout, str) and (layout == 'numpy' or layout.lower() == 'c'):
layout = tuple(range(spatialDimensions))
elif isinstance(layout, str) and (layout == 'reverseNumpy' or layout.lower() == 'f'):
layout = tuple(reversed(range(spatialDimensions)))
if isinstance(layout, str):
layout = layoutStringToTuple(layout, spatialDimensions + indexDimensions)
shape = tuple(int(s) for s in shape)
strides = computeStrides(shape, layout)
......@@ -152,6 +150,9 @@ class Field(object):
# the coordinates are not the loop counters in that case, but are read from this index field
self.isIndexField = False
def newFieldWithDifferentName(self, newName):
return Field(newName, self._dtype, self._layout, self.shape, self.strides)
def spatialDimensions(self):
return len(self._layout)
......@@ -80,10 +80,12 @@ def productSymmetric(*args, withDiagonal=True):
yield tuple(a[i] for a, i in zip(args, idx))
def fastSubs(term, subsDict):
def fastSubs(term, subsDict, skip=None):
"""Similar to sympy subs function.
This version is much faster for big substitution dictionaries than sympy version"""
def visit(expr):
if skip and skip(expr):
return expr
if expr in subsDict:
return subsDict[expr]
if not hasattr(expr, 'args'):
import sympy as sp
import warnings
from pystencils.sympyextensions import fastSubs
from pystencils.transformations import filteredTreeIteration
from pystencils.data_types import TypedSymbol, VectorType, BasicType, getTypeOfExpression, castFunc, collateTypes, \
......@@ -89,7 +90,7 @@ def insertVectorCasts(astNode):
substitutionDict = {}
for asmt in filteredTreeIteration(astNode, ast.SympyAssignment):
subsExpr = asmt.rhs.subs(substitutionDict)
subsExpr = fastSubs(asmt.rhs, substitutionDict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
asmt.rhs = visitExpr(subsExpr)
rhsType = getTypeOfExpression(asmt.rhs)
if isinstance(asmt.lhs, TypedSymbol):
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment