Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from pystencils.backend.ast.analysis import OperationCounter
from pystencils.backend.ast.expressions import (
PsAdd,
PsConstant,
PsDiv,
PsExpression,
PsMul,
PsTernary,
)
from pystencils.backend.ast.structural import (
PsBlock,
PsDeclaration,
PsLoop,
)
from pystencils.backend.kernelcreation import KernelCreationContext, Typifier
from pystencils.types import PsBoolType
def test_count_operations():
ctx = KernelCreationContext()
typify = Typifier(ctx)
counter = OperationCounter()
x = PsExpression.make(ctx.get_symbol("x"))
y = PsExpression.make(ctx.get_symbol("y"))
z = PsExpression.make(ctx.get_symbol("z"))
i = PsExpression.make(ctx.get_symbol("i", ctx.index_dtype))
p = PsExpression.make(ctx.get_symbol("p", PsBoolType()))
zero = PsExpression.make(PsConstant(0, ctx.index_dtype))
two = PsExpression.make(PsConstant(2, ctx.index_dtype))
five = PsExpression.make(PsConstant(5, ctx.index_dtype))
ast = PsLoop(
i,
zero,
five,
two,
PsBlock(
[
PsDeclaration(x, PsAdd(y, z)),
PsDeclaration(y, PsMul(x, PsMul(y, z))),
PsDeclaration(z, PsDiv(PsDiv(PsDiv(x, y), z), PsTernary(p, x, y))),
]
),
)
ast = typify(ast)
op_count = counter(ast)
assert op_count.float_adds == 3 * 1
assert op_count.float_muls == 3 * 2
assert op_count.float_divs == 3 * 3
assert op_count.int_adds == 3 * 1
assert op_count.int_muls == 0
assert op_count.int_divs == 0
assert op_count.calls == 0
assert op_count.branches == 3 * 1
assert op_count.loops_with_dynamic_bounds == 0