kernel_decorator.py 4.96 KB
Newer Older
1
2
3
import ast
import inspect
import textwrap
Jan Hönig's avatar
Jan Hönig committed
4
from typing import Callable, Union, List, Dict, Tuple
Martin Bauer's avatar
Martin Bauer committed
5
6
7

import sympy as sp

8
from pystencils.assignment import Assignment
Martin Bauer's avatar
Martin Bauer committed
9
from pystencils.sympyextensions import SymbolCreator
Jan Hönig's avatar
Jan Hönig committed
10
from pystencils.kernelcreation import CreateKernelConfig
11

Jan Hönig's avatar
Jan Hönig committed
12
__all__ = ['kernel', 'kernel_config']
13
14


Jan Hönig's avatar
Jan Hönig committed
15
16
17
18
19
20
21
22
def _kernel(func: Callable[..., None], **kwargs) -> Tuple[List[Assignment], str]:
    """
    Convenient function for kernel decorator to prevent code duplication
    Args:
        func: decorated function
        **kwargs: kwargs for the function
    Returns:
        assignments, function_name
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    """
    source = inspect.getsource(func)
    source = textwrap.dedent(source)
    a = ast.parse(source)
    KernelFunctionRewrite().visit(a)
    ast.fix_missing_locations(a)
    gl = func.__globals__.copy()

    assignments = []

    def assignment_adder(lhs, rhs):
        assignments.append(Assignment(lhs, rhs))

    gl['_add_assignment'] = assignment_adder
    gl['_Piecewise'] = sp.Piecewise
    gl.update(inspect.getclosurevars(func).nonlocals)
    exec(compile(a, filename="<ast>", mode="exec"), gl)
    func = gl[func.__name__]
    args = inspect.getfullargspec(func).args
    if 's' in args and 's' not in kwargs:
        kwargs['s'] = SymbolCreator()
    func(**kwargs)
Jan Hönig's avatar
Jan Hönig committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    return assignments, func.__name__


def kernel(func: Callable[..., None], **kwargs) -> List[Assignment]:
    """Decorator to simplify generation of pystencils Assignments.

    Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
    in the result list. Furthermore the meaning of the ternary inline 'if-else' changes meaning to denote a
    sympy Piecewise.

    The decorated function may not receive any arguments, with exception of an argument called 's' that specifies
    a SymbolCreator()
    Args:
        func: decorated function
        **kwargs: kwargs for the function

    Examples:
        >>> import pystencils as ps
        >>> @kernel
        ... def my_kernel(s):
        ...     f, g = ps.fields('f, g: [2D]')
        ...     s.neighbors @= f[0,1] + f[1,0]
        ...     g[0,0]      @= s.neighbors + f[0,0] if f[0,0] > 0 else 0
        >>> f, g = ps.fields('f, g: [2D]')
        >>> assert my_kernel[0].rhs == f[0,1] + f[1,0]
    """
    assignments, _ = _kernel(func, **kwargs)
    return assignments


def kernel_config(config: CreateKernelConfig, **kwargs) -> Callable[..., Dict]:
    """Decorator to simplify generation of pystencils Assignments, which takes a configuration
    and updates the function name accordingly.

    Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
    in the result list. Furthermore the meaning of the ternary inline 'if-else' changes meaning to denote a
    sympy Piecewise.

    The decorated function may not receive any arguments, with exception of an argument called 's' that specifies
    a SymbolCreator()
    Args:
        config: Specify whether to return the list with assignments, or a dictionary containing additional settings
                like func_name
    Returns:
        decorator with config

    Examples:
        >>> import pystencils as ps
        >>> config = ps.CreateKernelConfig()
        >>> @kernel_config(config)
        ... def my_kernel(s):
        ...     f, g = ps.fields('f, g: [2D]')
        ...     s.neighbors @= f[0,1] + f[1,0]
        ...     g[0,0]      @= s.neighbors + f[0,0] if f[0,0] > 0 else 0
        >>> f, g = ps.fields('f, g: [2D]')
        >>> assert my_kernel['assignments'][0].rhs == f[0,1] + f[1,0]
    """
    def decorator(func: Callable[..., None]) -> Union[List[Assignment], Dict]:
        """
        Args:
            func: decorated function
        Returns:
            Dict for unpacking into create_kernel
        """
        assignments, func_name = _kernel(func, **kwargs)
        config.function_name = func_name
        return {'assignments': assignments, 'config': config}
    return decorator
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138


# noinspection PyMethodMayBeStatic
class KernelFunctionRewrite(ast.NodeTransformer):

    def visit_IfExp(self, node):
        piecewise_func = ast.Name(id='_Piecewise', ctx=ast.Load())
        piecewise_func = ast.copy_location(piecewise_func, node)
        piecewise_args = [ast.Tuple(elts=[node.body, node.test], ctx=ast.Load()),
                          ast.Tuple(elts=[node.orelse, ast.NameConstant(value=True)], ctx=ast.Load())]
        result = ast.Call(func=piecewise_func, args=piecewise_args, keywords=[])

        return ast.copy_location(result, node)

    def visit_AugAssign(self, node):
        self.generic_visit(node)
        node.target.ctx = ast.Load()
        new_node = ast.Expr(ast.Call(func=ast.Name(id='_add_assignment', ctx=ast.Load()),
                                     args=[node.target, node.value],
                                     keywords=[]))
        return ast.copy_location(new_node, node)

    def visit_FunctionDef(self, node):
        self.generic_visit(node)
        node.decorator_list = []
        return node