test_destructuring_field_class.py 1.56 KB
Newer Older
1
import sympy
2
3
import jinja2

4
5
6

import pystencils
from pystencils.astnodes import DestructuringBindingsForFieldClass
7
8
from pystencils.kernelparameters import  FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol

9
10
11
12
13
14
15
16


def test_destructuring_field_class():
    z, x, y = pystencils.fields("z, y, x: [2d]")

    normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
        z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])

17
    ast = pystencils.create_kernel(normal_assignments, target='gpu')
18
19
20
21
    print(pystencils.show_code(ast))

    ast.body = DestructuringBindingsForFieldClass(ast.body)
    print(pystencils.show_code(ast))
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    ast.compile()


class DestructuringEmojiClass(DestructuringBindingsForFieldClass):
    CLASS_TO_MEMBER_DICT = {
        FieldPointerSymbol: "🥶",
        FieldShapeSymbol: "😳_%i",
        FieldStrideSymbol: "🥵_%i"
    }
    CLASS_NAME_TEMPLATE = jinja2.Template("🤯<{{ dtype }}, {{ ndim }}>")
    def __init__(self, node):
        super().__init__(node)
        self.headers = []
        
    
def test_destructuring_alternative_field_class():
    z, x, y = pystencils.fields("z, y, x: [2d]")
39

40
41
42
43
44
45
    normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
        z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])

    ast = pystencils.create_kernel(normal_assignments, target='gpu')
    ast.body = DestructuringEmojiClass(ast.body)
    print(pystencils.show_code(ast))
46
47
48

def main():
    test_destructuring_field_class()
49
    test_destructuring_alternative_field_class()
50
51
52
53


if __name__ == '__main__':
    main()