From c55f53dc50453b571cc99a3dc362f5020c71f07e Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Tue, 24 Apr 2018 13:26:05 +0200
Subject: [PATCH] Improved method to create assignment list from python
 function

- uses ast instead of text-based editing
- can handle Piecewise defined functions
- new decorator for simplified usage
---
 __init__.py         |  8 +++--
 kernel_decorator.py | 78 +++++++++++++++++++++++++++++++++++++++++++++
 sympyextensions.py  | 54 -------------------------------
 3 files changed, 83 insertions(+), 57 deletions(-)
 create mode 100644 kernel_decorator.py

diff --git a/__init__.py b/__init__.py
index 8f7deed40..4ba13340e 100644
--- a/__init__.py
+++ b/__init__.py
@@ -1,6 +1,6 @@
 """Module to generate stencil kernels in C or CUDA using sympy expressions and call them as Python functions"""
 from . import sympy_gmpy_bug_workaround  # NOQA
-from .field import Field, FieldType
+from .field import Field, FieldType, fields
 from .data_types import TypedSymbol
 from .slicing import make_slice
 from .kernelcreation import create_kernel, create_indexed_kernel
@@ -9,8 +9,9 @@ from .assignment_collection import AssignmentCollection
 from .assignment import Assignment
 from .sympyextensions import SymbolCreator
 from .datahandling import create_data_handling
+from . kernel_decorator import kernel
 
-__all__ = ['Field', 'FieldType',
+__all__ = ['Field', 'FieldType', 'fields',
            'TypedSymbol',
            'make_slice',
            'create_kernel', 'create_indexed_kernel',
@@ -18,4 +19,5 @@ __all__ = ['Field', 'FieldType',
            'AssignmentCollection',
            'Assignment',
            'SymbolCreator',
-           'create_data_handling']
+           'create_data_handling',
+           'kernel']
diff --git a/kernel_decorator.py b/kernel_decorator.py
new file mode 100644
index 000000000..fa496c8a7
--- /dev/null
+++ b/kernel_decorator.py
@@ -0,0 +1,78 @@
+import ast
+import inspect
+import sympy as sp
+import textwrap
+from pystencils.sympyextensions import SymbolCreator
+from pystencils.assignment import Assignment
+
+__all__ = ['kernel']
+
+
+def kernel(func, **kwargs):
+    """Decorator to simplify generation of pystencils Assignments.
+
+    Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
+    in the result list. Furthermore the meaning of the ternary inline 'if-else' changes meaning to denote a
+    sympy Piecewise.
+
+    The decorated function may not receive any arguments, with exception of an argument called 's' that specifies
+    a SymbolCreator()
+
+    Examples:
+        >>> import pystencils as ps
+        >>> @kernel
+        ... def my_kernel(s):
+        ...     f, g = ps.fields('f, g: [2D]')
+        ...     s.neighbors @= f[0,1] + f[1,0]
+        ...     g[0,0]      @= s.neighbors + f[0,0] if f[0,0] > 0 else 0
+        >>> my_kernel
+        [Assignment(neighbors, f_E + f_N), Assignment(g_C, Piecewise((f_C + neighbors, f_C > 0), (0, True)))]
+    """
+    source = inspect.getsource(func)
+    source = textwrap.dedent(source)
+    a = ast.parse(source)
+    KernelFunctionRewrite().visit(a)
+    ast.fix_missing_locations(a)
+    gl = func.__globals__.copy()
+
+    assignments = []
+
+    def assignment_adder(lhs, rhs):
+        assignments.append(Assignment(lhs, rhs))
+
+    gl['_add_assignment'] = assignment_adder
+    gl['_Piecewise'] = sp.Piecewise
+    gl.update(inspect.getclosurevars(func).nonlocals)
+    exec(compile(a, filename="<ast>", mode="exec"), gl)
+    func = gl[func.__name__]
+    args = inspect.getfullargspec(func).args
+    if 's' in args and 's' not in kwargs:
+        kwargs['s'] = SymbolCreator()
+    func(**kwargs)
+    return assignments
+
+
+# noinspection PyMethodMayBeStatic
+class KernelFunctionRewrite(ast.NodeTransformer):
+
+    def visit_IfExp(self, node):
+        piecewise_func = ast.Name(id='_Piecewise', ctx=ast.Load())
+        piecewise_func = ast.copy_location(piecewise_func, node)
+        piecewise_args = [ast.Tuple(elts=[node.body, node.test], ctx=ast.Load()),
+                          ast.Tuple(elts=[node.orelse, ast.NameConstant(value=True)], ctx=ast.Load())]
+        result = ast.Call(func=piecewise_func, args=piecewise_args, keywords=[])
+
+        return ast.copy_location(result, node)
+
+    def visit_AugAssign(self, node):
+        self.generic_visit(node)
+        node.target.ctx = ast.Load()
+        new_node = ast.Expr(ast.Call(func=ast.Name(id='_add_assignment', ctx=ast.Load()),
+                                     args=[node.target, node.value],
+                                     keywords=[]))
+        return ast.copy_location(new_node, node)
+
+    def visit_FunctionDef(self, node):
+        self.generic_visit(node)
+        node.decorator_list = []
+        return node
diff --git a/sympyextensions.py b/sympyextensions.py
index 55859f0f5..ec27fd238 100644
--- a/sympyextensions.py
+++ b/sympyextensions.py
@@ -505,60 +505,6 @@ def sort_assignments_topologically(assignments: Sequence[Assignment]) -> List[As
     return [Assignment(a, b) for a, b in res]
 
 
-def assignments_from_python_function(func, **kwargs):
-    """Mechanism to simplify the generation of a list of sympy equations.
-
-    Introduces a special "assignment operator" written as "@=". Each line containing this operator gives an
-    equation in the result list. Note that executing this function normally yields an error.
-
-    Additionally the shortcut object 'S' is available to quickly create new sympy symbols.
-
-    Examples:
-        >>> def my_kernel(s):
-        ...     from pystencils import Field
-        ...     f = Field.create_generic('f', spatial_dimensions=2, index_dimensions=0)
-        ...     g = f.new_field_with_different_name('g')
-        ...
-        ...     s.neighbors @= f[0,1] + f[1,0]
-        ...     g[0,0]      @= s.neighbors + f[0,0]
-        >>> assignments_from_python_function(my_kernel)
-        [Assignment(neighbors, f_E + f_N), Assignment(g_C, f_C + neighbors)]
-    """
-    import inspect
-    import re
-
-    assignment_regexp = re.compile(r'(\s*)(.+?)@=(.*)')
-    whitespace_regexp = re.compile(r'(\s*)(.*)')
-    source_lines = inspect.getsourcelines(func)[0]
-
-    # determine indentation
-    first_code_line = source_lines[1]
-    match_res = whitespace_regexp.match(first_code_line)
-    assert match_res, "First line is not indented"
-    num_whitespaces = len(match_res.group(1))
-
-    for i in range(1, len(source_lines)):
-        source_line = source_lines[i][num_whitespaces:]
-        if 'return' in source_line:
-            raise ValueError("Function may not have a return statement!")
-        match_res = assignment_regexp.match(source_line)
-        if match_res:
-            source_line = "%s_result.append(Assignment(%s, %s))\n" % tuple(match_res.groups()[i] for i in range(3))
-        source_lines[i] = source_line
-
-    code = "".join(source_lines[1:])
-    result = []
-    locals_dict = {'_result': result,
-                   'Assignment': Assignment,
-                   's': SymbolCreator()}
-    locals_dict.update(kwargs)
-    globals_dict = inspect.stack()[1][0].f_globals.copy()
-    globals_dict.update(inspect.stack()[1][0].f_locals)
-
-    exec(code, globals_dict, locals_dict)
-    return result
-
-
 class SymbolCreator:
     def __getattribute__(self, name):
         return sp.Symbol(name)
-- 
GitLab