test_global_definitions.py 4.28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.

"""

"""

import sympy

import pystencils.astnodes
from pystencils.backends.cbackend import CBackend
from pystencils.data_types import TypedSymbol


class BogusDeclaration(pystencils.astnodes.Node):
    """Base class for all AST nodes."""

    def __init__(self, parent=None):
        self.parent = parent

    @property
    def args(self):
        """Returns all arguments/children of this node."""
        raise set()

    @property
    def symbols_defined(self):
        """Set of symbols which are defined by this node."""
        return {TypedSymbol('Foo', 'double')}

    @property
    def undefined_symbols(self):
        """Symbols which are used but are not defined inside this node."""
        set()

    def subs(self, subs_dict):
        """Inplace! substitute, similar to sympy's but modifies the AST inplace."""
        for a in self.args:
            a.subs(subs_dict)

    @property
    def func(self):
        return self.__class__

    def atoms(self, arg_type):
        """Returns a set of all descendants recursively, which are an instance of the given type."""
        result = set()
        for arg in self.args:
            if isinstance(arg, arg_type):
                result.add(arg)
            result.update(arg.atoms(arg_type))
        return result


class BogusUsage(pystencils.astnodes.Node):
    """Base class for all AST nodes."""

    def __init__(self, requires_global: bool, parent=None):
        self.parent = parent
        if requires_global:
            self.required_global_declarations = [BogusDeclaration()]

    @property
    def args(self):
        """Returns all arguments/children of this node."""
        return set()

    @property
    def symbols_defined(self):
        """Set of symbols which are defined by this node."""
        return set()

    @property
    def undefined_symbols(self):
        """Symbols which are used but are not defined inside this node."""
        return {TypedSymbol('Foo', 'double')}

    def subs(self, subs_dict):
        """Inplace! substitute, similar to sympy's but modifies the AST inplace."""
        for a in self.args:
            a.subs(subs_dict)

    @property
    def func(self):
        return self.__class__

    def atoms(self, arg_type):
        """Returns a set of all descendants recursively, which are an instance of the given type."""
        result = set()
        for arg in self.args:
            if isinstance(arg, arg_type):
                result.add(arg)
            result.update(arg.atoms(arg_type))
        return result


def test_global_definitions_with_global_symbol():
    # Teach our printer to print new ast nodes
    CBackend._print_BogusUsage = lambda _, __: "// Bogus would go here"
    CBackend._print_BogusDeclaration = lambda _, __: "// Declaration would go here"

    z, x, y = pystencils.fields("z, y, x: [2d]")

    normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
        z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])

    ast = pystencils.create_kernel(normal_assignments)
    print(pystencils.show_code(ast))
    ast.body.append(BogusUsage(requires_global=True))
    print(pystencils.show_code(ast))
    kernel = ast.compile()
    assert kernel is not None

    assert TypedSymbol('Foo', 'double') not in [p.symbol for p in ast.get_parameters()]


def test_global_definitions_without_global_symbol():
    # Teach our printer to print new ast nodes
    CBackend._print_BogusUsage = lambda _, __: "// Bogus would go here"
    CBackend._print_BogusDeclaration = lambda _, __: "// Declaration would go here"

    z, x, y = pystencils.fields("z, y, x: [2d]")

    normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
        z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])

    ast = pystencils.create_kernel(normal_assignments)
    print(pystencils.show_code(ast))
    ast.body.append(BogusUsage(requires_global=False))
    print(pystencils.show_code(ast))
    kernel = ast.compile()
    assert kernel is not None

    assert TypedSymbol('Foo', 'double') in [p.symbol for p in ast.get_parameters()]


def main():
    test_global_definitions_with_global_symbol()
    test_global_definitions_without_global_symbol()


if __name__ == '__main__':
    main()