Commit c55f53dc authored by Martin Bauer's avatar Martin Bauer
Browse files

Improved method to create assignment list from python function

- uses ast instead of text-based editing
- can handle Piecewise defined functions
- new decorator for simplified usage
parent 3bbc5d49
"""Module to generate stencil kernels in C or CUDA using sympy expressions and call them as Python functions""" """Module to generate stencil kernels in C or CUDA using sympy expressions and call them as Python functions"""
from . import sympy_gmpy_bug_workaround # NOQA from . import sympy_gmpy_bug_workaround # NOQA
from .field import Field, FieldType from .field import Field, FieldType, fields
from .data_types import TypedSymbol from .data_types import TypedSymbol
from .slicing import make_slice from .slicing import make_slice
from .kernelcreation import create_kernel, create_indexed_kernel from .kernelcreation import create_kernel, create_indexed_kernel
...@@ -9,8 +9,9 @@ from .assignment_collection import AssignmentCollection ...@@ -9,8 +9,9 @@ from .assignment_collection import AssignmentCollection
from .assignment import Assignment from .assignment import Assignment
from .sympyextensions import SymbolCreator from .sympyextensions import SymbolCreator
from .datahandling import create_data_handling from .datahandling import create_data_handling
from . kernel_decorator import kernel
__all__ = ['Field', 'FieldType', __all__ = ['Field', 'FieldType', 'fields',
'TypedSymbol', 'TypedSymbol',
'make_slice', 'make_slice',
'create_kernel', 'create_indexed_kernel', 'create_kernel', 'create_indexed_kernel',
...@@ -18,4 +19,5 @@ __all__ = ['Field', 'FieldType', ...@@ -18,4 +19,5 @@ __all__ = ['Field', 'FieldType',
'AssignmentCollection', 'AssignmentCollection',
'Assignment', 'Assignment',
'SymbolCreator', 'SymbolCreator',
'create_data_handling'] 'create_data_handling',
'kernel']
import ast
import inspect
import sympy as sp
import textwrap
from pystencils.sympyextensions import SymbolCreator
from pystencils.assignment import Assignment
__all__ = ['kernel']
def kernel(func, **kwargs):
"""Decorator to simplify generation of pystencils Assignments.
Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
in the result list. Furthermore the meaning of the ternary inline 'if-else' changes meaning to denote a
sympy Piecewise.
The decorated function may not receive any arguments, with exception of an argument called 's' that specifies
a SymbolCreator()
Examples:
>>> import pystencils as ps
>>> @kernel
... def my_kernel(s):
... f, g = ps.fields('f, g: [2D]')
... s.neighbors @= f[0,1] + f[1,0]
... g[0,0] @= s.neighbors + f[0,0] if f[0,0] > 0 else 0
>>> my_kernel
[Assignment(neighbors, f_E + f_N), Assignment(g_C, Piecewise((f_C + neighbors, f_C > 0), (0, True)))]
"""
source = inspect.getsource(func)
source = textwrap.dedent(source)
a = ast.parse(source)
KernelFunctionRewrite().visit(a)
ast.fix_missing_locations(a)
gl = func.__globals__.copy()
assignments = []
def assignment_adder(lhs, rhs):
assignments.append(Assignment(lhs, rhs))
gl['_add_assignment'] = assignment_adder
gl['_Piecewise'] = sp.Piecewise
gl.update(inspect.getclosurevars(func).nonlocals)
exec(compile(a, filename="<ast>", mode="exec"), gl)
func = gl[func.__name__]
args = inspect.getfullargspec(func).args
if 's' in args and 's' not in kwargs:
kwargs['s'] = SymbolCreator()
func(**kwargs)
return assignments
# noinspection PyMethodMayBeStatic
class KernelFunctionRewrite(ast.NodeTransformer):
def visit_IfExp(self, node):
piecewise_func = ast.Name(id='_Piecewise', ctx=ast.Load())
piecewise_func = ast.copy_location(piecewise_func, node)
piecewise_args = [ast.Tuple(elts=[node.body, node.test], ctx=ast.Load()),
ast.Tuple(elts=[node.orelse, ast.NameConstant(value=True)], ctx=ast.Load())]
result = ast.Call(func=piecewise_func, args=piecewise_args, keywords=[])
return ast.copy_location(result, node)
def visit_AugAssign(self, node):
self.generic_visit(node)
node.target.ctx = ast.Load()
new_node = ast.Expr(ast.Call(func=ast.Name(id='_add_assignment', ctx=ast.Load()),
args=[node.target, node.value],
keywords=[]))
return ast.copy_location(new_node, node)
def visit_FunctionDef(self, node):
self.generic_visit(node)
node.decorator_list = []
return node
...@@ -505,60 +505,6 @@ def sort_assignments_topologically(assignments: Sequence[Assignment]) -> List[As ...@@ -505,60 +505,6 @@ def sort_assignments_topologically(assignments: Sequence[Assignment]) -> List[As
return [Assignment(a, b) for a, b in res] return [Assignment(a, b) for a, b in res]
def assignments_from_python_function(func, **kwargs):
"""Mechanism to simplify the generation of a list of sympy equations.
Introduces a special "assignment operator" written as "@=". Each line containing this operator gives an
equation in the result list. Note that executing this function normally yields an error.
Additionally the shortcut object 'S' is available to quickly create new sympy symbols.
Examples:
>>> def my_kernel(s):
... from pystencils import Field
... f = Field.create_generic('f', spatial_dimensions=2, index_dimensions=0)
... g = f.new_field_with_different_name('g')
...
... s.neighbors @= f[0,1] + f[1,0]
... g[0,0] @= s.neighbors + f[0,0]
>>> assignments_from_python_function(my_kernel)
[Assignment(neighbors, f_E + f_N), Assignment(g_C, f_C + neighbors)]
"""
import inspect
import re
assignment_regexp = re.compile(r'(\s*)(.+?)@=(.*)')
whitespace_regexp = re.compile(r'(\s*)(.*)')
source_lines = inspect.getsourcelines(func)[0]
# determine indentation
first_code_line = source_lines[1]
match_res = whitespace_regexp.match(first_code_line)
assert match_res, "First line is not indented"
num_whitespaces = len(match_res.group(1))
for i in range(1, len(source_lines)):
source_line = source_lines[i][num_whitespaces:]
if 'return' in source_line:
raise ValueError("Function may not have a return statement!")
match_res = assignment_regexp.match(source_line)
if match_res:
source_line = "%s_result.append(Assignment(%s, %s))\n" % tuple(match_res.groups()[i] for i in range(3))
source_lines[i] = source_line
code = "".join(source_lines[1:])
result = []
locals_dict = {'_result': result,
'Assignment': Assignment,
's': SymbolCreator()}
locals_dict.update(kwargs)
globals_dict = inspect.stack()[1][0].f_globals.copy()
globals_dict.update(inspect.stack()[1][0].f_locals)
exec(code, globals_dict, locals_dict)
return result
class SymbolCreator: class SymbolCreator:
def __getattribute__(self, name): def __getattribute__(self, name):
return sp.Symbol(name) return sp.Symbol(name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment