Skip to content
Snippets Groups Projects
Commit 2bad904a authored by Daniel Bauer's avatar Daniel Bauer :speech_balloon: Committed by Frederik Hennig
Browse files

Add AST pass for counting operations

parent d6621ef9
Branches
Tags
1 merge request!397Add AST pass for counting operations
from dataclasses import dataclass
from typing import cast
from functools import reduce
import operator
from .structural import (
PsAssignment,
......@@ -12,11 +14,28 @@ from .structural import (
PsLoop,
PsStatement,
)
from .expressions import PsSymbolExpr, PsConstantExpr
from .expressions import (
PsAdd,
PsArrayAccess,
PsCall,
PsConstantExpr,
PsDiv,
PsIntDiv,
PsLiteralExpr,
PsMul,
PsNeg,
PsRem,
PsSub,
PsSymbolExpr,
PsTernary,
)
from ..symbols import PsSymbol
from ..exceptions import PsInternalCompilerError
from ...types import PsNumericType
from ...types.exception import PsTypeError
class UndefinedSymbolsCollector:
"""Collect undefined symbols.
......@@ -120,3 +139,212 @@ def collect_required_headers(node: PsAstNode) -> set[str]:
return reduce(
set.union, (collect_required_headers(c) for c in node.children), set()
)
@dataclass
class OperationCounts:
float_adds: int = 0
float_muls: int = 0
float_divs: int = 0
int_adds: int = 0
int_muls: int = 0
int_divs: int = 0
calls: int = 0
branches: int = 0
loops_with_dynamic_bounds: int = 0
def __add__(self, other):
if not isinstance(other, OperationCounts):
return NotImplemented
return OperationCounts(
float_adds=self.float_adds + other.float_adds,
float_muls=self.float_muls + other.float_muls,
float_divs=self.float_divs + other.float_divs,
int_adds=self.int_adds + other.int_adds,
int_muls=self.int_muls + other.int_muls,
int_divs=self.int_divs + other.int_divs,
calls=self.calls + other.calls,
branches=self.branches + other.branches,
loops_with_dynamic_bounds=self.loops_with_dynamic_bounds
+ other.loops_with_dynamic_bounds,
)
def __rmul__(self, other):
if not isinstance(other, int):
return NotImplemented
return OperationCounts(
float_adds=other * self.float_adds,
float_muls=other * self.float_muls,
float_divs=other * self.float_divs,
int_adds=other * self.int_adds,
int_muls=other * self.int_muls,
int_divs=other * self.int_divs,
calls=other * self.calls,
branches=other * self.branches,
loops_with_dynamic_bounds=other * self.loops_with_dynamic_bounds,
)
class OperationCounter:
"""Counts the number of operations in an AST.
Assumes that the AST is typed. It is recommended that constant folding is
applied prior to this pass.
The counted operations are:
- Additions, multiplications and divisions of floating and integer type.
The counts of either type are reported separately and operations on
other types are ignored.
- Function calls.
- Branches.
Includes `PsConditional` and `PsTernary`. The operations in all branches
are summed up (i.e. the result is an overestimation).
- Loops with an unknown number of iterations.
The operations in the loop header and body are counted exactly once,
i.e. it is assumed that there is one loop iteration.
If the start, stop and step of the loop are `PsConstantExpr`, then any
operation within the body is multiplied by the number of iterations.
"""
def __call__(self, node: PsAstNode) -> OperationCounts:
"""Counts the number of operations in the given AST."""
return self.visit(node)
def visit(self, node: PsAstNode) -> OperationCounts:
match node:
case PsExpression():
return self.visit_expr(node)
case PsStatement(expr):
return self.visit_expr(expr)
case PsAssignment(lhs, rhs):
return self.visit_expr(lhs) + self.visit_expr(rhs)
case PsBlock(statements):
return reduce(
operator.add, (self.visit(s) for s in statements), OperationCounts()
)
case PsLoop(_, start, stop, step, body):
if (
isinstance(start, PsConstantExpr)
and isinstance(stop, PsConstantExpr)
and isinstance(step, PsConstantExpr)
):
val_start = start.constant.value
val_stop = stop.constant.value
val_step = step.constant.value
if (val_stop - val_start) % val_step == 0:
iteration_count = max(0, int((val_stop - val_start) / val_step))
else:
iteration_count = max(
0, int((val_stop - val_start) / val_step) + 1
)
return self.visit_expr(start) + iteration_count * (
OperationCounts(int_adds=1) # loop counter increment
+ self.visit_expr(stop)
+ self.visit_expr(step)
+ self.visit(body)
)
else:
return (
OperationCounts(loops_with_dynamic_bounds=1)
+ self.visit_expr(start)
+ self.visit_expr(stop)
+ self.visit_expr(step)
+ self.visit(body)
)
case PsConditional(cond, branch_true, branch_false):
op_counts = (
OperationCounts(branches=1)
+ self.visit(cond)
+ self.visit(branch_true)
)
if branch_false is not None:
op_counts += self.visit(branch_false)
return op_counts
case PsEmptyLeafMixIn():
return OperationCounts()
case unknown:
raise PsInternalCompilerError(f"Can't count operations in {unknown}")
def visit_expr(self, expr: PsExpression) -> OperationCounts:
match expr:
case PsSymbolExpr(_) | PsConstantExpr(_) | PsLiteralExpr(_):
return OperationCounts()
case PsArrayAccess(_, index):
return self.visit_expr(index)
case PsCall(_, args):
return OperationCounts(calls=1) + reduce(
operator.add, (self.visit(a) for a in args), OperationCounts()
)
case PsTernary(cond, then, els):
return (
OperationCounts(branches=1)
+ self.visit_expr(cond)
+ self.visit_expr(then)
+ self.visit_expr(els)
)
case PsNeg(arg):
if expr.dtype is None:
raise PsTypeError(f"Untyped arithmetic expression: {expr}")
op_counts = self.visit_expr(arg)
if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float():
op_counts.float_muls += 1
elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int():
op_counts.int_muls += 1
return op_counts
case PsAdd(arg1, arg2) | PsSub(arg1, arg2):
if expr.dtype is None:
raise PsTypeError(f"Untyped arithmetic expression: {expr}")
op_counts = self.visit_expr(arg1) + self.visit_expr(arg2)
if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float():
op_counts.float_adds += 1
elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int():
op_counts.int_adds += 1
return op_counts
case PsMul(arg1, arg2):
if expr.dtype is None:
raise PsTypeError(f"Untyped arithmetic expression: {expr}")
op_counts = self.visit_expr(arg1) + self.visit_expr(arg2)
if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float():
op_counts.float_muls += 1
elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int():
op_counts.int_muls += 1
return op_counts
case PsDiv(arg1, arg2) | PsIntDiv(arg1, arg2) | PsRem(arg1, arg2):
if expr.dtype is None:
raise PsTypeError(f"Untyped arithmetic expression: {expr}")
op_counts = self.visit_expr(arg1) + self.visit_expr(arg2)
if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float():
op_counts.float_divs += 1
elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int():
op_counts.int_divs += 1
return op_counts
case _:
return reduce(
operator.add,
(self.visit_expr(cast(PsExpression, c)) for c in expr.children),
OperationCounts(),
)
......@@ -226,7 +226,7 @@ class PsSubscript(PsLvalue, PsExpression):
case 1:
self.index = failing_cast(PsExpression, c)
def __str__(self) -> str:
def __repr__(self) -> str:
return f"Subscript({self._base})[{self._index}]"
......@@ -419,7 +419,7 @@ class PsCall(PsExpression):
if not isinstance(other, PsCall):
return False
return super().structurally_equal(other) and self._function == other._function
def __str__(self):
args = ", ".join(str(arg) for arg in self._args)
return f"PsCall({self._function}, ({args}))"
......
from pystencils.backend.ast.analysis import OperationCounter
from pystencils.backend.ast.expressions import (
PsAdd,
PsConstant,
PsDiv,
PsExpression,
PsMul,
PsTernary,
)
from pystencils.backend.ast.structural import (
PsBlock,
PsDeclaration,
PsLoop,
)
from pystencils.backend.kernelcreation import KernelCreationContext, Typifier
from pystencils.types import PsBoolType
def test_count_operations():
ctx = KernelCreationContext()
typify = Typifier(ctx)
counter = OperationCounter()
x = PsExpression.make(ctx.get_symbol("x"))
y = PsExpression.make(ctx.get_symbol("y"))
z = PsExpression.make(ctx.get_symbol("z"))
i = PsExpression.make(ctx.get_symbol("i", ctx.index_dtype))
p = PsExpression.make(ctx.get_symbol("p", PsBoolType()))
zero = PsExpression.make(PsConstant(0, ctx.index_dtype))
two = PsExpression.make(PsConstant(2, ctx.index_dtype))
five = PsExpression.make(PsConstant(5, ctx.index_dtype))
ast = PsLoop(
i,
zero,
five,
two,
PsBlock(
[
PsDeclaration(x, PsAdd(y, z)),
PsDeclaration(y, PsMul(x, PsMul(y, z))),
PsDeclaration(z, PsDiv(PsDiv(PsDiv(x, y), z), PsTernary(p, x, y))),
]
),
)
ast = typify(ast)
op_count = counter(ast)
assert op_count.float_adds == 3 * 1
assert op_count.float_muls == 3 * 2
assert op_count.float_divs == 3 * 3
assert op_count.int_adds == 3 * 1
assert op_count.int_muls == 0
assert op_count.int_divs == 0
assert op_count.calls == 0
assert op_count.branches == 3 * 1
assert op_count.loops_with_dynamic_bounds == 0
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment