diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py index 15ee0680edb0b5b8197aec2545d182f90eb6c71a..5b37470cc89dc8b304d9fd033ea66ec67b9f8d8a 100644 --- a/src/pystencils/backend/ast/analysis.py +++ b/src/pystencils/backend/ast/analysis.py @@ -1,5 +1,7 @@ +from dataclasses import dataclass from typing import cast from functools import reduce +import operator from .structural import ( PsAssignment, @@ -12,11 +14,28 @@ from .structural import ( PsLoop, PsStatement, ) -from .expressions import PsSymbolExpr, PsConstantExpr +from .expressions import ( + PsAdd, + PsArrayAccess, + PsCall, + PsConstantExpr, + PsDiv, + PsIntDiv, + PsLiteralExpr, + PsMul, + PsNeg, + PsRem, + PsSub, + PsSymbolExpr, + PsTernary, +) from ..symbols import PsSymbol from ..exceptions import PsInternalCompilerError +from ...types import PsNumericType +from ...types.exception import PsTypeError + class UndefinedSymbolsCollector: """Collect undefined symbols. @@ -120,3 +139,212 @@ def collect_required_headers(node: PsAstNode) -> set[str]: return reduce( set.union, (collect_required_headers(c) for c in node.children), set() ) + + +@dataclass +class OperationCounts: + float_adds: int = 0 + float_muls: int = 0 + float_divs: int = 0 + int_adds: int = 0 + int_muls: int = 0 + int_divs: int = 0 + calls: int = 0 + branches: int = 0 + loops_with_dynamic_bounds: int = 0 + + def __add__(self, other): + if not isinstance(other, OperationCounts): + return NotImplemented + + return OperationCounts( + float_adds=self.float_adds + other.float_adds, + float_muls=self.float_muls + other.float_muls, + float_divs=self.float_divs + other.float_divs, + int_adds=self.int_adds + other.int_adds, + int_muls=self.int_muls + other.int_muls, + int_divs=self.int_divs + other.int_divs, + calls=self.calls + other.calls, + branches=self.branches + other.branches, + loops_with_dynamic_bounds=self.loops_with_dynamic_bounds + + other.loops_with_dynamic_bounds, + ) + + def __rmul__(self, other): + if not isinstance(other, int): + return NotImplemented + + return OperationCounts( + float_adds=other * self.float_adds, + float_muls=other * self.float_muls, + float_divs=other * self.float_divs, + int_adds=other * self.int_adds, + int_muls=other * self.int_muls, + int_divs=other * self.int_divs, + calls=other * self.calls, + branches=other * self.branches, + loops_with_dynamic_bounds=other * self.loops_with_dynamic_bounds, + ) + + +class OperationCounter: + """Counts the number of operations in an AST. + + Assumes that the AST is typed. It is recommended that constant folding is + applied prior to this pass. + + The counted operations are: + - Additions, multiplications and divisions of floating and integer type. + The counts of either type are reported separately and operations on + other types are ignored. + - Function calls. + - Branches. + Includes `PsConditional` and `PsTernary`. The operations in all branches + are summed up (i.e. the result is an overestimation). + - Loops with an unknown number of iterations. + The operations in the loop header and body are counted exactly once, + i.e. it is assumed that there is one loop iteration. + + If the start, stop and step of the loop are `PsConstantExpr`, then any + operation within the body is multiplied by the number of iterations. + """ + + def __call__(self, node: PsAstNode) -> OperationCounts: + """Counts the number of operations in the given AST.""" + return self.visit(node) + + def visit(self, node: PsAstNode) -> OperationCounts: + match node: + case PsExpression(): + return self.visit_expr(node) + + case PsStatement(expr): + return self.visit_expr(expr) + + case PsAssignment(lhs, rhs): + return self.visit_expr(lhs) + self.visit_expr(rhs) + + case PsBlock(statements): + return reduce( + operator.add, (self.visit(s) for s in statements), OperationCounts() + ) + + case PsLoop(_, start, stop, step, body): + if ( + isinstance(start, PsConstantExpr) + and isinstance(stop, PsConstantExpr) + and isinstance(step, PsConstantExpr) + ): + val_start = start.constant.value + val_stop = stop.constant.value + val_step = step.constant.value + + if (val_stop - val_start) % val_step == 0: + iteration_count = max(0, int((val_stop - val_start) / val_step)) + else: + iteration_count = max( + 0, int((val_stop - val_start) / val_step) + 1 + ) + + return self.visit_expr(start) + iteration_count * ( + OperationCounts(int_adds=1) # loop counter increment + + self.visit_expr(stop) + + self.visit_expr(step) + + self.visit(body) + ) + else: + return ( + OperationCounts(loops_with_dynamic_bounds=1) + + self.visit_expr(start) + + self.visit_expr(stop) + + self.visit_expr(step) + + self.visit(body) + ) + + case PsConditional(cond, branch_true, branch_false): + op_counts = ( + OperationCounts(branches=1) + + self.visit(cond) + + self.visit(branch_true) + ) + if branch_false is not None: + op_counts += self.visit(branch_false) + return op_counts + + case PsEmptyLeafMixIn(): + return OperationCounts() + + case unknown: + raise PsInternalCompilerError(f"Can't count operations in {unknown}") + + def visit_expr(self, expr: PsExpression) -> OperationCounts: + match expr: + case PsSymbolExpr(_) | PsConstantExpr(_) | PsLiteralExpr(_): + return OperationCounts() + + case PsArrayAccess(_, index): + return self.visit_expr(index) + + case PsCall(_, args): + return OperationCounts(calls=1) + reduce( + operator.add, (self.visit(a) for a in args), OperationCounts() + ) + + case PsTernary(cond, then, els): + return ( + OperationCounts(branches=1) + + self.visit_expr(cond) + + self.visit_expr(then) + + self.visit_expr(els) + ) + + case PsNeg(arg): + if expr.dtype is None: + raise PsTypeError(f"Untyped arithmetic expression: {expr}") + + op_counts = self.visit_expr(arg) + if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float(): + op_counts.float_muls += 1 + elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int(): + op_counts.int_muls += 1 + return op_counts + + case PsAdd(arg1, arg2) | PsSub(arg1, arg2): + if expr.dtype is None: + raise PsTypeError(f"Untyped arithmetic expression: {expr}") + + op_counts = self.visit_expr(arg1) + self.visit_expr(arg2) + if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float(): + op_counts.float_adds += 1 + elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int(): + op_counts.int_adds += 1 + return op_counts + + case PsMul(arg1, arg2): + if expr.dtype is None: + raise PsTypeError(f"Untyped arithmetic expression: {expr}") + + op_counts = self.visit_expr(arg1) + self.visit_expr(arg2) + if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float(): + op_counts.float_muls += 1 + elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int(): + op_counts.int_muls += 1 + return op_counts + + case PsDiv(arg1, arg2) | PsIntDiv(arg1, arg2) | PsRem(arg1, arg2): + if expr.dtype is None: + raise PsTypeError(f"Untyped arithmetic expression: {expr}") + + op_counts = self.visit_expr(arg1) + self.visit_expr(arg2) + if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float(): + op_counts.float_divs += 1 + elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int(): + op_counts.int_divs += 1 + return op_counts + + case _: + return reduce( + operator.add, + (self.visit_expr(cast(PsExpression, c)) for c in expr.children), + OperationCounts(), + ) diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 908f31052ed4b8d49a66cd1ce801d9841ef4fb7d..3b76e514e566c9c7113c1bbcaf1b63df71a8cb57 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -226,7 +226,7 @@ class PsSubscript(PsLvalue, PsExpression): case 1: self.index = failing_cast(PsExpression, c) - def __str__(self) -> str: + def __repr__(self) -> str: return f"Subscript({self._base})[{self._index}]" @@ -419,7 +419,7 @@ class PsCall(PsExpression): if not isinstance(other, PsCall): return False return super().structurally_equal(other) and self._function == other._function - + def __str__(self): args = ", ".join(str(arg) for arg in self._args) return f"PsCall({self._function}, ({args}))" diff --git a/tests/nbackend/test_operation_counter.py b/tests/nbackend/test_operation_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..98b004c4dbc489adc9c0aa86b0d8baccc92bc60c --- /dev/null +++ b/tests/nbackend/test_operation_counter.py @@ -0,0 +1,61 @@ +from pystencils.backend.ast.analysis import OperationCounter +from pystencils.backend.ast.expressions import ( + PsAdd, + PsConstant, + PsDiv, + PsExpression, + PsMul, + PsTernary, +) +from pystencils.backend.ast.structural import ( + PsBlock, + PsDeclaration, + PsLoop, +) + +from pystencils.backend.kernelcreation import KernelCreationContext, Typifier +from pystencils.types import PsBoolType + + +def test_count_operations(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + counter = OperationCounter() + + x = PsExpression.make(ctx.get_symbol("x")) + y = PsExpression.make(ctx.get_symbol("y")) + z = PsExpression.make(ctx.get_symbol("z")) + + i = PsExpression.make(ctx.get_symbol("i", ctx.index_dtype)) + p = PsExpression.make(ctx.get_symbol("p", PsBoolType())) + + zero = PsExpression.make(PsConstant(0, ctx.index_dtype)) + two = PsExpression.make(PsConstant(2, ctx.index_dtype)) + five = PsExpression.make(PsConstant(5, ctx.index_dtype)) + + ast = PsLoop( + i, + zero, + five, + two, + PsBlock( + [ + PsDeclaration(x, PsAdd(y, z)), + PsDeclaration(y, PsMul(x, PsMul(y, z))), + PsDeclaration(z, PsDiv(PsDiv(PsDiv(x, y), z), PsTernary(p, x, y))), + ] + ), + ) + ast = typify(ast) + + op_count = counter(ast) + + assert op_count.float_adds == 3 * 1 + assert op_count.float_muls == 3 * 2 + assert op_count.float_divs == 3 * 3 + assert op_count.int_adds == 3 * 1 + assert op_count.int_muls == 0 + assert op_count.int_divs == 0 + assert op_count.calls == 0 + assert op_count.branches == 3 * 1 + assert op_count.loops_with_dynamic_bounds == 0