Skip to content
Snippets Groups Projects
Commit 9d1e022d authored by Martin Bauer's avatar Martin Bauer
Browse files

Easier kernel formulation with "@=" operator to create sp.Eq's

parent 231fb6af
Branches
Tags
No related merge requests found
......@@ -509,3 +509,62 @@ def getSymmetricPart(term, vars):
def sortEquationsTopologically(equationSequence):
res = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equationSequence])
return [sp.Eq(a, b) for a, b in res]
def getEquationsFromFunction(func, **kwargs):
"""
Mechanism to simplify the generation of a list of sympy equations.
Introduces a special "assignment operator" written as "@=". Each line containing this operator gives an
equation in the result list. Note that executing this function normally yields an error.
Additionally the shortcut object 'S' is available to quickly create new sympy symbols.
Example:
>>> def myKernel():
... from pystencils import Field
... f = Field.createGeneric('f', spatialDimensions=2, indexDimensions=0)
... g = f.newFieldWithDifferentName('g')
...
... S.neighbors @= f[0,1] + f[1,0]
... g[0,0] @= S.neighbors + f[0,0]
>>> getEquationsFromFunction(myKernel)
[Eq(neighbors, f_E + f_N), Eq(g_C, f_C + neighbors)]
"""
import inspect
import re
class SymbolCreator:
def __getattribute__(self, name):
return sp.Symbol(name)
assignmentRegexp = re.compile(r'(\s*)(.+?)@=(.*)')
whitespaceRegexp = re.compile(r'(\s*)(.*)')
sourceLines = inspect.getsourcelines(func)[0]
# determine indentation
firstCodeLine = sourceLines[1]
matchRes = whitespaceRegexp.match(firstCodeLine)
assert matchRes, "First line is not indented"
numWhitespaces = len(matchRes.group(1))
for i in range(1, len(sourceLines)):
sourceLine = sourceLines[i][numWhitespaces:]
if 'return' in sourceLine:
raise ValueError("Function may not have a return statement!")
matchRes = assignmentRegexp.match(sourceLine)
if matchRes:
sourceLine = "%s_result.append(Eq(%s, %s))\n" % matchRes.groups()
sourceLines[i] = sourceLine
code = "".join(sourceLines[1:])
result = []
localsDict = {'_result': result,
'Eq': sp.Eq,
'S': SymbolCreator()}
localsDict.update(kwargs)
globalsDict = inspect.stack()[1][0].f_globals.copy()
globalsDict.update(inspect.stack()[1][0].f_locals)
exec(code, globalsDict, localsDict)
return result
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment