From 9d1e022d4408c88b99a77737ec5dc7507d0281b0 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Wed, 11 Oct 2017 16:05:25 +0200
Subject: [PATCH] Easier kernel formulation with "@=" operator to create
 sp.Eq's

---
 sympyextensions.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 59 insertions(+)

diff --git a/sympyextensions.py b/sympyextensions.py
index 62187fbd6..4abeb7b66 100644
--- a/sympyextensions.py
+++ b/sympyextensions.py
@@ -509,3 +509,62 @@ def getSymmetricPart(term, vars):
 def sortEquationsTopologically(equationSequence):
     res = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equationSequence])
     return [sp.Eq(a, b) for a, b in res]
+
+
+def getEquationsFromFunction(func, **kwargs):
+    """
+    Mechanism to simplify the generation of a list of sympy equations. 
+    Introduces a special "assignment operator" written as "@=". Each line containing this operator gives an
+    equation in the result list. Note that executing this function normally yields an error.
+    
+    Additionally the shortcut object 'S' is available to quickly create new sympy symbols.
+    
+    Example:
+        
+    >>> def myKernel():
+    ...     from pystencils import Field
+    ...     f = Field.createGeneric('f', spatialDimensions=2, indexDimensions=0)
+    ...     g = f.newFieldWithDifferentName('g')
+    ...     
+    ...     S.neighbors @= f[0,1] + f[1,0]
+    ...     g[0,0]      @= S.neighbors + f[0,0]
+    >>> getEquationsFromFunction(myKernel)
+    [Eq(neighbors, f_E + f_N), Eq(g_C, f_C + neighbors)]
+    """
+    import inspect
+    import re
+
+    class SymbolCreator:
+        def __getattribute__(self, name):
+            return sp.Symbol(name)
+
+    assignmentRegexp = re.compile(r'(\s*)(.+?)@=(.*)')
+    whitespaceRegexp = re.compile(r'(\s*)(.*)')
+    sourceLines = inspect.getsourcelines(func)[0]
+
+    # determine indentation
+    firstCodeLine = sourceLines[1]
+    matchRes = whitespaceRegexp.match(firstCodeLine)
+    assert matchRes, "First line is not indented"
+    numWhitespaces = len(matchRes.group(1))
+
+    for i in range(1, len(sourceLines)):
+        sourceLine = sourceLines[i][numWhitespaces:]
+        if 'return' in sourceLine:
+            raise ValueError("Function may not have a return statement!")
+        matchRes = assignmentRegexp.match(sourceLine)
+        if matchRes:
+            sourceLine = "%s_result.append(Eq(%s, %s))\n" % matchRes.groups()
+        sourceLines[i] = sourceLine
+
+    code = "".join(sourceLines[1:])
+    result = []
+    localsDict = {'_result': result,
+                  'Eq': sp.Eq,
+                  'S': SymbolCreator()}
+    localsDict.update(kwargs)
+    globalsDict = inspect.stack()[1][0].f_globals.copy()
+    globalsDict.update(inspect.stack()[1][0].f_locals)
+
+    exec(code, globalsDict, localsDict)
+    return result
-- 
GitLab