Commit c55f53dc authored by Martin Bauer's avatar Martin Bauer
Browse files

Improved method to create assignment list from python function

- uses ast instead of text-based editing
- can handle Piecewise defined functions
- new decorator for simplified usage
parent 3bbc5d49
"""Module to generate stencil kernels in C or CUDA using sympy expressions and call them as Python functions"""
from . import sympy_gmpy_bug_workaround # NOQA
from .field import Field, FieldType
from .field import Field, FieldType, fields
from .data_types import TypedSymbol
from .slicing import make_slice
from .kernelcreation import create_kernel, create_indexed_kernel
......@@ -9,8 +9,9 @@ from .assignment_collection import AssignmentCollection
from .assignment import Assignment
from .sympyextensions import SymbolCreator
from .datahandling import create_data_handling
from . kernel_decorator import kernel
__all__ = ['Field', 'FieldType',
__all__ = ['Field', 'FieldType', 'fields',
'create_kernel', 'create_indexed_kernel',
......@@ -18,4 +19,5 @@ __all__ = ['Field', 'FieldType',
import ast
import inspect
import sympy as sp
import textwrap
from pystencils.sympyextensions import SymbolCreator
from pystencils.assignment import Assignment
__all__ = ['kernel']
def kernel(func, **kwargs):
"""Decorator to simplify generation of pystencils Assignments.
Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
in the result list. Furthermore the meaning of the ternary inline 'if-else' changes meaning to denote a
sympy Piecewise.
The decorated function may not receive any arguments, with exception of an argument called 's' that specifies
a SymbolCreator()
>>> import pystencils as ps
>>> @kernel
... def my_kernel(s):
... f, g = ps.fields('f, g: [2D]')
... s.neighbors @= f[0,1] + f[1,0]
... g[0,0] @= s.neighbors + f[0,0] if f[0,0] > 0 else 0
>>> my_kernel
[Assignment(neighbors, f_E + f_N), Assignment(g_C, Piecewise((f_C + neighbors, f_C > 0), (0, True)))]
source = inspect.getsource(func)
source = textwrap.dedent(source)
a = ast.parse(source)
gl = func.__globals__.copy()
assignments = []
def assignment_adder(lhs, rhs):
assignments.append(Assignment(lhs, rhs))
gl['_add_assignment'] = assignment_adder
gl['_Piecewise'] = sp.Piecewise
exec(compile(a, filename="<ast>", mode="exec"), gl)
func = gl[func.__name__]
args = inspect.getfullargspec(func).args
if 's' in args and 's' not in kwargs:
kwargs['s'] = SymbolCreator()
return assignments
# noinspection PyMethodMayBeStatic
class KernelFunctionRewrite(ast.NodeTransformer):
def visit_IfExp(self, node):
piecewise_func = ast.Name(id='_Piecewise', ctx=ast.Load())
piecewise_func = ast.copy_location(piecewise_func, node)
piecewise_args = [ast.Tuple(elts=[node.body, node.test], ctx=ast.Load()),
ast.Tuple(elts=[node.orelse, ast.NameConstant(value=True)], ctx=ast.Load())]
result = ast.Call(func=piecewise_func, args=piecewise_args, keywords=[])
return ast.copy_location(result, node)
def visit_AugAssign(self, node):
self.generic_visit(node) = ast.Load()
new_node = ast.Expr(ast.Call(func=ast.Name(id='_add_assignment', ctx=ast.Load()),
args=[, node.value],
return ast.copy_location(new_node, node)
def visit_FunctionDef(self, node):
node.decorator_list = []
return node
......@@ -505,60 +505,6 @@ def sort_assignments_topologically(assignments: Sequence[Assignment]) -> List[As
return [Assignment(a, b) for a, b in res]
def assignments_from_python_function(func, **kwargs):
"""Mechanism to simplify the generation of a list of sympy equations.
Introduces a special "assignment operator" written as "@=". Each line containing this operator gives an
equation in the result list. Note that executing this function normally yields an error.
Additionally the shortcut object 'S' is available to quickly create new sympy symbols.
>>> def my_kernel(s):
... from pystencils import Field
... f = Field.create_generic('f', spatial_dimensions=2, index_dimensions=0)
... g = f.new_field_with_different_name('g')
... s.neighbors @= f[0,1] + f[1,0]
... g[0,0] @= s.neighbors + f[0,0]
>>> assignments_from_python_function(my_kernel)
[Assignment(neighbors, f_E + f_N), Assignment(g_C, f_C + neighbors)]
import inspect
import re
assignment_regexp = re.compile(r'(\s*)(.+?)@=(.*)')
whitespace_regexp = re.compile(r'(\s*)(.*)')
source_lines = inspect.getsourcelines(func)[0]
# determine indentation
first_code_line = source_lines[1]
match_res = whitespace_regexp.match(first_code_line)
assert match_res, "First line is not indented"
num_whitespaces = len(
for i in range(1, len(source_lines)):
source_line = source_lines[i][num_whitespaces:]
if 'return' in source_line:
raise ValueError("Function may not have a return statement!")
match_res = assignment_regexp.match(source_line)
if match_res:
source_line = "%s_result.append(Assignment(%s, %s))\n" % tuple(match_res.groups()[i] for i in range(3))
source_lines[i] = source_line
code = "".join(source_lines[1:])
result = []
locals_dict = {'_result': result,
'Assignment': Assignment,
's': SymbolCreator()}
globals_dict = inspect.stack()[1][0].f_globals.copy()
exec(code, globals_dict, locals_dict)
return result
class SymbolCreator:
def __getattribute__(self, name):
return sp.Symbol(name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment