From c55f53dc50453b571cc99a3dc362f5020c71f07e Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Tue, 24 Apr 2018 13:26:05 +0200 Subject: [PATCH] Improved method to create assignment list from python function - uses ast instead of text-based editing - can handle Piecewise defined functions - new decorator for simplified usage --- __init__.py | 8 +++-- kernel_decorator.py | 78 +++++++++++++++++++++++++++++++++++++++++++++ sympyextensions.py | 54 ------------------------------- 3 files changed, 83 insertions(+), 57 deletions(-) create mode 100644 kernel_decorator.py diff --git a/__init__.py b/__init__.py index 8f7deed40..4ba13340e 100644 --- a/__init__.py +++ b/__init__.py @@ -1,6 +1,6 @@ """Module to generate stencil kernels in C or CUDA using sympy expressions and call them as Python functions""" from . import sympy_gmpy_bug_workaround # NOQA -from .field import Field, FieldType +from .field import Field, FieldType, fields from .data_types import TypedSymbol from .slicing import make_slice from .kernelcreation import create_kernel, create_indexed_kernel @@ -9,8 +9,9 @@ from .assignment_collection import AssignmentCollection from .assignment import Assignment from .sympyextensions import SymbolCreator from .datahandling import create_data_handling +from . kernel_decorator import kernel -__all__ = ['Field', 'FieldType', +__all__ = ['Field', 'FieldType', 'fields', 'TypedSymbol', 'make_slice', 'create_kernel', 'create_indexed_kernel', @@ -18,4 +19,5 @@ __all__ = ['Field', 'FieldType', 'AssignmentCollection', 'Assignment', 'SymbolCreator', - 'create_data_handling'] + 'create_data_handling', + 'kernel'] diff --git a/kernel_decorator.py b/kernel_decorator.py new file mode 100644 index 000000000..fa496c8a7 --- /dev/null +++ b/kernel_decorator.py @@ -0,0 +1,78 @@ +import ast +import inspect +import sympy as sp +import textwrap +from pystencils.sympyextensions import SymbolCreator +from pystencils.assignment import Assignment + +__all__ = ['kernel'] + + +def kernel(func, **kwargs): + """Decorator to simplify generation of pystencils Assignments. + + Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment + in the result list. Furthermore the meaning of the ternary inline 'if-else' changes meaning to denote a + sympy Piecewise. + + The decorated function may not receive any arguments, with exception of an argument called 's' that specifies + a SymbolCreator() + + Examples: + >>> import pystencils as ps + >>> @kernel + ... def my_kernel(s): + ... f, g = ps.fields('f, g: [2D]') + ... s.neighbors @= f[0,1] + f[1,0] + ... g[0,0] @= s.neighbors + f[0,0] if f[0,0] > 0 else 0 + >>> my_kernel + [Assignment(neighbors, f_E + f_N), Assignment(g_C, Piecewise((f_C + neighbors, f_C > 0), (0, True)))] + """ + source = inspect.getsource(func) + source = textwrap.dedent(source) + a = ast.parse(source) + KernelFunctionRewrite().visit(a) + ast.fix_missing_locations(a) + gl = func.__globals__.copy() + + assignments = [] + + def assignment_adder(lhs, rhs): + assignments.append(Assignment(lhs, rhs)) + + gl['_add_assignment'] = assignment_adder + gl['_Piecewise'] = sp.Piecewise + gl.update(inspect.getclosurevars(func).nonlocals) + exec(compile(a, filename="<ast>", mode="exec"), gl) + func = gl[func.__name__] + args = inspect.getfullargspec(func).args + if 's' in args and 's' not in kwargs: + kwargs['s'] = SymbolCreator() + func(**kwargs) + return assignments + + +# noinspection PyMethodMayBeStatic +class KernelFunctionRewrite(ast.NodeTransformer): + + def visit_IfExp(self, node): + piecewise_func = ast.Name(id='_Piecewise', ctx=ast.Load()) + piecewise_func = ast.copy_location(piecewise_func, node) + piecewise_args = [ast.Tuple(elts=[node.body, node.test], ctx=ast.Load()), + ast.Tuple(elts=[node.orelse, ast.NameConstant(value=True)], ctx=ast.Load())] + result = ast.Call(func=piecewise_func, args=piecewise_args, keywords=[]) + + return ast.copy_location(result, node) + + def visit_AugAssign(self, node): + self.generic_visit(node) + node.target.ctx = ast.Load() + new_node = ast.Expr(ast.Call(func=ast.Name(id='_add_assignment', ctx=ast.Load()), + args=[node.target, node.value], + keywords=[])) + return ast.copy_location(new_node, node) + + def visit_FunctionDef(self, node): + self.generic_visit(node) + node.decorator_list = [] + return node diff --git a/sympyextensions.py b/sympyextensions.py index 55859f0f5..ec27fd238 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -505,60 +505,6 @@ def sort_assignments_topologically(assignments: Sequence[Assignment]) -> List[As return [Assignment(a, b) for a, b in res] -def assignments_from_python_function(func, **kwargs): - """Mechanism to simplify the generation of a list of sympy equations. - - Introduces a special "assignment operator" written as "@=". Each line containing this operator gives an - equation in the result list. Note that executing this function normally yields an error. - - Additionally the shortcut object 'S' is available to quickly create new sympy symbols. - - Examples: - >>> def my_kernel(s): - ... from pystencils import Field - ... f = Field.create_generic('f', spatial_dimensions=2, index_dimensions=0) - ... g = f.new_field_with_different_name('g') - ... - ... s.neighbors @= f[0,1] + f[1,0] - ... g[0,0] @= s.neighbors + f[0,0] - >>> assignments_from_python_function(my_kernel) - [Assignment(neighbors, f_E + f_N), Assignment(g_C, f_C + neighbors)] - """ - import inspect - import re - - assignment_regexp = re.compile(r'(\s*)(.+?)@=(.*)') - whitespace_regexp = re.compile(r'(\s*)(.*)') - source_lines = inspect.getsourcelines(func)[0] - - # determine indentation - first_code_line = source_lines[1] - match_res = whitespace_regexp.match(first_code_line) - assert match_res, "First line is not indented" - num_whitespaces = len(match_res.group(1)) - - for i in range(1, len(source_lines)): - source_line = source_lines[i][num_whitespaces:] - if 'return' in source_line: - raise ValueError("Function may not have a return statement!") - match_res = assignment_regexp.match(source_line) - if match_res: - source_line = "%s_result.append(Assignment(%s, %s))\n" % tuple(match_res.groups()[i] for i in range(3)) - source_lines[i] = source_line - - code = "".join(source_lines[1:]) - result = [] - locals_dict = {'_result': result, - 'Assignment': Assignment, - 's': SymbolCreator()} - locals_dict.update(kwargs) - globals_dict = inspect.stack()[1][0].f_globals.copy() - globals_dict.update(inspect.stack()[1][0].f_locals) - - exec(code, globals_dict, locals_dict) - return result - - class SymbolCreator: def __getattribute__(self, name): return sp.Symbol(name) -- GitLab