From 9d1e022d4408c88b99a77737ec5dc7507d0281b0 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Wed, 11 Oct 2017 16:05:25 +0200 Subject: [PATCH] Easier kernel formulation with "@=" operator to create sp.Eq's --- sympyextensions.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/sympyextensions.py b/sympyextensions.py index 62187fbd6..4abeb7b66 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -509,3 +509,62 @@ def getSymmetricPart(term, vars): def sortEquationsTopologically(equationSequence): res = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equationSequence]) return [sp.Eq(a, b) for a, b in res] + + +def getEquationsFromFunction(func, **kwargs): + """ + Mechanism to simplify the generation of a list of sympy equations. + Introduces a special "assignment operator" written as "@=". Each line containing this operator gives an + equation in the result list. Note that executing this function normally yields an error. + + Additionally the shortcut object 'S' is available to quickly create new sympy symbols. + + Example: + + >>> def myKernel(): + ... from pystencils import Field + ... f = Field.createGeneric('f', spatialDimensions=2, indexDimensions=0) + ... g = f.newFieldWithDifferentName('g') + ... + ... S.neighbors @= f[0,1] + f[1,0] + ... g[0,0] @= S.neighbors + f[0,0] + >>> getEquationsFromFunction(myKernel) + [Eq(neighbors, f_E + f_N), Eq(g_C, f_C + neighbors)] + """ + import inspect + import re + + class SymbolCreator: + def __getattribute__(self, name): + return sp.Symbol(name) + + assignmentRegexp = re.compile(r'(\s*)(.+?)@=(.*)') + whitespaceRegexp = re.compile(r'(\s*)(.*)') + sourceLines = inspect.getsourcelines(func)[0] + + # determine indentation + firstCodeLine = sourceLines[1] + matchRes = whitespaceRegexp.match(firstCodeLine) + assert matchRes, "First line is not indented" + numWhitespaces = len(matchRes.group(1)) + + for i in range(1, len(sourceLines)): + sourceLine = sourceLines[i][numWhitespaces:] + if 'return' in sourceLine: + raise ValueError("Function may not have a return statement!") + matchRes = assignmentRegexp.match(sourceLine) + if matchRes: + sourceLine = "%s_result.append(Eq(%s, %s))\n" % matchRes.groups() + sourceLines[i] = sourceLine + + code = "".join(sourceLines[1:]) + result = [] + localsDict = {'_result': result, + 'Eq': sp.Eq, + 'S': SymbolCreator()} + localsDict.update(kwargs) + globalsDict = inspect.stack()[1][0].f_globals.copy() + globalsDict.update(inspect.stack()[1][0].f_locals) + + exec(code, globalsDict, localsDict) + return result -- GitLab