An error occurred while loading the file. Please try again.
-
Martin Bauer authored
- if a field access is printed in compass or index notation seems to depend on the sympy version - removed doctests that rely on the printed form
65ddbe06
kernel_decorator.py 2.72 KiB
import ast
import inspect
import sympy as sp
import textwrap
from pystencils.sympyextensions import SymbolCreator
from pystencils.assignment import Assignment
__all__ = ['kernel']
def kernel(func, **kwargs):
"""Decorator to simplify generation of pystencils Assignments.
Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
in the result list. Furthermore the meaning of the ternary inline 'if-else' changes meaning to denote a
sympy Piecewise.
The decorated function may not receive any arguments, with exception of an argument called 's' that specifies
a SymbolCreator()
Examples:
>>> import pystencils as ps
>>> @kernel
... def my_kernel(s):
... f, g = ps.fields('f, g: [2D]')
... s.neighbors @= f[0,1] + f[1,0]
... g[0,0] @= s.neighbors + f[0,0] if f[0,0] > 0 else 0
>>> f, g = ps.fields('f, g: [2D]')
>>> assert my_kernel[0].rhs == f[0,1] + f[1,0]
"""
source = inspect.getsource(func)
source = textwrap.dedent(source)
a = ast.parse(source)
KernelFunctionRewrite().visit(a)
ast.fix_missing_locations(a)
gl = func.__globals__.copy()
assignments = []
def assignment_adder(lhs, rhs):
assignments.append(Assignment(lhs, rhs))
gl['_add_assignment'] = assignment_adder
gl['_Piecewise'] = sp.Piecewise
gl.update(inspect.getclosurevars(func).nonlocals)
exec(compile(a, filename="<ast>", mode="exec"), gl)
func = gl[func.__name__]
args = inspect.getfullargspec(func).args
if 's' in args and 's' not in kwargs:
kwargs['s'] = SymbolCreator()
func(**kwargs)
return assignments
# noinspection PyMethodMayBeStatic
class KernelFunctionRewrite(ast.NodeTransformer):
def visit_IfExp(self, node):
piecewise_func = ast.Name(id='_Piecewise', ctx=ast.Load())
piecewise_func = ast.copy_location(piecewise_func, node)
piecewise_args = [ast.Tuple(elts=[node.body, node.test], ctx=ast.Load()),
ast.Tuple(elts=[node.orelse, ast.NameConstant(value=True)], ctx=ast.Load())]
result = ast.Call(func=piecewise_func, args=piecewise_args, keywords=[])
return ast.copy_location(result, node)
def visit_AugAssign(self, node):
self.generic_visit(node)
node.target.ctx = ast.Load()
new_node = ast.Expr(ast.Call(func=ast.Name(id='_add_assignment', ctx=ast.Load()),
args=[node.target, node.value],
keywords=[]))
return ast.copy_location(new_node, node)
def visit_FunctionDef(self, node):
self.generic_visit(node)
node.decorator_list = []
return node