Skip to content
Snippets Groups Projects
test_operation_counter.py 1.66 KiB
Newer Older
from pystencils.backend.ast.analysis import OperationCounter
from pystencils.backend.ast.expressions import (
    PsAdd,
    PsConstant,
    PsDiv,
    PsExpression,
    PsMul,
    PsTernary,
)
from pystencils.backend.ast.structural import (
    PsBlock,
    PsDeclaration,
    PsLoop,
)

from pystencils.backend.kernelcreation import KernelCreationContext, Typifier
from pystencils.types import PsBoolType


def test_count_operations():
    ctx = KernelCreationContext()
    typify = Typifier(ctx)
    counter = OperationCounter()

    x = PsExpression.make(ctx.get_symbol("x"))
    y = PsExpression.make(ctx.get_symbol("y"))
    z = PsExpression.make(ctx.get_symbol("z"))

    i = PsExpression.make(ctx.get_symbol("i", ctx.index_dtype))
    p = PsExpression.make(ctx.get_symbol("p", PsBoolType()))

    zero = PsExpression.make(PsConstant(0, ctx.index_dtype))
    two = PsExpression.make(PsConstant(2, ctx.index_dtype))
    five = PsExpression.make(PsConstant(5, ctx.index_dtype))

    ast = PsLoop(
        i,
        zero,
        five,
        two,
        PsBlock(
            [
                PsDeclaration(x, PsAdd(y, z)),
                PsDeclaration(y, PsMul(x, PsMul(y, z))),
                PsDeclaration(z, PsDiv(PsDiv(PsDiv(x, y), z), PsTernary(p, x, y))),
            ]
        ),
    )
    ast = typify(ast)

    op_count = counter(ast)

    assert op_count.float_adds == 3 * 1
    assert op_count.float_muls == 3 * 2
    assert op_count.float_divs == 3 * 3
    assert op_count.int_adds == 3 * 1
    assert op_count.int_muls == 0
    assert op_count.int_divs == 0
    assert op_count.calls == 0
    assert op_count.branches == 3 * 1
    assert op_count.loops_with_dynamic_bounds == 0