Commit 9d1e022d authored by Martin Bauer's avatar Martin Bauer
Browse files

Easier kernel formulation with "@=" operator to create sp.Eq's

parent 231fb6af
......@@ -509,3 +509,62 @@ def getSymmetricPart(term, vars):
def sortEquationsTopologically(equationSequence):
res = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equationSequence])
return [sp.Eq(a, b) for a, b in res]
def getEquationsFromFunction(func, **kwargs):
"""
Mechanism to simplify the generation of a list of sympy equations.
Introduces a special "assignment operator" written as "@=". Each line containing this operator gives an
equation in the result list. Note that executing this function normally yields an error.
Additionally the shortcut object 'S' is available to quickly create new sympy symbols.
Example:
>>> def myKernel():
... from pystencils import Field
... f = Field.createGeneric('f', spatialDimensions=2, indexDimensions=0)
... g = f.newFieldWithDifferentName('g')
...
... S.neighbors @= f[0,1] + f[1,0]
... g[0,0] @= S.neighbors + f[0,0]
>>> getEquationsFromFunction(myKernel)
[Eq(neighbors, f_E + f_N), Eq(g_C, f_C + neighbors)]
"""
import inspect
import re
class SymbolCreator:
def __getattribute__(self, name):
return sp.Symbol(name)
assignmentRegexp = re.compile(r'(\s*)(.+?)@=(.*)')
whitespaceRegexp = re.compile(r'(\s*)(.*)')
sourceLines = inspect.getsourcelines(func)[0]
# determine indentation
firstCodeLine = sourceLines[1]
matchRes = whitespaceRegexp.match(firstCodeLine)
assert matchRes, "First line is not indented"
numWhitespaces = len(matchRes.group(1))
for i in range(1, len(sourceLines)):
sourceLine = sourceLines[i][numWhitespaces:]
if 'return' in sourceLine:
raise ValueError("Function may not have a return statement!")
matchRes = assignmentRegexp.match(sourceLine)
if matchRes:
sourceLine = "%s_result.append(Eq(%s, %s))\n" % matchRes.groups()
sourceLines[i] = sourceLine
code = "".join(sourceLines[1:])
result = []
localsDict = {'_result': result,
'Eq': sp.Eq,
'S': SymbolCreator()}
localsDict.update(kwargs)
globalsDict = inspect.stack()[1][0].f_globals.copy()
globalsDict.update(inspect.stack()[1][0].f_locals)
exec(code, globalsDict, localsDict)
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment