Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import sympy
import pystencils.astnodes
from pystencils.backends.cbackend import CBackend
from pystencils.data_types import TypedSymbol
class BogusDeclaration(pystencils.astnodes.Node):
"""Base class for all AST nodes."""
def __init__(self, parent=None):
self.parent = parent
@property
def args(self):
"""Returns all arguments/children of this node."""
raise set()
@property
def symbols_defined(self):
"""Set of symbols which are defined by this node."""
return {TypedSymbol('Foo', 'double')}
@property
def undefined_symbols(self):
"""Symbols which are used but are not defined inside this node."""
set()
def subs(self, subs_dict):
"""Inplace! substitute, similar to sympy's but modifies the AST inplace."""
for a in self.args:
a.subs(subs_dict)
@property
def func(self):
return self.__class__
def atoms(self, arg_type):
"""Returns a set of all descendants recursively, which are an instance of the given type."""
result = set()
for arg in self.args:
if isinstance(arg, arg_type):
result.add(arg)
result.update(arg.atoms(arg_type))
return result
class BogusUsage(pystencils.astnodes.Node):
"""Base class for all AST nodes."""
def __init__(self, requires_global: bool, parent=None):
self.parent = parent
if requires_global:
self.required_global_declarations = [BogusDeclaration()]
@property
def args(self):
"""Returns all arguments/children of this node."""
return set()
@property
def symbols_defined(self):
"""Set of symbols which are defined by this node."""
return set()
@property
def undefined_symbols(self):
"""Symbols which are used but are not defined inside this node."""
return {TypedSymbol('Foo', 'double')}
def subs(self, subs_dict):
"""Inplace! substitute, similar to sympy's but modifies the AST inplace."""
for a in self.args:
a.subs(subs_dict)
@property
def func(self):
return self.__class__
def atoms(self, arg_type):
"""Returns a set of all descendants recursively, which are an instance of the given type."""
result = set()
for arg in self.args:
if isinstance(arg, arg_type):
result.add(arg)
result.update(arg.atoms(arg_type))
return result
def test_global_definitions_with_global_symbol():
# Teach our printer to print new ast nodes
CBackend._print_BogusUsage = lambda _, __: "// Bogus would go here"
CBackend._print_BogusDeclaration = lambda _, __: "// Declaration would go here"
z, x, y = pystencils.fields("z, y, x: [2d]")
normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
ast = pystencils.create_kernel(normal_assignments)
print(pystencils.show_code(ast))
ast.body.append(BogusUsage(requires_global=True))
print(pystencils.show_code(ast))
kernel = ast.compile()
assert kernel is not None
assert TypedSymbol('Foo', 'double') not in [p.symbol for p in ast.get_parameters()]
def test_global_definitions_without_global_symbol():
# Teach our printer to print new ast nodes
CBackend._print_BogusUsage = lambda _, __: "// Bogus would go here"
CBackend._print_BogusDeclaration = lambda _, __: "// Declaration would go here"
z, x, y = pystencils.fields("z, y, x: [2d]")
normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
ast = pystencils.create_kernel(normal_assignments)
print(pystencils.show_code(ast))
ast.body.append(BogusUsage(requires_global=False))
print(pystencils.show_code(ast))
kernel = ast.compile()
assert kernel is not None
assert TypedSymbol('Foo', 'double') in [p.symbol for p in ast.get_parameters()]
def main():
test_global_definitions_with_global_symbol()
test_global_definitions_without_global_symbol()
if __name__ == '__main__':
main()