diff --git a/docs/source/backend/ast.rst b/docs/source/backend/ast.rst index 44f8f25409ad08cff1d550dddcd9099bbbc798df..823ee1d97a2919df3ab349090439e7c62a2bc358 100644 --- a/docs/source/backend/ast.rst +++ b/docs/source/backend/ast.rst @@ -2,13 +2,15 @@ Abstract Syntax Tree ******************** +.. automodule:: pystencils.backend.ast + API Documentation ================= Inheritance Diagram ------------------- -.. inheritance-diagram:: pystencils.backend.ast.astnode.PsAstNode pystencils.backend.ast.structural pystencils.backend.ast.expressions pystencils.backend.extensions.foreign_ast +.. inheritance-diagram:: pystencils.backend.ast.astnode.PsAstNode pystencils.backend.ast.structural pystencils.backend.ast.expressions pystencils.backend.ast.vector pystencils.backend.extensions.foreign_ast :top-classes: pystencils.types.PsAstNode :parts: 1 @@ -29,3 +31,9 @@ Expressions .. automodule:: pystencils.backend.ast.expressions :members: + +SIMD Nodes +---------- + +.. automodule:: pystencils.backend.ast.vector + :members: diff --git a/docs/source/backend/errors.rst b/docs/source/backend/errors.rst new file mode 100644 index 0000000000000000000000000000000000000000..231b7e295ad4a82b27a21229d4ddcf5ae1cfff3f --- /dev/null +++ b/docs/source/backend/errors.rst @@ -0,0 +1,6 @@ +********************* +Errors and Exceptions +********************* + +.. automodule:: pystencils.backend.exceptions + :members: diff --git a/docs/source/backend/index.rst b/docs/source/backend/index.rst index 70ed684c69e494a2c9d72f91fdcc62641f4acd69..a1d39410bbb012c4bafc0a8f91538eab0d30276a 100644 --- a/docs/source/backend/index.rst +++ b/docs/source/backend/index.rst @@ -16,6 +16,7 @@ who wish to customize or extend the behaviour of the code generator in their app platforms transformations output + errors jit extensions diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index 8b70da86b72324be68b91a82f1f9b5d98435eb55..66533a0b7a71ad867582e9159b04d1c7c78561cf 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -1,6 +1,6 @@ """Module to generate stencil kernels in C or CUDA using sympy expressions and call them as Python functions""" -from .enums import Target +from .target import Target from .defaults import DEFAULTS from . import fd from . import stencil as stencil diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py index 3c6d2ef557e44a882edf4e104df4bd4e2a8830fd..edeba04f2b8e5d8727abe1150b9e574808e6811a 100644 --- a/src/pystencils/backend/ast/analysis.py +++ b/src/pystencils/backend/ast/analysis.py @@ -29,7 +29,7 @@ from .expressions import ( PsSymbolExpr, PsTernary, PsSubscript, - PsMemAcc + PsMemAcc, ) from ..memory import PsSymbol @@ -285,11 +285,8 @@ class OperationCounter: return OperationCounts() case PsBufferAcc(_, indices) | PsSubscript(_, indices): - return reduce( - operator.add, - (self.visit_expr(idx) for idx in indices) - ) - + return reduce(operator.add, (self.visit_expr(idx) for idx in indices)) + case PsMemAcc(_, offset): return self.visit_expr(offset) diff --git a/src/pystencils/backend/ast/astnode.py b/src/pystencils/backend/ast/astnode.py index 4ef557fe1cc855b3b24d88ffcab377e87c74c384..64374f1a211a6dcfcc383c9e743cb217bc21630a 100644 --- a/src/pystencils/backend/ast/astnode.py +++ b/src/pystencils/backend/ast/astnode.py @@ -48,6 +48,10 @@ class PsAstNode(ABC): for c1, c2 in zip(self.children, other.children) ) ) + + def __str__(self) -> str: + from ..emission import emit_ir + return emit_ir(self) class PsLeafMixIn(ABC): diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index d73b1faa758f8ce31312c674712ec89bfd5683ab..a850470ffacf4f528ca5883e5d91b14fa6aa5f9c 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -11,10 +11,8 @@ from ..memory import PsSymbol, PsBuffer, BufferBasePtr from ..constants import PsConstant from ..literals import PsLiteral from ..functions import PsFunction -from ...types import ( - PsType, - PsVectorType, -) +from ...types import PsType + from .util import failing_cast from ..exceptions import PsInternalCompilerError @@ -58,15 +56,23 @@ class PsExpression(PsAstNode, ABC): return self._dtype def __add__(self, other: PsExpression) -> PsAdd: + if not isinstance(other, PsExpression): + return NotImplemented return PsAdd(self, other) def __sub__(self, other: PsExpression) -> PsSub: + if not isinstance(other, PsExpression): + return NotImplemented return PsSub(self, other) def __mul__(self, other: PsExpression) -> PsMul: + if not isinstance(other, PsExpression): + return NotImplemented return PsMul(self, other) def __truediv__(self, other: PsExpression) -> PsDiv: + if not isinstance(other, PsExpression): + return NotImplemented return PsDiv(self, other) def __neg__(self) -> PsNeg: @@ -100,7 +106,7 @@ class PsExpression(PsAstNode, ABC): def clone(self): """Clone this expression. - + .. note:: Subclasses of `PsExpression` should not override this method, but implement `_clone_expr` instead. @@ -115,7 +121,7 @@ class PsExpression(PsAstNode, ABC): @abstractmethod def _clone_expr(self) -> PsExpression: """Implementation of expression cloning. - + :meta public: """ pass @@ -362,61 +368,6 @@ class PsMemAcc(PsLvalue, PsExpression): return f"PsMemAcc({repr(self._ptr)}, {repr(self._offset)})" -class PsVectorMemAcc(PsMemAcc): - """Pointer-based vectorized memory access.""" - - __match_args__ = ("base_ptr", "base_index") - - def __init__( - self, - base_ptr: PsExpression, - base_index: PsExpression, - vector_entries: int, - stride: int = 1, - alignment: int = 0, - ): - super().__init__(base_ptr, base_index) - - self._vector_entries = vector_entries - self._stride = stride - self._alignment = alignment - - @property - def vector_entries(self) -> int: - return self._vector_entries - - @property - def stride(self) -> int: - return self._stride - - @property - def alignment(self) -> int: - return self._alignment - - def get_vector_type(self) -> PsVectorType: - return cast(PsVectorType, self._dtype) - - def _clone_expr(self) -> PsVectorMemAcc: - return PsVectorMemAcc( - self._ptr.clone(), - self._offset.clone(), - self.vector_entries, - self._stride, - self._alignment, - ) - - def structurally_equal(self, other: PsAstNode) -> bool: - if not isinstance(other, PsVectorMemAcc): - return False - - return ( - super().structurally_equal(other) - and self._vector_entries == other._vector_entries - and self._stride == other._stride - and self._alignment == other._alignment - ) - - class PsLookup(PsExpression, PsLvalue): __match_args__ = ("aggregate", "member_name") @@ -508,9 +459,9 @@ class PsCall(PsExpression): return False return super().structurally_equal(other) and self._function == other._function - def __str__(self): - args = ", ".join(str(arg) for arg in self._args) - return f"PsCall({self._function}, ({args}))" + def __repr__(self): + args = ", ".join(repr(arg) for arg in self._args) + return f"PsCall({repr(self._function)}, ({args}))" class PsTernary(PsExpression): @@ -554,9 +505,6 @@ class PsTernary(PsExpression): case 2: self._else = failing_cast(PsExpression, c) - def __str__(self) -> str: - return f"PsTernary({self._cond}, {self._then}, {self._else})" - def __repr__(self) -> str: return f"PsTernary({repr(self._cond)}, {repr(self._then)}, {repr(self._else)})" @@ -778,19 +726,19 @@ class PsBitwiseOr(PsBinOp, PsIntOpTrait): class PsAnd(PsBinOp, PsBoolOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.and_ + return np.logical_and class PsOr(PsBinOp, PsBoolOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.or_ + return np.logical_or class PsNot(PsUnOp, PsBoolOpTrait): @property def python_operator(self) -> Callable[[Any], Any] | None: - return operator.not_ + return np.logical_not class PsRel(PsBinOp): diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index 3ae462c41c0170dcaa4a27adbd6d039df8c099d8..1d716fa9676db4a2dda8f5241990461cab48d835 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -306,6 +306,9 @@ class PsConditional(PsAstNode): case _: assert False, "unreachable code" + def __repr__(self) -> str: + return f"PsConditional({repr(self._condition)}, {repr(self._branch_true)}, {repr(self._branch_false)})" + class PsEmptyLeafMixIn: """Mix-in marking AST leaves that can be treated as empty by the code generator, diff --git a/src/pystencils/backend/ast/vector.py b/src/pystencils/backend/ast/vector.py new file mode 100644 index 0000000000000000000000000000000000000000..705d250949f3662695d506feeff30c20649eb1c5 --- /dev/null +++ b/src/pystencils/backend/ast/vector.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from typing import cast + +from .astnode import PsAstNode +from .expressions import PsExpression, PsLvalue, PsUnOp +from .util import failing_cast + +from ...types import PsVectorType + + +class PsVectorOp: + """Mix-in for vector operations""" + + +class PsVecBroadcast(PsUnOp, PsVectorOp): + """Broadcast a scalar value to N vector lanes.""" + + __match_args__ = ("lanes", "operand") + + def __init__(self, lanes: int, operand: PsExpression): + super().__init__(operand) + self._lanes = lanes + + @property + def lanes(self) -> int: + return self._lanes + + @lanes.setter + def lanes(self, n: int): + self._lanes = n + + def _clone_expr(self) -> PsVecBroadcast: + return PsVecBroadcast(self._lanes, self._operand.clone()) + + def structurally_equal(self, other: PsAstNode) -> bool: + if not isinstance(other, PsVecBroadcast): + return False + return ( + super().structurally_equal(other) + and self._lanes == other._lanes + ) + + +class PsVecMemAcc(PsExpression, PsLvalue, PsVectorOp): + """Pointer-based vectorized memory access. + + Args: + base_ptr: Pointer identifying the accessed memory region + offset: Offset inside the memory region + vector_entries: Number of elements to access + stride: Optional integer step size for strided access, or ``None`` for contiguous access + aligned: For contiguous accesses, whether the access is guaranteed to be naturally aligned + according to the vector data type + """ + + __match_args__ = ("pointer", "offset", "vector_entries", "stride", "aligned") + + def __init__( + self, + base_ptr: PsExpression, + offset: PsExpression, + vector_entries: int, + stride: PsExpression | None = None, + aligned: bool = False, + ): + super().__init__() + + self._ptr = base_ptr + self._offset = offset + self._vector_entries = vector_entries + self._stride = stride + self._aligned = aligned + + @property + def pointer(self) -> PsExpression: + return self._ptr + + @pointer.setter + def pointer(self, expr: PsExpression): + self._ptr = expr + + @property + def offset(self) -> PsExpression: + return self._offset + + @offset.setter + def offset(self, expr: PsExpression): + self._offset = expr + + @property + def vector_entries(self) -> int: + return self._vector_entries + + @property + def stride(self) -> PsExpression | None: + return self._stride + + @stride.setter + def stride(self, expr: PsExpression | None): + self._stride = expr + + @property + def aligned(self) -> bool: + return self._aligned + + def get_vector_type(self) -> PsVectorType: + return cast(PsVectorType, self._dtype) + + def get_children(self) -> tuple[PsAstNode, ...]: + return (self._ptr, self._offset) + (() if self._stride is None else (self._stride,)) + + def set_child(self, idx: int, c: PsAstNode): + idx = [0, 1, 2][idx] + match idx: + case 0: + self._ptr = failing_cast(PsExpression, c) + case 1: + self._offset = failing_cast(PsExpression, c) + case 2: + self._stride = failing_cast(PsExpression, c) + + def _clone_expr(self) -> PsVecMemAcc: + return PsVecMemAcc( + self._ptr.clone(), + self._offset.clone(), + self.vector_entries, + self._stride.clone() if self._stride is not None else None, + self._aligned, + ) + + def structurally_equal(self, other: PsAstNode) -> bool: + if not isinstance(other, PsVecMemAcc): + return False + + return ( + super().structurally_equal(other) + and self._vector_entries == other._vector_entries + and self._aligned == other._aligned + ) + + def __repr__(self) -> str: + return ( + f"PsVecMemAcc({repr(self._ptr)}, {repr(self._offset)}, {repr(self._vector_entries)}, " + f"stride={repr(self._stride)}, aligned={repr(self._aligned)})" + ) diff --git a/src/pystencils/backend/constants.py b/src/pystencils/backend/constants.py index b867d89d34597ef524c36a1eb7e720b2dadf0cd2..10e8946d9c2ae3c5e8fb6198590be7b269f379a9 100644 --- a/src/pystencils/backend/constants.py +++ b/src/pystencils/backend/constants.py @@ -1,5 +1,6 @@ from __future__ import annotations from typing import Any +import numpy as np from ..types import PsNumericType, constify from .exceptions import PsInternalCompilerError @@ -85,4 +86,4 @@ class PsConstant: if not isinstance(other, PsConstant): return False - return (self._value, self._dtype) == (other._value, other._dtype) + return (self._dtype == other._dtype) and bool(np.all(self._value == other._value)) diff --git a/src/pystencils/backend/emission/__init__.py b/src/pystencils/backend/emission/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b38b0c459b0a408b87476faeb7e395f9a4822838 --- /dev/null +++ b/src/pystencils/backend/emission/__init__.py @@ -0,0 +1,4 @@ +from .c_printer import emit_code, CAstPrinter +from .ir_printer import emit_ir, IRAstPrinter + +__all__ = ["emit_code", "CAstPrinter", "emit_ir", "IRAstPrinter"] diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission/base_printer.py similarity index 84% rename from src/pystencils/backend/emission.py rename to src/pystencils/backend/emission/base_printer.py index 6196d69bee44be11d48d2e3e18e731a730fb47d4..9b1d5481a79447dcf70e9706c5d11450b18da22b 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission/base_printer.py @@ -1,9 +1,10 @@ from __future__ import annotations from enum import Enum +from abc import ABC, abstractmethod -from ..enums import Target +from ...target import Target -from .ast.structural import ( +from ..ast.structural import ( PsAstNode, PsBlock, PsStatement, @@ -15,7 +16,7 @@ from .ast.structural import ( PsPragma, ) -from .ast.expressions import ( +from ..ast.expressions import ( PsExpression, PsAdd, PsAddressOf, @@ -39,7 +40,6 @@ from .ast.expressions import ( PsSub, PsSymbolExpr, PsLiteralExpr, - PsVectorMemAcc, PsTernary, PsAnd, PsOr, @@ -51,24 +51,15 @@ from .ast.expressions import ( PsGe, PsLe, PsSubscript, - PsBufferAcc, ) -from .extensions.foreign_ast import PsForeignExpression +from ..extensions.foreign_ast import PsForeignExpression -from .exceptions import PsInternalCompilerError -from .memory import PsSymbol -from ..types import PsScalarType, PsArrayType +from ..memory import PsSymbol +from ..constants import PsConstant +from ...types import PsType -from .kernelfunction import KernelFunction, GpuKernelFunction - - -__all__ = ["emit_code", "CAstPrinter"] - - -def emit_code(kernel: KernelFunction): - printer = CAstPrinter() - return printer(kernel) +from ..kernelfunction import KernelFunction, GpuKernelFunction class EmissionError(Exception): @@ -170,7 +161,14 @@ class PrinterCtx: return " " * self.indent_level + line -class CAstPrinter: +class BasePrinter(ABC): + """Base code printer. + + The base printer is capable of printing syntax tree nodes valid in all output dialects. + It is specialized in `CAstPrinter` for the C output language, + and in `IRAstPrinter` for debug-printing the entire IR. + """ + def __init__(self, indent_width=3): self._indent_width = indent_width @@ -182,14 +180,6 @@ class CAstPrinter: else: return self.visit(obj, PrinterCtx()) - def print_signature(self, func: KernelFunction) -> str: - prefix = self._func_prefix(func) - params_str = ", ".join( - f"{p.dtype.c_string()} {p.name}" for p in func.parameters - ) - signature = " ".join([prefix, "void", func.name, f"({params_str})"]) - return signature - def visit(self, node: PsAstNode, pc: PrinterCtx) -> str: match node: case PsBlock(statements): @@ -219,13 +209,14 @@ class CAstPrinter: case PsLoop(ctr, start, stop, step, body): ctr_symbol = ctr.symbol + ctr_decl = self._symbol_decl(ctr_symbol) start_code = self.visit(start, pc) stop_code = self.visit(stop, pc) step_code = self.visit(step, pc) body_code = self.visit(body, pc) code = ( - f"for({ctr_symbol.dtype} {ctr_symbol.name} = {start_code};" + f"for({ctr_decl} = {start_code};" + f" {ctr.symbol.name} < {stop_code};" + f" {ctr.symbol.name} += {step_code})\n" + body_code @@ -259,20 +250,11 @@ class CAstPrinter: return symbol.name case PsConstantExpr(constant): - dtype = constant.get_dtype() - if not isinstance(dtype, PsScalarType): - raise EmissionError( - "Cannot print literals for non-scalar constants." - ) - - return dtype.create_literal(constant.value) + return self._constant_literal(constant) case PsLiteralExpr(lit): return lit.text - case PsVectorMemAcc(): - raise EmissionError("Cannot print vectorized array accesses") - case PsMemAcc(base, offset): pc.push_op(Ops.Subscript, LR.Left) base_code = self.visit(base, pc) @@ -283,14 +265,16 @@ class CAstPrinter: pc.pop_op() return pc.parenthesize(f"{base_code}[{index_code}]", Ops.Subscript) - + case PsSubscript(base, indices): pc.push_op(Ops.Subscript, LR.Left) base_code = self.visit(base, pc) pc.pop_op() pc.push_op(Ops.Weakest, LR.Middle) - indices_code = "".join("[" + self.visit(idx, pc) + "]" for idx in indices) + indices_code = "".join( + "[" + self.visit(idx, pc) + "]" for idx in indices + ) pc.pop_op() return pc.parenthesize(base_code + indices_code, Ops.Subscript) @@ -334,13 +318,6 @@ class CAstPrinter: return pc.parenthesize(f"!{operand_code}", Ops.Not) - # case PsDeref(operand): - # pc.push_op(Ops.Deref, LR.Right) - # operand_code = self.visit(operand, pc) - # pc.pop_op() - - # return pc.parenthesize(f"*{operand_code}", Ops.Deref) - case PsAddressOf(operand): pc.push_op(Ops.AddressOf, LR.Right) operand_code = self.visit(operand, pc) @@ -353,7 +330,7 @@ class CAstPrinter: operand_code = self.visit(operand, pc) pc.pop_op() - type_str = target_type.c_string() + type_str = self._type_str(target_type) return pc.parenthesize(f"({type_str}) {operand_code}", Ops.Cast) case PsTernary(cond, then, els): @@ -370,6 +347,7 @@ class CAstPrinter: ) case PsArrayInitList(_): + def print_arr(item) -> str: if isinstance(item, PsExpression): return self.visit(item, pc) @@ -388,15 +366,19 @@ class CAstPrinter: foreign_code = node.get_code(self.visit(c, pc) for c in children) pc.pop_op() return foreign_code - - case PsBufferAcc(): - raise PsInternalCompilerError( - f"Unable to print C code for buffer access {node}.\n" - f"Buffer accesses must be lowered using the `LowerToC` pass before emission." - ) case _: - raise NotImplementedError(f"Don't know how to print {node}") + raise NotImplementedError( + f"BasePrinter does not know how to print {type(node)}" + ) + + def print_signature(self, func: KernelFunction) -> str: + prefix = self._func_prefix(func) + params_str = ", ".join( + f"{self._type_str(p.dtype)} {p.name}" for p in func.parameters + ) + signature = " ".join([prefix, "void", func.name, f"({params_str})"]) + return signature def _func_prefix(self, func: KernelFunction): if isinstance(func, GpuKernelFunction) and func.target == Target.CUDA: @@ -404,20 +386,17 @@ class CAstPrinter: else: return "FUNC_PREFIX" - def _symbol_decl(self, symb: PsSymbol): - dtype = symb.get_dtype() - - if isinstance(dtype, PsArrayType): - array_dims = dtype.shape - dtype = dtype.base_type - else: - array_dims = () + @abstractmethod + def _symbol_decl(self, symb: PsSymbol) -> str: + pass - code = f"{dtype.c_string()} {symb.name}" - for d in array_dims: - code += f"[{str(d) if d is not None else ''}]" + @abstractmethod + def _constant_literal(self, constant: PsConstant) -> str: + pass - return code + @abstractmethod + def _type_str(self, dtype: PsType) -> str: + """Return a valid string representation of the given type""" def _char_and_op(self, node: PsBinOp) -> tuple[str, Ops]: match node: diff --git a/src/pystencils/backend/emission/c_printer.py b/src/pystencils/backend/emission/c_printer.py new file mode 100644 index 0000000000000000000000000000000000000000..0efe875887c75f08df20f825cdc1411c6de9ae96 --- /dev/null +++ b/src/pystencils/backend/emission/c_printer.py @@ -0,0 +1,62 @@ +from pystencils.backend.ast.astnode import PsAstNode +from pystencils.backend.constants import PsConstant +from pystencils.backend.emission.base_printer import PrinterCtx, EmissionError +from pystencils.backend.memory import PsSymbol +from .base_printer import BasePrinter + +from ..kernelfunction import KernelFunction +from ...types import PsType, PsArrayType, PsScalarType +from ..ast.expressions import PsBufferAcc +from ..ast.vector import PsVecMemAcc + + +def emit_code(kernel: KernelFunction): + printer = CAstPrinter() + return printer(kernel) + + +class CAstPrinter(BasePrinter): + + def __init__(self, indent_width=3): + super().__init__(indent_width) + + def visit(self, node: PsAstNode, pc: PrinterCtx) -> str: + match node: + case PsVecMemAcc(): + raise EmissionError("Cannot print vectorized array accesses to C code.") + + case PsBufferAcc(): + raise EmissionError( + f"Unable to print C code for buffer access {node}.\n" + f"Buffer accesses must be lowered using the `LowerToC` pass before emission." + ) + + case _: + return super().visit(node, pc) + + def _symbol_decl(self, symb: PsSymbol): + dtype = symb.get_dtype() + + if isinstance(dtype, PsArrayType): + array_dims = dtype.shape + dtype = dtype.base_type + else: + array_dims = () + + code = f"{self._type_str(dtype)} {symb.name}" + for d in array_dims: + code += f"[{str(d) if d is not None else ''}]" + + return code + + def _constant_literal(self, constant: PsConstant): + dtype = constant.get_dtype() + if not isinstance(dtype, PsScalarType): + raise EmissionError( + "Cannot print literals for non-scalar constants." + ) + + return dtype.create_literal(constant.value) + + def _type_str(self, dtype: PsType): + return dtype.c_string() diff --git a/src/pystencils/backend/emission/ir_printer.py b/src/pystencils/backend/emission/ir_printer.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4a18bd50b8edd54c6b4b7269531a13840e6b1c --- /dev/null +++ b/src/pystencils/backend/emission/ir_printer.py @@ -0,0 +1,81 @@ +from pystencils.backend.constants import PsConstant +from pystencils.backend.emission.base_printer import PrinterCtx +from pystencils.backend.memory import PsSymbol +from pystencils.types.meta import PsType, deconstify + +from .base_printer import BasePrinter, Ops, LR + +from ..ast import PsAstNode +from ..ast.expressions import PsBufferAcc +from ..ast.vector import PsVecMemAcc, PsVecBroadcast + + +def emit_ir(ir: PsAstNode): + """Emit the IR as C-like pseudo-code for inspection.""" + ir_printer = IRAstPrinter() + return ir_printer(ir) + + +class IRAstPrinter(BasePrinter): + + def __init__(self, indent_width=3): + super().__init__(indent_width) + + def visit(self, node: PsAstNode, pc: PrinterCtx) -> str: + match node: + case PsBufferAcc(ptr, indices): + pc.push_op(Ops.Subscript, LR.Left) + base_code = self.visit(ptr, pc) + pc.pop_op() + + pc.push_op(Ops.Weakest, LR.Middle) + indices_code = ", ".join(self.visit(idx, pc) for idx in indices) + pc.pop_op() + + return pc.parenthesize( + base_code + "[" + indices_code + "]", Ops.Subscript + ) + + case PsVecMemAcc(ptr, offset, lanes, stride): + pc.push_op(Ops.Subscript, LR.Left) + ptr_code = self.visit(ptr, pc) + pc.pop_op() + + pc.push_op(Ops.Weakest, LR.Middle) + offset_code = self.visit(offset, pc) + pc.pop_op() + + stride_code = "" if stride is None else f", stride={stride}" + + code = f"vec_load< {lanes}{stride_code} >({ptr_code}, {offset_code})" + return pc.parenthesize(code, Ops.Subscript) + + case PsVecBroadcast(lanes, operand): + pc.push_op(Ops.Weakest, LR.Middle) + operand_code = self.visit(operand, pc) + pc.pop_op() + + return pc.parenthesize( + f"vec_broadcast<{lanes}>({operand_code})", Ops.Weakest + ) + + case _: + return super().visit(node, pc) + + def _symbol_decl(self, symb: PsSymbol): + return f"{symb.name}: {self._type_str(symb.dtype)}" + + def _constant_literal(self, constant: PsConstant) -> str: + return f"[{constant.value}: {self._deconst_type_str(constant.dtype)}]" + + def _type_str(self, dtype: PsType | None): + if dtype is None: + return "<untyped>" + else: + return str(dtype) + + def _deconst_type_str(self, dtype: PsType | None): + if dtype is None: + return "<untyped>" + else: + return str(deconstify(dtype)) diff --git a/src/pystencils/backend/exceptions.py b/src/pystencils/backend/exceptions.py index d42f7c11fdc4dc13b7c520119057336da3b6e3e2..ec4f58fd8b7142009494ba3035e7d8e7aeacfce9 100644 --- a/src/pystencils/backend/exceptions.py +++ b/src/pystencils/backend/exceptions.py @@ -13,5 +13,9 @@ class KernelConstraintsError(Exception): """Indicates a constraint violation in the symbolic kernel""" +class VectorizationError(Exception): + """Indicates an error during a vectorization procedure""" + + class MaterializationError(Exception): """Indicates a fatal error during materialization of any abstract kernel component.""" diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index 30b243d9cd614d9f843021dff52167297fccdbba..be268ac99b92a2ca55693bb3f953ab3bf18d7e51 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -139,6 +139,8 @@ class CFunction(PsFunction): class PsMathFunction(PsFunction): """Homogenously typed mathematical functions.""" + __match_args__ = ("func",) + def __init__(self, func: MathFunctions) -> None: super().__init__(func.function_name, func.num_args) self._func = func diff --git a/src/pystencils/backend/jit/cpu_extension_module.py b/src/pystencils/backend/jit/cpu_extension_module.py index d7f64455082523a0d189e9864080cb62827ea156..4412f8879a346d5c3635271e9d3700fed041435f 100644 --- a/src/pystencils/backend/jit/cpu_extension_module.py +++ b/src/pystencils/backend/jit/cpu_extension_module.py @@ -272,7 +272,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ extract_func = self._scalar_extractor(param.dtype) code = self.TMPL_EXTRACT_SCALAR.format( name=param.name, - target_type=str(param.dtype), + target_type=param.dtype.c_string(), extract_function=extract_func, ) self._scalar_extractions[param] = code @@ -288,14 +288,14 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ for prop in param.properties: match prop: case FieldBasePtr(): - code = f"{param.dtype} {param.name} = ({param.dtype}) {buffer}.buf;" + code = f"{param.dtype.c_string()} {param.name} = ({param.dtype}) {buffer}.buf;" break case FieldShape(_, coord): - code = f"{param.dtype} {param.name} = {buffer}.shape[{coord}];" + code = f"{param.dtype.c_string()} {param.name} = {buffer}.shape[{coord}];" break case FieldStride(_, coord): code = ( - f"{param.dtype} {param.name} = " + f"{param.dtype.c_string()} {param.name} = " f"{buffer}.strides[{coord}] / {field.dtype.itemsize};" ) break diff --git a/src/pystencils/backend/jit/gpu_cupy.py b/src/pystencils/backend/jit/gpu_cupy.py index 7f38d9d434333c0d504babe5ed1fea65f6f85dad..563a9c06a7261343e6435737c47854a55d54a05e 100644 --- a/src/pystencils/backend/jit/gpu_cupy.py +++ b/src/pystencils/backend/jit/gpu_cupy.py @@ -8,7 +8,7 @@ try: except ImportError: HAVE_CUPY = False -from ...enums import Target +from ...target import Target from ...field import FieldType from ...types import PsType diff --git a/src/pystencils/backend/jit/jit.py b/src/pystencils/backend/jit/jit.py index 2184245707e861052dce61e7f8f12b720fa37922..2d091c4a009f27ba1d1efb2e7bab37021ff001dd 100644 --- a/src/pystencils/backend/jit/jit.py +++ b/src/pystencils/backend/jit/jit.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod if TYPE_CHECKING: from ..kernelfunction import KernelFunction, KernelParameter - from ...enums import Target + from ...target import Target class JitError(Exception): diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 839b8fd9829a83b46dbe2419013959b42943b96c..615f7396dd292227e39380062a8c77c425ecf401 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -136,6 +136,14 @@ class KernelCreationContext: symb.apply_dtype(dtype) return symb + + def get_new_symbol(self, name: str, dtype: PsType | None = None) -> PsSymbol: + """Always create a new symbol, deduplicating its name if another symbol with the same name already exists.""" + + if name in self._symbols: + return self.duplicate_symbol(self._symbols[name], dtype) + else: + return self.get_symbol(name, dtype) def find_symbol(self, name: str) -> PsSymbol | None: """Find a symbol with the given name in the symbol table, if it exists. @@ -170,11 +178,13 @@ class KernelCreationContext: self._symbols[old.name] = new - def duplicate_symbol(self, symb: PsSymbol) -> PsSymbol: + def duplicate_symbol( + self, symb: PsSymbol, new_dtype: PsType | None = None + ) -> PsSymbol: """Canonically duplicates the given symbol. - A new symbol with the same data type, and new name ``symb.name + "__<counter>"`` is created, - added to the symbol table, and returned. + A new symbol with the new name ``symb.name + "__<counter>"`` and optionally a different data type + is created, added to the symbol table, and returned. The ``counter`` reflects the number of previously created duplicates of this symbol. """ if (result := self._symbol_ctr_pattern.search(symb.name)) is not None: @@ -183,12 +193,15 @@ class KernelCreationContext: else: basename = symb.name + if new_dtype is None: + new_dtype = symb.dtype + initial_count = self._symbol_dup_table[basename] for i in count(initial_count): dup_name = f"{basename}__{i}" if self.find_symbol(dup_name) is None: self._symbol_dup_table[basename] = i + 1 - return self.get_symbol(dup_name, symb.dtype) + return self.get_symbol(dup_name, new_dtype) assert False, "unreachable code" @property diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index bdc8f11336886dade665e475a023ac49a7595eba..7ca0c370c3db2cb1496f829f3905e04b27255e15 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -43,7 +43,6 @@ from ..ast.expressions import ( PsLookup, PsRightShift, PsSubscript, - PsVectorMemAcc, PsTernary, PsRel, PsEq, @@ -57,6 +56,7 @@ from ..ast.expressions import ( PsNot, PsMemAcc ) +from ..ast.vector import PsVecMemAcc from ..constants import PsConstant from ...types import PsNumericType, PsStructType, PsType @@ -158,7 +158,7 @@ class FreezeExpressions: if isinstance(lhs, PsSymbolExpr): return PsDeclaration(lhs, rhs) - elif isinstance(lhs, (PsBufferAcc, PsLookup, PsVectorMemAcc)): # todo + elif isinstance(lhs, (PsBufferAcc, PsLookup, PsVecMemAcc)): return PsAssignment(lhs, rhs) else: raise FreezeError( diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index bae0328e4348836463f2d6831fc1905855354548..4f057e1fc8c3c68c6dae3d35ace54c8be8f3de21 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -149,7 +149,7 @@ class FullIterationSpace(IterationSpace): f" did not equal iteration space dimensionality ({dim})" ) - archetype_size = archetype_array.shape[:dim] + archetype_size = tuple(archetype_array.shape[:dim]) else: archetype_size = (None,) * dim diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index c8fad68f106dea78126be9d9ada51e2c57180cd2..cd4103433af0dcb91a593f66aa01cbecd0a630ce 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -12,6 +12,8 @@ from ...types import ( PsDereferencableType, PsPointerType, PsBoolType, + PsScalarType, + PsVectorType, constify, deconstify, ) @@ -47,6 +49,7 @@ from ..ast.expressions import ( PsNeg, PsNot, ) +from ..ast.vector import PsVecBroadcast, PsVecMemAcc from ..functions import PsMathFunction, CFunction from ..ast.util import determine_memory_object @@ -195,7 +198,7 @@ class TypeContext: case PsNumericOpTrait() if not isinstance( self._target_type, PsNumericType - ) or isinstance(self._target_type, PsBoolType): + ) or self._target_type.is_bool(): # FIXME: PsBoolType derives from PsNumericType, but is not numeric raise TypificationError( f"Numerical operation encountered in non-numerical type context:\n" @@ -203,14 +206,20 @@ class TypeContext: f" Type Context: {self._target_type}" ) - case PsIntOpTrait() if not isinstance(self._target_type, PsIntegerType): + case PsIntOpTrait() if not ( + isinstance(self._target_type, PsNumericType) + and self._target_type.is_int() + ): raise TypificationError( f"Integer operation encountered in non-integer type context:\n" f" Expression: {expr}" f" Type Context: {self._target_type}" ) - case PsBoolOpTrait() if not isinstance(self._target_type, PsBoolType): + case PsBoolOpTrait() if not ( + isinstance(self._target_type, PsNumericType) + and self._target_type.is_bool() + ): raise TypificationError( f"Boolean operation encountered in non-boolean type context:\n" f" Expression: {expr}" @@ -427,7 +436,7 @@ class Typifier: for idx in indices: self._handle_idx(idx) - case PsMemAcc(ptr, offset): + case PsMemAcc(ptr, offset) | PsVecMemAcc(ptr, offset): ptr_tc = TypeContext() self.visit_expr(ptr, ptr_tc) @@ -439,6 +448,9 @@ class Typifier: tc.apply_dtype(ptr_tc.target_type.base_type, expr) self._handle_idx(offset) + if isinstance(expr, PsVecMemAcc) and expr.stride is not None: + self._handle_idx(expr.stride) + case PsSubscript(arr, indices): if isinstance(arr, PsArrayInitList): shape = arr.shape @@ -474,7 +486,9 @@ class Typifier: self._handle_idx(idx) case PsAddressOf(arg): - if not isinstance(arg, (PsSymbolExpr, PsSubscript, PsMemAcc, PsBufferAcc, PsLookup)): + if not isinstance( + arg, (PsSymbolExpr, PsSubscript, PsMemAcc, PsBufferAcc, PsLookup) + ): raise TypificationError( f"Illegal expression below AddressOf operator: {arg}" ) @@ -559,7 +573,13 @@ class Typifier: f" Arguments Type: {args_tc.target_type}" ) - tc.apply_dtype(PsBoolType(), expr) + if isinstance(args_tc.target_type, PsVectorType): + tc.apply_dtype( + PsVectorType(PsBoolType(), args_tc.target_type.vector_entries), + expr, + ) + else: + tc.apply_dtype(PsBoolType(), expr) case PsBinOp(op1, op2): self.visit_expr(op1, tc) @@ -606,6 +626,22 @@ class Typifier: tc.apply_dtype(dtype, expr) + case PsVecBroadcast(lanes, arg): + op_tc = TypeContext() + self.visit_expr(arg, op_tc) + + if op_tc.target_type is None: + raise TypificationError( + f"Unable to determine type of argument to vector broadcast: {arg}" + ) + + if not isinstance(op_tc.target_type, PsScalarType): + raise TypificationError( + f"Illegal type in argument to vector broadcast: {op_tc.target_type}" + ) + + tc.apply_dtype(PsVectorType(op_tc.target_type, lanes), expr) + case _: raise NotImplementedError(f"Can't typify {expr}") diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py index 8868179307fa8a76bec5049ea8cf05ac3a4b46e0..0118c4f40a2b0702c51f39e9bc75fe4dc29cd67e 100644 --- a/src/pystencils/backend/kernelfunction.py +++ b/src/pystencils/backend/kernelfunction.py @@ -22,7 +22,7 @@ from .platforms import Platform, GpuThreadsRange from .constraints import KernelParamsConstraint from ..types import PsType -from ..enums import Target +from ..target import Target from ..field import Field from ..sympyextensions import TypedSymbol diff --git a/src/pystencils/backend/memory.py b/src/pystencils/backend/memory.py index 9b72a4e4337f1de152291b3287cffb612999a786..e578f924b587b509032970e2ac4f80836c71ae26 100644 --- a/src/pystencils/backend/memory.py +++ b/src/pystencils/backend/memory.py @@ -153,8 +153,8 @@ class PsBuffer: self._element_type = element_type self._index_dtype = idx_dtype - self._shape = tuple(shape) - self._strides = tuple(strides) + self._shape = list(shape) + self._strides = list(strides) base_ptr.add_property(BufferBasePtr(self)) self._base_ptr = base_ptr @@ -170,12 +170,12 @@ class PsBuffer: return self._base_ptr @property - def shape(self) -> tuple[PsSymbol | PsConstant, ...]: + def shape(self) -> list[PsSymbol | PsConstant]: """Buffer shape symbols and/or constants""" return self._shape @property - def strides(self) -> tuple[PsSymbol | PsConstant, ...]: + def strides(self) -> list[PsSymbol | PsConstant]: """Buffer stride symbols and/or constants""" return self._strides diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index f8cae89fcb8f0f2310a83049c1f3453ba9329b39..94fbfa0e1dce6aaf747134b21d946eac4dc023a6 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -1,5 +1,5 @@ -from typing import Sequence from abc import ABC, abstractmethod +from typing import Sequence from pystencils.backend.ast.expressions import PsCall @@ -22,14 +22,13 @@ from ..ast.expressions import ( PsSymbolExpr, PsExpression, PsBufferAcc, - PsVectorMemAcc, PsLookup, PsGe, PsLe, - PsTernary + PsTernary, ) +from ..ast.vector import PsVecMemAcc from ...types import PsVectorType, PsCustomType -from ..transformations.select_intrinsics import IntrinsicOps class GenericCpu(Platform): @@ -162,24 +161,23 @@ class GenericVectorCpu(GenericCpu, ABC): or raise an `MaterializationError` if type is not supported.""" @abstractmethod - def constant_vector(self, c: PsConstant) -> PsExpression: + def constant_intrinsic(self, c: PsConstant) -> PsExpression: """Return an expression that initializes a constant vector, or raise an `MaterializationError` if not supported.""" @abstractmethod def op_intrinsic( - self, op: IntrinsicOps, vtype: PsVectorType, args: Sequence[PsExpression] + self, expr: PsExpression, operands: Sequence[PsExpression] ) -> PsExpression: """Return an expression intrinsically invoking the given operation - on the given arguments with the given vector type, or raise an `MaterializationError` if not supported.""" @abstractmethod - def vector_load(self, acc: PsVectorMemAcc) -> PsExpression: + def vector_load(self, acc: PsVecMemAcc) -> PsExpression: """Return an expression intrinsically performing a vector load, or raise an `MaterializationError` if not supported.""" @abstractmethod - def vector_store(self, acc: PsVectorMemAcc, arg: PsExpression) -> PsExpression: + def vector_store(self, acc: PsVecMemAcc, arg: PsExpression) -> PsExpression: """Return an expression intrinsically performing a vector store, or raise an `MaterializationError` if not supported.""" diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py index 33838df08bcdb13094a96387fd3db565e4ba5932..aaa8b351b973294401a7892bb9a00e7f7cc672ba 100644 --- a/src/pystencils/backend/platforms/x86.py +++ b/src/pystencils/backend/platforms/x86.py @@ -1,15 +1,21 @@ from __future__ import annotations +from typing import Sequence from enum import Enum from functools import cache -from typing import Sequence from ..ast.expressions import ( PsExpression, - PsVectorMemAcc, PsAddressOf, PsMemAcc, + PsUnOp, + PsBinOp, + PsAdd, + PsSub, + PsMul, + PsDiv, + PsConstantExpr ) -from ..transformations.select_intrinsics import IntrinsicOps +from ..ast.vector import PsVecMemAcc, PsVecBroadcast from ...types import PsCustomType, PsVectorType, PsPointerType from ..constants import PsConstant @@ -24,6 +30,7 @@ class X86VectorArch(Enum): SSE = 128 AVX = 256 AVX512 = 512 + AVX512_FP16 = AVX512 + 1 # TODO improve modelling? def __ge__(self, other: X86VectorArch) -> bool: return self.value >= other.value @@ -48,7 +55,7 @@ class X86VectorArch(Enum): prefix = "_mm512" case other: raise MaterializationError( - f"X86/{self} does not support vector width {other}" + f"x86/{self} does not support vector width {other}" ) return prefix @@ -56,7 +63,7 @@ class X86VectorArch(Enum): def intrin_suffix(self, vtype: PsVectorType) -> str: scalar_type = vtype.scalar_type match scalar_type: - case Fp(16) if self >= X86VectorArch.AVX512: + case Fp(16) if self >= X86VectorArch.AVX512_FP16: suffix = "ph" case Fp(32): suffix = "ps" @@ -66,7 +73,7 @@ class X86VectorArch(Enum): suffix = f"epi{width}" case _: raise MaterializationError( - f"X86/{self} does not support scalar type {scalar_type}" + f"x86/{self} does not support scalar type {scalar_type}" ) return suffix @@ -82,6 +89,10 @@ class X86VectorCpu(GenericVectorCpu): def __init__(self, vector_arch: X86VectorArch): self._vector_arch = vector_arch + @property + def vector_arch(self) -> X86VectorArch: + return self._vector_arch + @property def required_headers(self) -> set[str]: if self._vector_arch == X86VectorArch.SSE: @@ -112,37 +123,45 @@ class X86VectorCpu(GenericVectorCpu): suffix = "i" case _: raise MaterializationError( - f"X86/{self._vector_arch} does not support scalar type {scalar_type}" + f"x86/{self._vector_arch} does not support scalar type {scalar_type}" ) if vector_type.width > self._vector_arch.max_vector_width: raise MaterializationError( - f"X86/{self._vector_arch} does not support {vector_type}" + f"x86/{self._vector_arch} does not support {vector_type}" ) return PsCustomType(f"__m{vector_type.width}{suffix}") - def constant_vector(self, c: PsConstant) -> PsExpression: + def constant_intrinsic(self, c: PsConstant) -> PsExpression: vtype = c.dtype assert isinstance(vtype, PsVectorType) stype = vtype.scalar_type prefix = self._vector_arch.intrin_prefix(vtype) suffix = self._vector_arch.intrin_suffix(vtype) + + if stype == SInt(64) and vtype.vector_entries <= 4: + suffix += "x" + set_func = CFunction( f"{prefix}_set_{suffix}", (stype,) * vtype.vector_entries, vtype ) - values = c.value + values = [PsConstantExpr(PsConstant(v, stype)) for v in c.value] return set_func(*values) def op_intrinsic( - self, op: IntrinsicOps, vtype: PsVectorType, args: Sequence[PsExpression] + self, expr: PsExpression, operands: Sequence[PsExpression] ) -> PsExpression: - func = _x86_op_intrin(self._vector_arch, op, vtype) - return func(*args) + match expr: + case PsUnOp() | PsBinOp(): + func = _x86_op_intrin(self._vector_arch, expr, expr.get_dtype()) + return func(*operands) + case _: + raise MaterializationError(f"Cannot map {type(expr)} to x86 intrinsic") - def vector_load(self, acc: PsVectorMemAcc) -> PsExpression: - if acc.stride == 1: + def vector_load(self, acc: PsVecMemAcc) -> PsExpression: + if acc.stride is None: load_func = _x86_packed_load(self._vector_arch, acc.dtype, False) return load_func( PsAddressOf(PsMemAcc(acc.pointer, acc.offset)) @@ -150,8 +169,8 @@ class X86VectorCpu(GenericVectorCpu): else: raise NotImplementedError("Gather loads not implemented yet.") - def vector_store(self, acc: PsVectorMemAcc, arg: PsExpression) -> PsExpression: - if acc.stride == 1: + def vector_store(self, acc: PsVecMemAcc, arg: PsExpression) -> PsExpression: + if acc.stride is None: store_func = _x86_packed_store(self._vector_arch, acc.dtype, False) return store_func( PsAddressOf(PsMemAcc(acc.pointer, acc.offset)), @@ -189,24 +208,32 @@ def _x86_packed_store( @cache def _x86_op_intrin( - varch: X86VectorArch, op: IntrinsicOps, vtype: PsVectorType + varch: X86VectorArch, op: PsUnOp | PsBinOp, vtype: PsVectorType ) -> CFunction: prefix = varch.intrin_prefix(vtype) suffix = varch.intrin_suffix(vtype) match op: - case IntrinsicOps.ADD: + case PsVecBroadcast(): + opstr = "set1" + if vtype.scalar_type == SInt(64) and vtype.vector_entries <= 4: + suffix += "x" + case PsAdd(): opstr = "add" - case IntrinsicOps.SUB: + case PsSub(): opstr = "sub" - case IntrinsicOps.MUL: + case PsMul() if vtype.is_int(): + raise MaterializationError( + f"Unable to select intrinsic for integer multiplication: " + f"{varch.name} does not support packed integer multiplication.\n" + f" at: {op}" + ) + case PsMul(): opstr = "mul" - case IntrinsicOps.DIV: + case PsDiv(): opstr = "div" - case IntrinsicOps.FMA: - opstr = "fmadd" case _: - assert False + raise MaterializationError(f"Unable to select operation intrinsic for {type(op)}") - num_args = 3 if op == IntrinsicOps.FMA else 2 + num_args = 1 if isinstance(op, PsUnOp) else 2 return CFunction(f"{prefix}_{opstr}_{suffix}", (vtype,) * num_args, vtype) diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index 7375af618a438d9145e76fe2097d7176b7d2b2ea..44613e524cbb7fa6e8a5ef11291d054e18093f70 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -21,7 +21,7 @@ All transformations in this module retain canonicality of the AST. Canonicality allows transformations to forego various checks that would otherwise be necessary to prove their legality. -Certain transformations, like the auto-vectorizer (TODO), state additional requirements, e.g. +Certain transformations, like the `LoopVectorizer`, state additional requirements, e.g. the absence of loop-carried dependencies. Transformations @@ -48,6 +48,12 @@ Simplifying Transformations .. autoclass:: EliminateBranches :members: __call__ + +Code Rewriting +-------------- + +.. autofunction:: substitute_symbols + Code Motion ----------- @@ -66,6 +72,21 @@ Loop Reshaping Transformations .. autoclass:: AddOpenMP :members: +Vectorization +------------- + +.. autoclass:: VectorizationAxis + :members: + +.. autoclass:: VectorizationContext + :members: + +.. autoclass:: AstVectorizer + :members: + +.. autoclass:: LoopVectorizer + :members: + Code Lowering and Materialization --------------------------------- @@ -75,22 +96,29 @@ Code Lowering and Materialization .. autoclass:: SelectFunctions :members: __call__ +.. autoclass:: SelectIntrinsics + :members: + """ from .canonicalize_symbols import CanonicalizeSymbols from .canonical_clone import CanonicalClone +from .rewrite import substitute_symbols from .eliminate_constants import EliminateConstants from .eliminate_branches import EliminateBranches from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations from .reshape_loops import ReshapeLoops from .add_pragmas import InsertPragmasAtLoops, LoopPragma, AddOpenMP +from .ast_vectorizer import VectorizationAxis, VectorizationContext, AstVectorizer +from .loop_vectorizer import LoopVectorizer from .lower_to_c import LowerToC from .select_functions import SelectFunctions -from .select_intrinsics import MaterializeVectorIntrinsics +from .select_intrinsics import SelectIntrinsics __all__ = [ "CanonicalizeSymbols", "CanonicalClone", + "substitute_symbols", "EliminateConstants", "EliminateBranches", "HoistLoopInvariantDeclarations", @@ -98,7 +126,11 @@ __all__ = [ "InsertPragmasAtLoops", "LoopPragma", "AddOpenMP", + "VectorizationAxis", + "VectorizationContext", + "AstVectorizer", + "LoopVectorizer", "LowerToC", "SelectFunctions", - "MaterializeVectorIntrinsics", + "SelectIntrinsics", ] diff --git a/src/pystencils/backend/transformations/ast_vectorizer.py b/src/pystencils/backend/transformations/ast_vectorizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d0792fc366fa716f3c81b8f35adb4ceddb460d78 --- /dev/null +++ b/src/pystencils/backend/transformations/ast_vectorizer.py @@ -0,0 +1,488 @@ +from __future__ import annotations +from typing import overload + +from dataclasses import dataclass + +from ...types import PsType, PsVectorType, PsBoolType, PsScalarType + +from ..kernelcreation import KernelCreationContext, AstFactory +from ..memory import PsSymbol +from ..constants import PsConstant +from ..functions import PsMathFunction + +from ..ast import PsAstNode +from ..ast.structural import PsBlock, PsDeclaration, PsAssignment +from ..ast.expressions import ( + PsExpression, + PsAddressOf, + PsCast, + PsUnOp, + PsBinOp, + PsSymbolExpr, + PsConstantExpr, + PsLiteral, + PsCall, + PsMemAcc, + PsBufferAcc, + PsAdd, + PsMul, + PsSub, + PsNeg, + PsDiv, +) +from ..ast.vector import PsVectorOp, PsVecBroadcast, PsVecMemAcc +from ..ast.analysis import UndefinedSymbolsCollector + +from ..exceptions import PsInternalCompilerError, VectorizationError + + +@dataclass(frozen=True) +class VectorizationAxis: + """Information about the iteration axis along which a subtree is being vectorized.""" + + counter: PsSymbol + """Scalar iteration counter of this axis""" + + vectorized_counter: PsSymbol | None = None + """Vectorized iteration counter of this axis""" + + step: PsExpression = PsExpression.make(PsConstant(1)) + """Step size of the scalar iteration""" + + def get_vectorized_counter(self) -> PsSymbol: + if self.vectorized_counter is None: + raise PsInternalCompilerError( + "No vectorized counter defined on this vectorization axis" + ) + + return self.vectorized_counter + + +class VectorizationContext: + """Context information for AST vectorization. + + Args: + lanes: Number of vector lanes + axis: Iteration axis along which code is being vectorized + """ + + def __init__( + self, + ctx: KernelCreationContext, + lanes: int, + axis: VectorizationAxis, + vectorized_symbols: dict[PsSymbol, PsSymbol] | None = None, + ) -> None: + self._ctx = ctx + self._lanes = lanes + self._axis: VectorizationAxis = axis + self._vectorized_symbols: dict[PsSymbol, PsSymbol] = ( + {**vectorized_symbols} if vectorized_symbols is not None else dict() + ) + self._lane_mask: PsSymbol | None = None + + if axis.vectorized_counter is not None: + self._vectorized_symbols[axis.counter] = axis.vectorized_counter + + @property + def lanes(self) -> int: + """Number of vector lanes""" + return self._lanes + + @property + def axis(self) -> VectorizationAxis: + """Iteration axis along which to vectorize""" + return self._axis + + @property + def vectorized_symbols(self) -> dict[PsSymbol, PsSymbol]: + """Dictionary mapping scalar symbols that are being vectorized to their vectorized copies""" + return self._vectorized_symbols + + @property + def lane_mask(self) -> PsSymbol | None: + """Symbol representing the current lane execution mask, or ``None`` if all lanes are active.""" + return self._lane_mask + + @lane_mask.setter + def lane_mask(self, mask: PsSymbol | None): + self._lane_mask = mask + + def get_lane_mask_expr(self) -> PsExpression: + """Retrieve an expression representing the current lane execution mask.""" + if self._lane_mask is not None: + return PsExpression.make(self._lane_mask) + else: + return PsExpression.make( + PsConstant(True, PsVectorType(PsBoolType(), self._lanes)) + ) + + def vectorize_symbol(self, symb: PsSymbol) -> PsSymbol: + """Vectorize the given symbol of scalar type. + + Creates a duplicate of the given symbol with vectorized data type, + adds it to the ``vectorized_symbols`` dict, + and returns the duplicate. + + Raises: + VectorizationError: If the symbol's data type was not a `PsScalarType`, + or if the symbol was already vectorized + """ + if symb in self._vectorized_symbols: + raise VectorizationError(f"Symbol {symb} was already vectorized.") + + vec_type = self.vector_type(symb.get_dtype()) + vec_symb = self._ctx.duplicate_symbol(symb, vec_type) + self._vectorized_symbols[symb] = vec_symb + return vec_symb + + def vector_type(self, scalar_type: PsType) -> PsVectorType: + """Vectorize the given scalar data type. + + Raises: + VectorizationError: If the given data type was not a `PsScalarType`. + """ + if not isinstance(scalar_type, PsScalarType): + raise VectorizationError( + f"Unable to vectorize type {scalar_type}: was not a scalar numeric type" + ) + return PsVectorType(scalar_type, self._lanes) + + +@dataclass +class Affine: + coeff: PsExpression + offset: PsExpression + + def __neg__(self): + return Affine(-self.coeff, -self.offset) + + def __add__(self, other: Affine): + return Affine(self.coeff + other.coeff, self.offset + other.offset) + + def __sub__(self, other: Affine): + return Affine(self.coeff - other.coeff, self.offset - other.offset) + + def __mul__(self, factor: PsExpression): + if not isinstance(factor, PsExpression): + return NotImplemented + return Affine(self.coeff * factor, self.offset * factor) + + def __rmul__(self, factor: PsExpression): + if not isinstance(factor, PsExpression): + return NotImplemented + return Affine(self.coeff * factor, self.offset * factor) + + def __truediv__(self, divisor: PsExpression): + if not isinstance(divisor, PsExpression): + return NotImplemented + return Affine(self.coeff / divisor, self.offset / divisor) + + +class AstVectorizer: + """Transform a scalar subtree into a SIMD-parallel version of itself. + + The `AstVectorizer` constructs a vectorized copy of a subtree by creating a SIMD-parallel + version of each of its nodes, one at a time. + It relies on information given in a `VectorizationContext` that defines the current environment, + including the vectorization axis, the number of vector lanes, and an execution mask determining + which vector lanes are active. + + **Memory Accesses:** + The AST vectorizer is capable of vectorizing `PsMemAcc` and `PsBufferAcc` only under certain circumstances: + + - If all indices are independent of both the vectorization axis' counter and any vectorized symbols, + the memory access is *lane-invariant*, and its result will be broadcast to all vector lanes. + - If at most one index depends on the axis counter via an affine expression, and does not depend on any + vectorized symbols, the memory access can be performed in parallel, either contiguously or strided, + and is replaced by a `PsVecMemAcc`. + - All other cases cause vectorization to fail. + + **Legality:** + The AST vectorizer performs no legality checks and in particular assumes the absence of loop-carried + dependencies; i.e. all iterations of the vectorized subtree must already be independent of each + other, and insensitive to execution order. + + **Result and Failures:** + The AST vectorizer does not alter the original subtree, but constructs and returns a copy of it. + Any symbols declared within the subtree are therein replaced by canonically renamed, + vectorized copies of themselves. + + If the AST vectorizer is unable to transform a subtree, it raises a `VectorizationError`. + """ + + def __init__(self, ctx: KernelCreationContext): + self._ctx = ctx + self._factory = AstFactory(ctx) + self._collect_symbols = UndefinedSymbolsCollector() + + from ..kernelcreation import Typifier + from .eliminate_constants import EliminateConstants + from .lower_to_c import LowerToC + + self._typifiy = Typifier(ctx) + self._fold_constants = EliminateConstants(ctx) + self._lower_to_c = LowerToC(ctx) + + @overload + def __call__(self, node: PsBlock, vc: VectorizationContext) -> PsBlock: + pass + + @overload + def __call__(self, node: PsDeclaration, vc: VectorizationContext) -> PsDeclaration: + pass + + @overload + def __call__(self, node: PsAssignment, vc: VectorizationContext) -> PsAssignment: + pass + + @overload + def __call__(self, node: PsExpression, vc: VectorizationContext) -> PsExpression: + pass + + @overload + def __call__(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode: + pass + + def __call__(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode: + """Perform subtree vectorization. + + Args: + node: Root of the subtree that should be vectorized + vc: Object describing the current vectorization context + + Raises: + VectorizationError: If a node cannot be vectorized + """ + return self.visit(node, vc) + + def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode: + """Vectorize a subtree.""" + + match node: + case PsBlock(stmts): + return PsBlock([self.visit(n, vc) for n in stmts]) + + case PsExpression(): + return self.visit_expr(node, vc) + + case PsDeclaration(_, rhs): + vec_symb = vc.vectorize_symbol(node.declared_symbol) + vec_lhs = PsExpression.make(vec_symb) + vec_rhs = self.visit_expr(rhs, vc) + return PsDeclaration(vec_lhs, vec_rhs) + + case PsAssignment(lhs, rhs): + if not isinstance(lhs, (PsMemAcc, PsBufferAcc)): + raise VectorizationError(f"Unable to vectorize assignment to {lhs}") + + lhs_vec = self.visit_expr(lhs, vc) + if not isinstance(lhs_vec, PsVecMemAcc): + raise VectorizationError( + f"Unable to vectorize memory write {node}:\n" + f"Index did not depend on axis counter." + ) + + rhs_vec = self.visit_expr(rhs, vc) + return PsAssignment(lhs_vec, rhs_vec) + + case _: + raise NotImplementedError(f"Vectorization of {node} is not implemented") + + def visit_expr(self, expr: PsExpression, vc: VectorizationContext) -> PsExpression: + """Vectorize an expression.""" + + vec_expr: PsExpression + scalar_type = expr.get_dtype() + + match expr: + # Invalids + case PsVectorOp() | PsAddressOf(): + raise VectorizationError(f"Unable to vectorize {type(expr)}: {expr}") + + # Symbols + case PsSymbolExpr(symb) if symb in vc.vectorized_symbols: + # Vectorize symbol + vector_symb = vc.vectorized_symbols[symb] + vec_expr = PsSymbolExpr(vector_symb) + + case PsSymbolExpr(symb) if symb == vc.axis.counter: + raise VectorizationError( + f"Unable to vectorize occurence of axis counter {symb} " + "since no vectorized version of the counter was present in the context." + ) + + # Symbols, constants, and literals that can be broadcast + case PsSymbolExpr() | PsConstantExpr() | PsLiteral(): + if isinstance(expr.dtype, PsScalarType): + # Broadcast constant or non-vectorized scalar symbol + vec_expr = PsVecBroadcast(vc.lanes, expr.clone()) + else: + # Cannot vectorize non-scalar constants or symbols + raise VectorizationError( + f"Unable to vectorize expression {expr} of non-scalar data type {expr.dtype}" + ) + + # Unary Ops + case PsCast(target_type, operand): + vec_expr = PsCast( + vc.vector_type(target_type), self.visit_expr(operand, vc) + ) + + case PsUnOp(operand): + vec_expr = type(expr)(self.visit_expr(operand, vc)) + + # Binary Ops + case PsBinOp(op1, op2): + vec_expr = type(expr)( + self.visit_expr(op1, vc), self.visit_expr(op2, vc) + ) + + # Math Functions + case PsCall(PsMathFunction(func), func_args): + vec_expr = PsCall( + PsMathFunction(func), + [self.visit_expr(arg, vc) for arg in func_args], + ) + + # Other Functions + case PsCall(func, _): + raise VectorizationError( + f"Unable to vectorize function call to {func}." + ) + + # Memory Accesses + case PsMemAcc(ptr, offset): + if not isinstance(ptr, PsSymbolExpr): + raise VectorizationError( + f"Unable to vectorize memory access by non-symbol pointer {ptr}" + ) + + idx_affine = self._index_as_affine(offset, vc) + if idx_affine is None: + vec_expr = PsVecBroadcast(vc.lanes, expr.clone()) + else: + stride: PsExpression | None = self._fold_constants( + self._typifiy(idx_affine.coeff * vc.axis.step) + ) + + if ( + isinstance(stride, PsConstantExpr) + and stride.constant.value == 1 + ): + # Contiguous access + stride = None + + vec_expr = PsVecMemAcc( + ptr.clone(), offset.clone(), vc.lanes, stride + ) + + case PsBufferAcc(ptr, indices): + buf = expr.buffer + + ctr_found = False + access_stride: PsExpression | None = None + + for i, idx in enumerate(indices): + idx_affine = self._index_as_affine(idx, vc) + if idx_affine is not None: + if ctr_found: + raise VectorizationError( + f"Unable to vectorize buffer access {expr}: " + f"Found multiple indices that depend on iteration counter {vc.axis.counter}." + ) + + ctr_found = True + + access_stride = stride = self._fold_constants( + self._typifiy( + idx_affine.coeff + * vc.axis.step + * PsExpression.make(buf.strides[i]) + ) + ) + + if ctr_found: + # Buffer access must be vectorized + assert access_stride is not None + + if ( + isinstance(access_stride, PsConstantExpr) + and access_stride.constant.value == 1 + ): + # Contiguous access + access_stride = None + + linearized_acc = self._lower_to_c(expr) + assert isinstance(linearized_acc, PsMemAcc) + + vec_expr = PsVecMemAcc( + ptr.clone(), + linearized_acc.offset.clone(), + vc.lanes, + access_stride, + ) + else: + # Buffer access is lane-invariant + vec_expr = PsVecBroadcast(vc.lanes, expr.clone()) + + case _: + raise NotImplementedError( + f"Vectorization of {type(expr)} is not implemented" + ) + + vec_expr.dtype = vc.vector_type(scalar_type) + return vec_expr + + def _index_as_affine( + self, idx: PsExpression, vc: VectorizationContext + ) -> Affine | None: + """Attempt to analyze an index expression as an affine expression of the axis counter.""" + + free_symbols = self._collect_symbols(idx) + + # Check if all symbols except for the axis counter are lane-invariant + for symb in free_symbols: + if symb != vc.axis.counter and symb in vc.vectorized_symbols: + raise VectorizationError( + "Unable to rewrite index as affine expression of axis counter: \n" + f" {idx}\n" + f"Expression depends on non-lane-invariant symbol {symb}" + ) + + if vc.axis.counter not in free_symbols: + # Index is lane-invariant + return None + + zero = self._factory.parse_index(0) + one = self._factory.parse_index(1) + + def lane_invariant(expr) -> bool: + return vc.axis.counter not in self._collect_symbols(expr) + + def collect(subexpr) -> Affine: + match subexpr: + case PsSymbolExpr(symb) if symb == vc.axis.counter: + return Affine(one, zero) + case _ if lane_invariant(subexpr): + return Affine(zero, subexpr) + case PsNeg(op): + return -collect(op) + case PsAdd(op1, op2): + return collect(op1) + collect(op2) + case PsSub(op1, op2): + return collect(op1) - collect(op2) + case PsMul(op1, op2) if lane_invariant(op1): + return op1 * collect(op2) + case PsMul(op1, op2) if lane_invariant(op2): + return collect(op1) * op2 + case PsDiv(op1, op2) if lane_invariant(op2): + return collect(op1) / op2 + case _: + raise VectorizationError( + "Unable to rewrite index as affine expression of axis counter: \n" + f" {idx}\n" + f"Encountered invalid subexpression {subexpr}" + ) + + return collect(idx) diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index 222f4a378c3063cf58e1d14f714a5ccbdc524964..961f4a04a6f816a2e1aea17e6e924e4369961f43 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -1,6 +1,8 @@ from typing import cast, Iterable, overload from collections import defaultdict +import numpy as np + from ..kernelcreation import KernelCreationContext, Typifier from ..ast import PsAstNode @@ -30,17 +32,19 @@ from ..ast.expressions import ( PsGt, PsNe, PsTernary, + PsCast, ) +from ..ast.vector import PsVecBroadcast from ..ast.util import AstEqWrapper from ..constants import PsConstant from ..memory import PsSymbol from ..functions import PsMathFunction from ...types import ( - PsIntegerType, - PsIeeeFloatType, PsNumericType, PsBoolType, + PsScalarType, + PsVectorType, PsTypeError, ) @@ -110,14 +114,19 @@ class EliminateConstants: """ def __init__( - self, ctx: KernelCreationContext, extract_constant_exprs: bool = False + self, + ctx: KernelCreationContext, + extract_constant_exprs: bool = False, + fold_integers: bool = True, + fold_relations: bool = True, + fold_floats: bool = False, ): self._ctx = ctx self._typify = Typifier(ctx) - self._fold_integers = True - self._fold_relations = True - self._fold_floats = False + self._fold_integers = fold_integers + self._fold_relations = fold_relations + self._fold_floats = fold_floats self._extract_constant_exprs = extract_constant_exprs @overload @@ -177,84 +186,109 @@ class EliminateConstants: expr.children = [r[0] for r in subtree_results] subtree_constness = [r[1] for r in subtree_results] - # Eliminate idempotence, dominance, and trivial relations + # Eliminate idempotence, dominance. constant (broad)casts, and trivial relations match expr: # Additive idempotence: Addition and subtraction of zero - case PsAdd(PsConstantExpr(c), other_op) if c.value == 0: + case PsAdd(PsConstantExpr(c), other_op) if np.all(c.value == 0): return other_op, all(subtree_constness) - case PsAdd(other_op, PsConstantExpr(c)) if c.value == 0: + case PsAdd(other_op, PsConstantExpr(c)) if np.all(c.value == 0): return other_op, all(subtree_constness) - case PsSub(other_op, PsConstantExpr(c)) if c.value == 0: + case PsSub(other_op, PsConstantExpr(c)) if np.all(c.value == 0): return other_op, all(subtree_constness) # Additive idempotence: Subtraction from zero - case PsSub(PsConstantExpr(c), other_op) if c.value == 0: - other_transformed, is_const = self.visit_expr(-other_op, ecc) + case PsSub(PsConstantExpr(c), other_op) if np.all(c.value == 0): + other_transformed, is_const = self.visit_expr( + self._typify(-other_op), ecc + ) return other_transformed, is_const # Multiplicative idempotence: Multiplication with and division by one - case PsMul(PsConstantExpr(c), other_op) if c.value == 1: + case PsMul(PsConstantExpr(c), other_op) if np.all(c.value == 1): return other_op, all(subtree_constness) - case PsMul(other_op, PsConstantExpr(c)) if c.value == 1: + case PsMul(other_op, PsConstantExpr(c)) if np.all(c.value == 1): return other_op, all(subtree_constness) case PsDiv(other_op, PsConstantExpr(c)) | PsIntDiv( other_op, PsConstantExpr(c) - ) if c.value == 1: + ) if np.all(c.value == 1): return other_op, all(subtree_constness) # Trivial remainder at division by one - case PsRem(other_op, PsConstantExpr(c)) if c.value == 1: + case PsRem(other_op, PsConstantExpr(c)) if np.all(c.value == 1): zero = self._typify(PsConstantExpr(PsConstant(0, c.get_dtype()))) return zero, True # Multiplicative dominance: 0 * x = 0 - case PsMul(PsConstantExpr(c), other_op) if c.value == 0: + case PsMul(PsConstantExpr(c), other_op) if np.all(c.value == 0): return PsConstantExpr(c), True - case PsMul(other_op, PsConstantExpr(c)) if c.value == 0: + case PsMul(other_op, PsConstantExpr(c)) if np.all(c.value == 0): return PsConstantExpr(c), True # Logical idempotence - case PsAnd(PsConstantExpr(c), other_op) if c.value: + case PsAnd(PsConstantExpr(c), other_op) if np.all(c.value): return other_op, all(subtree_constness) - case PsAnd(other_op, PsConstantExpr(c)) if c.value: + case PsAnd(other_op, PsConstantExpr(c)) if np.all(c.value): return other_op, all(subtree_constness) - case PsOr(PsConstantExpr(c), other_op) if not c.value: + case PsOr(PsConstantExpr(c), other_op) if not np.any(c.value): return other_op, all(subtree_constness) - case PsOr(other_op, PsConstantExpr(c)) if not c.value: + case PsOr(other_op, PsConstantExpr(c)) if not np.any(c.value): return other_op, all(subtree_constness) # Logical dominance - case PsAnd(PsConstantExpr(c), other_op) if not c.value: + case PsAnd(PsConstantExpr(c), other_op) if not np.any(c.value): return PsConstantExpr(c), True - case PsAnd(other_op, PsConstantExpr(c)) if not c.value: + case PsAnd(other_op, PsConstantExpr(c)) if not np.any(c.value): return PsConstantExpr(c), True - case PsOr(PsConstantExpr(c), other_op) if c.value: + case PsOr(PsConstantExpr(c), other_op) if np.all(c.value): return PsConstantExpr(c), True - case PsOr(other_op, PsConstantExpr(c)) if c.value: + case PsOr(other_op, PsConstantExpr(c)) if np.all(c.value): return PsConstantExpr(c), True + # Trivial (broad)casts + case PsCast(target_type, PsConstantExpr(c)): + assert isinstance(target_type, PsNumericType) + return PsConstantExpr(c.reinterpret_as(target_type)), True + + case PsVecBroadcast(lanes, PsConstantExpr(c)): + scalar_type = c.get_dtype() + assert isinstance(scalar_type, PsScalarType) + vec_type = PsVectorType(scalar_type, lanes) + return PsConstantExpr(PsConstant(c.value, vec_type)), True + # Trivial comparisons case ( PsEq(op1, op2) | PsGe(op1, op2) | PsLe(op1, op2) ) if op1.structurally_equal(op2): - true = self._typify(PsConstantExpr(PsConstant(True, PsBoolType()))) + arg_dtype = op1.get_dtype() + bool_type = ( + PsVectorType(PsBoolType(), arg_dtype.vector_entries) + if isinstance(arg_dtype, PsVectorType) + else PsBoolType() + ) + true = self._typify(PsConstantExpr(PsConstant(True, bool_type))) return true, True case ( PsNe(op1, op2) | PsGt(op1, op2) | PsLt(op1, op2) ) if op1.structurally_equal(op2): - false = self._typify(PsConstantExpr(PsConstant(False, PsBoolType()))) + arg_dtype = op1.get_dtype() + bool_type = ( + PsVectorType(PsBoolType(), arg_dtype.vector_entries) + if isinstance(arg_dtype, PsVectorType) + else PsBoolType() + ) + false = self._typify(PsConstantExpr(PsConstant(False, bool_type))) return false, True # Trivial ternaries @@ -270,11 +304,15 @@ class EliminateConstants: if all(subtree_constness): dtype = expr.get_dtype() - is_int = isinstance(dtype, PsIntegerType) - is_float = isinstance(dtype, PsIeeeFloatType) - is_bool = isinstance(dtype, PsBoolType) is_rel = isinstance(expr, PsRel) + if isinstance(dtype, PsNumericType): + is_int = dtype.is_int() + is_float = dtype.is_float() + is_bool = dtype.is_bool() + else: + is_int = is_float = is_bool = False + do_fold = ( is_bool or (self._fold_integers and is_int) @@ -317,8 +355,9 @@ class EliminateConstants: elif isinstance(expr, PsDiv): if is_int: from ...utils import c_intdiv + folded = PsConstant(c_intdiv(v1, v2), dtype) - elif isinstance(dtype, PsIeeeFloatType): + elif isinstance(dtype, PsNumericType) and dtype.is_float(): folded = PsConstant(v1 / v2, dtype) if folded is not None: diff --git a/src/pystencils/backend/transformations/loop_vectorizer.py b/src/pystencils/backend/transformations/loop_vectorizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e01e657e3edf015f85fec2b3954ec40157051515 --- /dev/null +++ b/src/pystencils/backend/transformations/loop_vectorizer.py @@ -0,0 +1,235 @@ +import numpy as np +from enum import Enum, auto +from typing import cast, Callable, overload + +from ...types import PsVectorType, PsScalarType + +from ..kernelcreation import KernelCreationContext +from ..constants import PsConstant +from ..ast import PsAstNode +from ..ast.structural import PsLoop, PsBlock, PsDeclaration +from ..ast.expressions import PsExpression +from ..ast.vector import PsVecBroadcast +from ..ast.analysis import collect_undefined_symbols + +from .ast_vectorizer import VectorizationAxis, VectorizationContext, AstVectorizer +from .rewrite import substitute_symbols + + +class LoopVectorizer: + """Vectorize loops. + + The loop vectorizer provides methods to vectorize single loops inside an AST + using a given number of vector lanes. + During vectorization, the loop body is transformed using the `AstVectorizer`, + The loop's limits are adapted according to the number of vector lanes, + and a block treating trailing iterations is optionally added. + + Args: + ctx: The current kernel creation context + lanes: The number of vector lanes to use + trailing_iters: Mode for the treatment of trailing iterations + """ + + class TrailingItersTreatment(Enum): + """How to treat trailing iterations during loop vectorization.""" + + SCALAR_LOOP = auto() + """Cover trailing iterations using a scalar remainder loop.""" + + MASKED_BLOCK = auto() + """Cover trailing iterations using a masked block.""" + + NONE = auto() + """Assume that the loop iteration count is a multiple of the number of lanes + and do not cover any trailing iterations""" + + def __init__( + self, + ctx: KernelCreationContext, + lanes: int, + trailing_iters: TrailingItersTreatment = TrailingItersTreatment.SCALAR_LOOP, + ): + self._ctx = ctx + self._lanes = lanes + self._trailing_iters = trailing_iters + + from ..kernelcreation import Typifier + from .eliminate_constants import EliminateConstants + + self._typify = Typifier(ctx) + self._vectorize_ast = AstVectorizer(ctx) + self._fold = EliminateConstants(ctx) + + def vectorize_select_loops( + self, node: PsAstNode, predicate: Callable[[PsLoop], bool] + ) -> PsAstNode: + """Select and vectorize loops from a syntax tree according to a predicate. + + Finds each loop inside a subtree and evaluates ``predicate`` on them. + If ``predicate(loop)`` evaluates to `True`, the loop is vectorized. + + Loops nested inside a vectorized loop will not be processed. + + Args: + node: Root of the subtree to process + predicate: Callback telling the vectorizer which loops to vectorize + """ + match node: + case PsLoop() if predicate(node): + return self.vectorize_loop(node) + case PsExpression(): + return node + case _: + node.children = [ + self.vectorize_select_loops(c, predicate) for c in node.children + ] + return node + + def __call__(self, loop: PsLoop) -> PsLoop | PsBlock: + return self.vectorize_loop(loop) + + def vectorize_loop(self, loop: PsLoop) -> PsLoop | PsBlock: + """Vectorize the given loop.""" + scalar_ctr_expr = loop.counter + scalar_ctr = scalar_ctr_expr.symbol + + # Prepare vector counter + vector_ctr_dtype = PsVectorType( + cast(PsScalarType, scalar_ctr_expr.get_dtype()), self._lanes + ) + vector_ctr = self._ctx.duplicate_symbol(scalar_ctr, vector_ctr_dtype) + step_multiplier_val = np.array( + range(self._lanes), dtype=scalar_ctr_expr.get_dtype().numpy_dtype + ) + step_multiplier = PsExpression.make( + PsConstant(step_multiplier_val, vector_ctr_dtype) + ) + vector_counter_decl = self._type_fold( + PsDeclaration( + PsExpression.make(vector_ctr), + PsVecBroadcast(self._lanes, scalar_ctr_expr) + + step_multiplier * PsVecBroadcast(self._lanes, loop.step), + ) + ) + + # Prepare axis + axis = VectorizationAxis(scalar_ctr, vector_ctr, step=loop.step) + + # Prepare vectorization context + vc = VectorizationContext(self._ctx, self._lanes, axis) + + # Generate vectorized loop body + simd_body = self._vectorize_ast(loop.body, vc) + + if vector_ctr in collect_undefined_symbols(simd_body): + simd_body.statements.insert(0, vector_counter_decl) + + # Build new loop limits + simd_start = loop.start.clone() + + simd_step = self._ctx.get_new_symbol( + f"__{scalar_ctr.name}_simd_step", scalar_ctr.get_dtype() + ) + simd_step_decl = self._type_fold( + PsDeclaration( + PsExpression.make(simd_step), + loop.step.clone() * PsExpression.make(PsConstant(self._lanes)), + ) + ) + + # Each iteration must satisfy `ctr + step * (lanes - 1) < stop` + simd_stop = self._ctx.get_new_symbol( + f"__{scalar_ctr.name}_simd_stop", scalar_ctr.get_dtype() + ) + simd_stop_decl = self._type_fold( + PsDeclaration( + PsExpression.make(simd_stop), + loop.stop.clone() + - ( + PsExpression.make(PsConstant(self._lanes)) + - PsExpression.make(PsConstant(1)) + ) + * loop.step.clone(), + ) + ) + + simd_loop = PsLoop( + PsExpression.make(scalar_ctr), + simd_start, + PsExpression.make(simd_stop), + PsExpression.make(simd_step), + simd_body, + ) + + # Treat trailing iterations + match self._trailing_iters: + case LoopVectorizer.TrailingItersTreatment.SCALAR_LOOP: + trailing_start = self._ctx.get_new_symbol( + f"__{scalar_ctr.name}_trailing_start", scalar_ctr.get_dtype() + ) + trailing_start_decl = self._type_fold( + PsDeclaration( + PsExpression.make(trailing_start), + ( + ( + PsExpression.make(simd_stop) + - simd_start.clone() + - PsExpression.make(PsConstant(1)) + ) + / PsExpression.make(simd_step) + + PsExpression.make(PsConstant(1)) + ) + * PsExpression.make(simd_step) + + simd_start.clone(), + ) + ) + + trailing_ctr = self._ctx.duplicate_symbol(scalar_ctr) + trailing_loop_body = substitute_symbols( + loop.body.clone(), {scalar_ctr: PsExpression.make(trailing_ctr)} + ) + trailing_loop = PsLoop( + PsExpression.make(trailing_ctr), + PsExpression.make(trailing_start), + loop.stop.clone(), + loop.step.clone(), + trailing_loop_body, + ) + + return PsBlock( + [ + simd_stop_decl, + simd_step_decl, + simd_loop, + trailing_start_decl, + trailing_loop, + ] + ) + + case LoopVectorizer.TrailingItersTreatment.MASKED_BLOCK: + raise NotImplementedError() + + case LoopVectorizer.TrailingItersTreatment.NONE: + return PsBlock( + [ + simd_stop_decl, + simd_step_decl, + simd_loop, + ] + ) + + @overload + def _type_fold(self, node: PsExpression) -> PsExpression: + pass + + @overload + def _type_fold(self, node: PsDeclaration) -> PsDeclaration: + pass + + @overload + def _type_fold(self, node: PsAstNode) -> PsAstNode: + pass + + def _type_fold(self, node: PsAstNode) -> PsAstNode: + return self._fold(self._typify(node)) diff --git a/src/pystencils/backend/transformations/lower_to_c.py b/src/pystencils/backend/transformations/lower_to_c.py index ea832355bb1a53f94fc07cad670f86f98e5f6a2e..0576616f2f2989ae72887f5cd720f263017a6b0a 100644 --- a/src/pystencils/backend/transformations/lower_to_c.py +++ b/src/pystencils/backend/transformations/lower_to_c.py @@ -65,7 +65,7 @@ class LowerToC: return i summands: list[PsExpression] = [ - maybe_cast(cast(PsExpression, self.visit(idx))) * PsExpression.make(stride) + maybe_cast(cast(PsExpression, self.visit(idx).clone())) * PsExpression.make(stride) for idx, stride in zip(indices, buf.strides, strict=True) ] @@ -75,7 +75,7 @@ class LowerToC: else reduce(operator.add, summands) ) - mem_acc = PsMemAcc(bptr, linearized_idx) + mem_acc = PsMemAcc(bptr.clone(), linearized_idx) return self._typify.typify_expression( mem_acc, target_type=buf.element_type diff --git a/src/pystencils/backend/transformations/rewrite.py b/src/pystencils/backend/transformations/rewrite.py new file mode 100644 index 0000000000000000000000000000000000000000..59241c295f42eeaf60f4cd03a5138214fdbd6c50 --- /dev/null +++ b/src/pystencils/backend/transformations/rewrite.py @@ -0,0 +1,37 @@ +from typing import overload + +from ..memory import PsSymbol +from ..ast import PsAstNode +from ..ast.structural import PsBlock +from ..ast.expressions import PsExpression, PsSymbolExpr + + +@overload +def substitute_symbols(node: PsBlock, subs: dict[PsSymbol, PsExpression]) -> PsBlock: + pass + + +@overload +def substitute_symbols( + node: PsExpression, subs: dict[PsSymbol, PsExpression] +) -> PsExpression: + pass + + +@overload +def substitute_symbols( + node: PsAstNode, subs: dict[PsSymbol, PsExpression] +) -> PsAstNode: + pass + + +def substitute_symbols( + node: PsAstNode, subs: dict[PsSymbol, PsExpression] +) -> PsAstNode: + """Substitute expressions for symbols throughout a subtree.""" + match node: + case PsSymbolExpr(symb) if symb in subs: + return subs[symb].clone() + case _: + node.children = [substitute_symbols(c, subs) for c in node.children] + return node diff --git a/src/pystencils/backend/transformations/select_intrinsics.py b/src/pystencils/backend/transformations/select_intrinsics.py index 3fb484c154fbb4ab873deea3e9b1d83c2f4354e6..a0eb5ce8f94510a18e826d1defa423b60e0d808d 100644 --- a/src/pystencils/backend/transformations/select_intrinsics.py +++ b/src/pystencils/backend/transformations/select_intrinsics.py @@ -1,133 +1,108 @@ from __future__ import annotations -from typing import TypeVar, TYPE_CHECKING, cast -from enum import Enum, auto +from typing import cast -from ..ast.structural import PsAstNode, PsAssignment, PsStatement +from ..kernelcreation import KernelCreationContext +from ..memory import PsSymbol +from ..ast.structural import PsAstNode, PsDeclaration, PsAssignment, PsStatement from ..ast.expressions import PsExpression -from ...types import PsVectorType, deconstify -from ..ast.expressions import ( - PsVectorMemAcc, - PsSymbolExpr, - PsConstantExpr, - PsBinOp, - PsAdd, - PsSub, - PsMul, - PsDiv, -) -from ..exceptions import PsInternalCompilerError +from ...types import PsVectorType, constify, deconstify +from ..ast.expressions import PsSymbolExpr, PsConstantExpr, PsUnOp, PsBinOp +from ..ast.vector import PsVecMemAcc +from ..exceptions import MaterializationError -if TYPE_CHECKING: - from ..platforms import GenericVectorCpu +from ..platforms import GenericVectorCpu -__all__ = ["IntrinsicOps", "MaterializeVectorIntrinsics"] +__all__ = ["SelectIntrinsics"] -NodeT = TypeVar("NodeT", bound=PsAstNode) +class SelectionContext: + def __init__(self, ctx: KernelCreationContext, platform: GenericVectorCpu): + self._ctx = ctx + self._platform = platform + self._intrin_symbols: dict[PsSymbol, PsSymbol] = dict() + self._lane_mask: PsSymbol | None = None -class IntrinsicOps(Enum): - ADD = auto() - SUB = auto() - MUL = auto() - DIV = auto() - FMA = auto() - + def get_intrin_symbol(self, symb: PsSymbol) -> PsSymbol: + if symb not in self._intrin_symbols: + assert isinstance(symb.dtype, PsVectorType) + intrin_type = self._platform.type_intrinsic(deconstify(symb.dtype)) -class VectorizationError(Exception): - """Exception indicating a fatal error during vectorization.""" + if symb.dtype.const: + intrin_type = constify(intrin_type) + replacement = self._ctx.duplicate_symbol(symb, intrin_type) + self._intrin_symbols[symb] = replacement -class VecTypeCtx: - def __init__(self) -> None: - self._dtype: None | PsVectorType = None + return self._intrin_symbols[symb] - def get(self) -> PsVectorType | None: - return self._dtype - def set(self, dtype: PsVectorType): - dtype = deconstify(dtype) - if self._dtype is not None and dtype != self._dtype: - raise PsInternalCompilerError( - f"Ambiguous vector types: {self._dtype} and {dtype}" - ) - self._dtype = dtype +class SelectIntrinsics: + """Lower IR vector types to intrinsic vector types, and IR vector operations to intrinsic vector operations. + + This transformation will replace all vectorial IR elements by conforming implementations using + compiler intrinsics for the given execution platform. - def reset(self): - self._dtype = None + Args: + ctx: The current kernel creation context + platform: Platform object representing the target hardware, which provides the intrinsics + Raises: + MaterializationError: If a vector type or operation cannot be represented by intrinsics + on the given platform + """ -class MaterializeVectorIntrinsics: - def __init__(self, platform: GenericVectorCpu): + def __init__(self, ctx: KernelCreationContext, platform: GenericVectorCpu): + self._ctx = ctx self._platform = platform def __call__(self, node: PsAstNode) -> PsAstNode: - return self.visit(node) + return self.visit(node, SelectionContext(self._ctx, self._platform)) - def visit(self, node: PsAstNode) -> PsAstNode: + def visit(self, node: PsAstNode, sc: SelectionContext) -> PsAstNode: match node: - case PsAssignment(lhs, rhs) if isinstance(lhs, PsVectorMemAcc): - vc = VecTypeCtx() - vc.set(lhs.get_vector_type()) - store_arg = self.visit_expr(rhs, vc) - return PsStatement(self._platform.vector_store(lhs, store_arg)) - case PsExpression(): - return self.visit_expr(node, VecTypeCtx()) + case PsExpression() if isinstance(node.dtype, PsVectorType): + return self.visit_expr(node, sc) + + case PsDeclaration(lhs, rhs) if isinstance(lhs.dtype, PsVectorType): + lhs_new = cast(PsSymbolExpr, self.visit_expr(lhs, sc)) + rhs_new = self.visit_expr(rhs, sc) + return PsDeclaration(lhs_new, rhs_new) + + case PsAssignment(lhs, rhs) if isinstance(lhs, PsVecMemAcc): + new_rhs = self.visit_expr(rhs, sc) + return PsStatement(self._platform.vector_store(lhs, new_rhs)) + case _: - node.children = [self(c) for c in node.children] - return node + node.children = [self.visit(c, sc) for c in node.children] + + return node + + def visit_expr(self, expr: PsExpression, sc: SelectionContext) -> PsExpression: + if not isinstance(expr.dtype, PsVectorType): + return expr - def visit_expr(self, expr: PsExpression, vc: VecTypeCtx) -> PsExpression: match expr: case PsSymbolExpr(symb): - if isinstance(symb.dtype, PsVectorType): - intrin_type = self._platform.type_intrinsic(symb.dtype) - vc.set(symb.dtype) - symb.dtype = intrin_type - - return expr + return PsSymbolExpr(sc.get_intrin_symbol(symb)) case PsConstantExpr(c): - if isinstance(c.dtype, PsVectorType): - vc.set(c.dtype) - return self._platform.constant_vector(c) - else: - return expr - - case PsVectorMemAcc(): - vc.set(expr.get_vector_type()) + return self._platform.constant_intrinsic(c) + + case PsUnOp(operand): + op = self.visit_expr(operand, sc) + return self._platform.op_intrinsic(expr, [op]) + + case PsBinOp(operand1, operand2): + op1 = self.visit_expr(operand1, sc) + op2 = self.visit_expr(operand2, sc) + + return self._platform.op_intrinsic(expr, [op1, op2]) + + case PsVecMemAcc(): return self._platform.vector_load(expr) - case PsBinOp(op1, op2): - op1 = self.visit_expr(op1, vc) - op2 = self.visit_expr(op2, vc) - - vtype = vc.get() - if vtype is not None: - return self._platform.op_intrinsic( - _intrin_op(expr), vtype, [op1, op2] - ) - else: - return expr - - case expr: - expr.children = [ - self.visit_expr(cast(PsExpression, c), vc) for c in expr.children - ] - if vc.get() is not None: - raise VectorizationError(f"Don't know how to vectorize {expr}") - return expr - - -def _intrin_op(expr: PsBinOp) -> IntrinsicOps: - match expr: - case PsAdd(): - return IntrinsicOps.ADD - case PsSub(): - return IntrinsicOps.SUB - case PsMul(): - return IntrinsicOps.MUL - case PsDiv(): - return IntrinsicOps.DIV - case _: - assert False + case _: + raise MaterializationError( + f"Unable to select intrinsic implementation for {expr}" + ) diff --git a/src/pystencils/config.py b/src/pystencils/config.py index c8cf489f7a29857232ec9211479c26796850f4dc..db2fd545ad5b1caa3d891cae45418cea70088b68 100644 --- a/src/pystencils/config.py +++ b/src/pystencils/config.py @@ -7,7 +7,7 @@ from collections.abc import Collection from typing import Sequence from dataclasses import dataclass, InitVar -from .enums import Target +from .target import Target from .field import Field, FieldType from .types import PsIntegerType, UserTypeSpec, PsIeeeFloatType, create_type diff --git a/src/pystencils/datahandling/__init__.py b/src/pystencils/datahandling/__init__.py index 18053d2d9d6546bcb5ac2093f5f63c1633965a4e..76a494255a69c7cb880d362ff6eb1835a8f1e33a 100644 --- a/src/pystencils/datahandling/__init__.py +++ b/src/pystencils/datahandling/__init__.py @@ -3,7 +3,7 @@ import warnings from typing import Tuple, Union from .datahandling_interface import DataHandling -from ..enums import Target +from ..target import Target from .serial_datahandling import SerialDataHandling try: diff --git a/src/pystencils/datahandling/datahandling_interface.py b/src/pystencils/datahandling/datahandling_interface.py index 33b565e82fdce9a12e516ca4165fd53fe53799af..f42c4ef138e04eede2719cbce69f9975b656cb30 100644 --- a/src/pystencils/datahandling/datahandling_interface.py +++ b/src/pystencils/datahandling/datahandling_interface.py @@ -3,7 +3,7 @@ from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union import numpy as np -from pystencils.enums import Target +from pystencils.target import Target from pystencils.field import Field, FieldType @@ -82,7 +82,7 @@ class DataHandling(ABC): >>> dh = create_data_handling((20, 30)) >>> x, y =dh.add_arrays('x, y(9)') >>> print(dh.fields) - {'x': x: double[22,32], 'y': y(9): double[22,32]} + {'x': x: float64[22,32], 'y': y(9): float64[22,32]} >>> assert x == dh.fields['x'] >>> assert dh.fields['x'].shape == (22, 32) >>> assert dh.fields['y'].index_shape == (9,) diff --git a/src/pystencils/datahandling/serial_datahandling.py b/src/pystencils/datahandling/serial_datahandling.py index ba705f4b9f665ea9eaf81ad42863196c25901df8..8521dda1014ab4e031706a1b21f4ae3c28259a05 100644 --- a/src/pystencils/datahandling/serial_datahandling.py +++ b/src/pystencils/datahandling/serial_datahandling.py @@ -6,7 +6,7 @@ import numpy as np from pystencils.datahandling.blockiteration import SerialBlock from pystencils.datahandling.datahandling_interface import DataHandling -from pystencils.enums import Target +from pystencils.target import Target from pystencils.field import (Field, FieldType, create_numpy_array_with_layout, layout_string_to_tuple, spatial_layout_string_to_tuple) from pystencils.gpu.gpu_array_handler import GPUArrayHandler, GPUNotAvailableHandler diff --git a/src/pystencils/enums.py b/src/pystencils/enums.py index 23c255ef0949e02ac5b0af57551ceec1bf6cfee2..86048059d67e6132223825e3b94588b35e30796e 100644 --- a/src/pystencils/enums.py +++ b/src/pystencils/enums.py @@ -1,99 +1,11 @@ -from enum import Flag, auto +from .target import Target as _Target +from warnings import warn -class Target(Flag): - """ - The Target enumeration represents all possible targets that can be used for the code generation. - """ +warn( + "Importing anything from `pystencils.enums` is deprecated and the module will be removed in pystencils 2.1. " + "Import from `pystencils` instead.", + FutureWarning +) - # ------------------ Component Flags - Do Not Use Directly! ------------------------------------------- - - _CPU = auto() - - _VECTOR = auto() - - _X86 = auto() - _SSE = auto() - _AVX = auto() - _AVX512 = auto() - - _ARM = auto() - _NEON = auto() - _SVE = auto() - - _GPU = auto() - - _CUDA = auto() - - _SYCL = auto() - - _AUTOMATIC = auto() - - # ------------------ Actual Targets ------------------------------------------------------------------- - - CurrentCPU = _CPU | _AUTOMATIC - """ - Auto-best CPU target. - - `CurrentCPU` causes the code generator to automatically select a CPU target according to CPUs found - on the current machine and runtime environment. - """ - - GenericCPU = _CPU - """Generic CPU target. - - Generate the kernel for a generic multicore CPU architecture. This opens up all architecture-independent - optimizations including OpenMP, but no vectorization. - """ - - CPU = GenericCPU - """Alias for backward-compatibility""" - - X86_SSE = _CPU | _VECTOR | _X86 | _SSE - """x86 architecture with SSE vector extensions.""" - - X86_AVX = _CPU | _VECTOR | _X86 | _AVX - """x86 architecture with AVX vector extensions.""" - - X86_AVX512 = _CPU | _VECTOR | _X86 | _AVX512 - """x86 architecture with AVX512 vector extensions.""" - - ARM_NEON = _CPU | _VECTOR | _ARM | _NEON - """ARM architecture with NEON vector extensions""" - - ARM_SVE = _CPU | _VECTOR | _ARM | _SVE - """ARM architecture with SVE vector extensions""" - - CurrentGPU = _GPU | _AUTOMATIC - """Auto-best GPU target. - - `CurrentGPU` causes the code generator to automatically select a GPU target according to GPU devices - found on the current machine and runtime environment. - """ - - CUDA = _GPU | _CUDA - """Generic CUDA GPU target. - - Generate a CUDA kernel for a generic Nvidia GPU. - """ - - GPU = CUDA - """Alias for backward compatibility.""" - - SYCL = _GPU | _SYCL - """SYCL kernel target. - - Generate a function to be called within a SYCL parallel command. - """ - - def is_automatic(self) -> bool: - return Target._AUTOMATIC in self - - def is_cpu(self) -> bool: - return Target._CPU in self - - def is_vector_cpu(self) -> bool: - return self.is_cpu() and Target._VECTOR in self - - def is_gpu(self) -> bool: - return Target._GPU in self +Target = _Target diff --git a/src/pystencils/field.py b/src/pystencils/field.py index 51f01deb1d16fbf363e057c3d7494cd72595ea71..c813c6154319def9289776608569dec6e4854e32 100644 --- a/src/pystencils/field.py +++ b/src/pystencils/field.py @@ -866,7 +866,7 @@ def fields(description=None, index_dimensions=0, layout=None, Format string can be left out, field names are taken from keyword arguments. >>> fields(f1=arr_s, f2=arr_s) - [f1: double[20,20], f2: double[20,20]] + [f1: float64[20,20], f2: float64[20,20]] The keyword names ``index_dimension`` and ``layout`` have special meaning, don't use them for field names >>> f = fields(f=arr_v, index_dimensions=1) diff --git a/src/pystencils/gpu/periodicity.py b/src/pystencils/gpu/periodicity.py index 6569fbb0f14ab6b44add1c93cf8b2210699deef4..0e1ec754bec8d232f62751bfdb5dee072cabcca5 100644 --- a/src/pystencils/gpu/periodicity.py +++ b/src/pystencils/gpu/periodicity.py @@ -2,8 +2,7 @@ import numpy as np from itertools import product from pystencils import CreateKernelConfig, create_kernel -from pystencils import Assignment, Field -from pystencils.enums import Target +from pystencils import Assignment, Field, Target from pystencils.slicing import get_periodic_boundary_src_dst_slices, normalize_slice diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 7d9ac7aa4465c264855d79ae7d56260e0dd698eb..a10fadf018236e2ca9275d8a39be681d3e3adc5f 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -1,7 +1,7 @@ from typing import cast, Sequence from dataclasses import replace -from .enums import Target +from .target import Target from .config import CreateKernelConfig from .backend import KernelFunction from .types import create_numeric_type, PsIntegerType diff --git a/src/pystencils/sympyextensions/__init__.py b/src/pystencils/sympyextensions/__init__.py index 847a4380b2ccdbac11c3142d19f3db476ee7880b..7431416c9eb9bcd4433dab76c32fb1b755501105 100644 --- a/src/pystencils/sympyextensions/__init__.py +++ b/src/pystencils/sympyextensions/__init__.py @@ -1,5 +1,6 @@ from .astnodes import ConditionalFieldAccess from .typed_sympy import TypedSymbol, CastFunc +from .pointers import mem_acc from .math import ( prod, @@ -34,6 +35,7 @@ __all__ = [ "ConditionalFieldAccess", "TypedSymbol", "CastFunc", + "mem_acc", "remove_higher_order_terms", "prod", "remove_small_floats", diff --git a/src/pystencils/target.py b/src/pystencils/target.py new file mode 100644 index 0000000000000000000000000000000000000000..7ed4f719d235d5075d517b2655d5fc812c1cd7c8 --- /dev/null +++ b/src/pystencils/target.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +from enum import Flag, auto +from warnings import warn +from functools import cache + + +class Target(Flag): + """ + The Target enumeration represents all possible targets that can be used for code generation. + """ + + # ------------------ Component Flags - Do Not Use Directly! ------------------------------------------- + + _CPU = auto() + + _VECTOR = auto() + + _X86 = auto() + _SSE = auto() + _AVX = auto() + _AVX512 = auto() + _VL = auto() + _FP16 = auto() + + _ARM = auto() + _NEON = auto() + _SVE = auto() + + _GPU = auto() + + _CUDA = auto() + + _SYCL = auto() + + _AUTOMATIC = auto() + + # ------------------ Actual Targets ------------------------------------------------------------------- + + CurrentCPU = _CPU | _AUTOMATIC + """ + Auto-best CPU target. + + `CurrentCPU` causes the code generator to automatically select a CPU target according to CPUs found + on the current machine and runtime environment. + """ + + GenericCPU = _CPU + """Generic CPU target. + + Generate the kernel for a generic multicore CPU architecture. This opens up all architecture-independent + optimizations including OpenMP, but no vectorization. + """ + + CPU = GenericCPU + """Alias for backward-compatibility""" + + X86_SSE = _CPU | _VECTOR | _X86 | _SSE + """x86 architecture with SSE vector extensions.""" + + X86_AVX = _CPU | _VECTOR | _X86 | _AVX + """x86 architecture with AVX vector extensions.""" + + X86_AVX512 = _CPU | _VECTOR | _X86 | _AVX512 + """x86 architecture with AVX512 vector extensions.""" + + X86_AVX512_FP16 = _CPU | _VECTOR | _X86 | _AVX512 | _FP16 + """x86 architecture with AVX512 vector extensions and fp16-support.""" + + ARM_NEON = _CPU | _VECTOR | _ARM | _NEON + """ARM architecture with NEON vector extensions""" + + ARM_SVE = _CPU | _VECTOR | _ARM | _SVE + """ARM architecture with SVE vector extensions""" + + CurrentGPU = _GPU | _AUTOMATIC + """Auto-best GPU target. + + `CurrentGPU` causes the code generator to automatically select a GPU target according to GPU devices + found on the current machine and runtime environment. + """ + + CUDA = _GPU | _CUDA + """Generic CUDA GPU target. + + Generate a CUDA kernel for a generic Nvidia GPU. + """ + + GPU = CUDA + """Alias for backward compatibility.""" + + SYCL = _GPU | _SYCL + """SYCL kernel target. + + Generate a function to be called within a SYCL parallel command. + """ + + def is_automatic(self) -> bool: + return Target._AUTOMATIC in self + + def is_cpu(self) -> bool: + return Target._CPU in self + + def is_vector_cpu(self) -> bool: + return self.is_cpu() and Target._VECTOR in self + + def is_gpu(self) -> bool: + return Target._GPU in self + + @staticmethod + def auto_cpu() -> Target: + """Return the most capable vector CPU target available on the current machine.""" + avail_targets = _available_vector_targets() + if avail_targets: + return avail_targets.pop() + else: + return Target.GenericCPU + + @staticmethod + def available_vector_cpu_targets() -> list[Target]: + """Returns a list of available (vector) CPU targets, ordered from least to most capable.""" + return _available_vector_targets() + + +@cache +def _available_vector_targets() -> list[Target]: + """Returns available vector targets, sorted from leat to most capable.""" + + targets: list[Target] = [] + + import platform + + if platform.machine() in ["x86_64", "x86", "AMD64", "i386"]: + try: + from cpuinfo import get_cpu_info + except ImportError: + warn( + "Unable to determine available x86 vector CPU targets for this system: " + "py-cpuinfo is not available.", + UserWarning, + ) + return [] + + flags = set(get_cpu_info()["flags"]) + + if {"sse", "sse2", "ssse3", "sse4_1", "sse4_2"} < flags: + targets.append(Target.X86_SSE) + + if {"avx", "avx2"} < flags: + targets.append(Target.X86_AVX) + + if {"avx512f"} < flags: + targets.append(Target.X86_AVX512) + + if {"avx512_fp16"} < flags: + targets.append(Target.X86_AVX512_FP16) + else: + warn( + "Unable to determine available vector CPU targets for this system: " + f"unknown platform {platform.machine()}.", + UserWarning, + ) + + return targets diff --git a/src/pystencils/types/parsing.py b/src/pystencils/types/parsing.py index 8e7d27f58265c08461cba6b05373848112a6fee7..c7a54d1a0831f9f7d01039f19e8e9f1395fe09c4 100644 --- a/src/pystencils/types/parsing.py +++ b/src/pystencils/types/parsing.py @@ -64,6 +64,8 @@ def interpret_python_type(t: type) -> PsType: return PsSignedIntegerType(64) if t is float: return PsIeeeFloatType(64) + if t is bool: + return PsBoolType() if t is np.uint8: return PsUnsignedIntegerType(8) @@ -89,6 +91,9 @@ def interpret_python_type(t: type) -> PsType: return PsIeeeFloatType(32) if t is np.float64: return PsIeeeFloatType(64) + + if t is np.bool_: + return PsBoolType() raise ValueError(f"Could not interpret Python data type {t} as a pystencils type.") diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index d3d18720cf1ff3c4af14f6c276da52098adfbdd2..6e4f65b85b486dee715758c2529a9cd914483f50 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -6,7 +6,7 @@ from dataclasses import dataclass import numpy as np from .exception import PsTypeError -from .meta import PsType, constify, deconstify +from .meta import PsType, deconstify class PsCustomType(PsType): @@ -335,6 +335,8 @@ class PsScalarType(PsNumericType, ABC): class PsVectorType(PsNumericType): """Packed vector of numeric type. + The packed vector's element type will always be made non-const. + Args: element_type: Underlying scalar data type num_entries: Number of entries in the vector @@ -345,7 +347,7 @@ class PsVectorType(PsNumericType): ): super().__init__(const) self._vector_entries = vector_entries - self._scalar_type = constify(scalar_type) if const else deconstify(scalar_type) + self._scalar_type = deconstify(scalar_type) def __args__(self) -> tuple[Any, ...]: """ @@ -410,7 +412,7 @@ class PsVectorType(PsNumericType): raise PsTypeError("Cannot retrieve C type string for generic vector types.") def __str__(self) -> str: - return f"vector[{self._scalar_type}, {self._vector_entries}]" + return f"{self._scalar_type}<{self._vector_entries}>" def __repr__(self) -> str: return ( @@ -521,7 +523,7 @@ class PsIntegerType(PsScalarType, ABC): case w if w < 32: # Plain integer literals get at least type `int`, which is 32 bit in all relevant cases # So we need to explicitly cast to smaller types - return f"(({self._c_type_without_const()}) {value}{unsigned_suffix})" + return f"(({deconstify(self).c_string()}) {value}{unsigned_suffix})" case 32: # No suffix here - becomes `int`, which is 32 bit return f"{value}{unsigned_suffix}" @@ -544,12 +546,15 @@ class PsIntegerType(PsScalarType, ABC): raise PsTypeError(f"Could not interpret {value} as {repr(self)}") - def _c_type_without_const(self) -> str: + def _str_without_const(self) -> str: prefix = "" if self._signed else "u" - return f"{prefix}int{self._width}_t" + return f"{prefix}int{self._width}" def c_string(self) -> str: - return f"{self._const_string()}{self._c_type_without_const()}" + return f"{self._const_string()}{self._str_without_const()}_t" + + def __str__(self) -> str: + return f"{self._const_string()}{self._str_without_const()}" def __repr__(self) -> str: return f"PsIntegerType( width={self.width}, signed={self.signed}, const={self.const} )" @@ -694,5 +699,8 @@ class PsIeeeFloatType(PsScalarType): case _: assert False, "unreachable code" + def __str__(self) -> str: + return f"{self._const_string()}float{self._width}" + def __repr__(self) -> str: return f"PsIeeeFloatType( width={self.width}, const={self.const} )" diff --git a/src/pystencils/utils.py b/src/pystencils/utils.py index de98e44316e259c81bbcc3b3ce2aa7c490f7a5e8..a53eb82896ab635c3995be918d31a03326766d5d 100644 --- a/src/pystencils/utils.py +++ b/src/pystencils/utils.py @@ -12,6 +12,7 @@ import sympy as sp class DotDict(dict): """Normal dict with additional dot access for all keys""" + __getattr__ = dict.get __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ @@ -105,7 +106,7 @@ def binary_numbers(n): result = list() for i in range(1 << n): binary_number = bin(i)[2:] - binary_number = '0' * (n - len(binary_number)) + binary_number + binary_number = "0" * (n - len(binary_number)) + binary_number result.append((list(map(int, binary_number)))) return result @@ -129,6 +130,7 @@ class LinearEquationSystem: {x: 7/2, y: 1/2} """ + def __init__(self, unknowns): size = len(unknowns) self._matrix = sp.zeros(size, size + 1) @@ -145,7 +147,7 @@ class LinearEquationSystem: def add_equation(self, linear_equation): """Add a linear equation as sympy expression. Implicit "-0" is assumed. Equation has to be linear and contain - only unknowns passed to the constructor otherwise a ValueError is raised. """ + only unknowns passed to the constructor otherwise a ValueError is raised.""" self._resize_if_necessary() linear_equation = linear_equation.expand() zero_row_idx = self.next_zero_row @@ -162,7 +164,7 @@ class LinearEquationSystem: self._reduced = False def add_equations(self, linear_equations): - """Add a sequence of equations. For details see `add_equation`. """ + """Add a sequence of equations. For details see `add_equation`.""" self._resize_if_necessary(len(linear_equations)) for eq in linear_equations: self.add_equation(eq) @@ -201,21 +203,21 @@ class LinearEquationSystem: non_zero_rows = self.next_zero_row num_unknowns = len(self.unknowns) if non_zero_rows == 0: - return 'multiple' + return "multiple" *row_begin, left, right = self._matrix.row(non_zero_rows - 1) if non_zero_rows > num_unknowns: - return 'none' + return "none" elif non_zero_rows == num_unknowns: if left == 0 and right != 0: - return 'none' + return "none" else: - return 'single' + return "single" elif non_zero_rows < num_unknowns: if right != 0 and left == 0 and all(e == 0 for e in row_begin): - return 'none' + return "none" else: - return 'multiple' + return "multiple" def solution(self): """Solves the system. Under- and overdetermined systems are supported. @@ -224,8 +226,9 @@ class LinearEquationSystem: def _resize_if_necessary(self, new_rows=1): if self.next_zero_row + new_rows > self._matrix.shape[0]: - self._matrix = self._matrix.row_insert(self._matrix.shape[0] + 1, - sp.zeros(new_rows, self._matrix.shape[1])) + self._matrix = self._matrix.row_insert( + self._matrix.shape[0] + 1, sp.zeros(new_rows, self._matrix.shape[1]) + ) def _update_next_zero_row(self): result = self._matrix.shape[0] @@ -253,7 +256,15 @@ class ContextVar: def c_intdiv(num, denom): """C-style integer division""" - return int(num / denom) + if isinstance(num, np.ndarray) or isinstance(denom, np.ndarray): + rtype = np.result_type(num, denom) + if not np.issubdtype(rtype, np.integer): + raise TypeError( + "Invalid numpy argument types to c_intdiv: Must be integers." + ) + return (num / denom).astype(rtype) + else: + return int(num / denom) def c_rem(num, denom): diff --git a/tests/frontend/test_pickle_support.py b/tests/frontend/test_pickle_support.py index 55ef2fb1cc4b833a73a8639e4cab78287d7169a1..b00fcad5960661a4052280a6ab05b20352f94602 100644 --- a/tests/frontend/test_pickle_support.py +++ b/tests/frontend/test_pickle_support.py @@ -2,6 +2,7 @@ from copy import copy, deepcopy from pystencils.field import Field from pystencils.sympyextensions import TypedSymbol +from pystencils.types import create_type def test_field_access(): @@ -15,4 +16,4 @@ def test_typed_symbol(): ts = TypedSymbol("s", "double") copy(ts) ts_copy = deepcopy(ts) - assert str(ts_copy.dtype).strip() == "double" + assert ts_copy.dtype == create_type("double") diff --git a/tests/kernelcreation/test_functions.py b/tests/kernelcreation/test_functions.py index 4094556a09dc89d0fe4616c8934332a9abaaa879..e16201f819161fc611ec16b09dc98a7d7abf9445 100644 --- a/tests/kernelcreation/test_functions.py +++ b/tests/kernelcreation/test_functions.py @@ -62,7 +62,7 @@ def test_unary_functions(target, function_name, dtype): xp = np sp_func, xp_func = unary_function(function_name, xp) - resolution: dtype = np.finfo(dtype).resolution + resolution = np.finfo(dtype).resolution inp = xp.array( [[0.1, 0.2, 0.3], [-0.8, -1.6, -12.592], [xp.pi, xp.e, 0.0]], dtype=dtype diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 988fa4bb8b10c2c243abfd3a171657ad6bf5e418..3defe4ad539b19d5c077a29bc96bf7878e4d5f0c 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -15,7 +15,6 @@ from pystencils.backend.ast.structural import ( PsBlock, ) from pystencils.backend.ast.expressions import ( - PsAddressOf, PsArrayInitList, PsCast, PsConstantExpr, @@ -35,9 +34,10 @@ from pystencils.backend.ast.expressions import ( PsTernary, PsMemAcc ) +from pystencils.backend.ast.vector import PsVecBroadcast from pystencils.backend.constants import PsConstant from pystencils.backend.functions import CFunction -from pystencils.types import constify, create_type, create_numeric_type +from pystencils.types import constify, create_type, create_numeric_type, PsVectorType from pystencils.types.quick import Fp, Int, Bool, Arr, Ptr from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions @@ -574,10 +574,6 @@ def test_invalid_conditions(): with pytest.raises(TypificationError): typify(cond) - cond = PsConditional(PsAnd(p, PsOr(x, q)), PsBlock([])) - with pytest.raises(TypificationError): - typify(cond) - def test_typify_ternary(): ctx = KernelCreationContext() @@ -625,6 +621,34 @@ def test_cfunction(): _ = typify(PsCall(threeway, (x, p))) +def test_typify_integer_vectors(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + + a, b, c = [PsExpression.make(ctx.get_symbol(name, PsVectorType(Int(32), 4))) for name in "abc"] + d, e = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "de"] + + result = typify(a + (b / c) - a * c) + assert result.get_dtype() == PsVectorType(Int(32), 4) + + result = typify(PsVecBroadcast(4, d - e) - PsVecBroadcast(4, e / d)) + assert result.get_dtype() == PsVectorType(Int(32), 4) + + +def test_typify_bool_vectors(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + + x, y = [PsExpression.make(ctx.get_symbol(name, PsVectorType(Fp(32), 4))) for name in "xy"] + p, q = [PsExpression.make(ctx.get_symbol(name, PsVectorType(Bool(), 4))) for name in "pq"] + + result = typify(PsAnd(PsOr(p, q), p)) + assert result.get_dtype() == PsVectorType(Bool(), 4) + + result = typify(PsAnd(PsLt(x, y), PsGe(y, x))) + assert result.get_dtype() == PsVectorType(Bool(), 4) + + def test_inference_fails(): ctx = KernelCreationContext() typify = Typifier(ctx) diff --git a/tests/nbackend/test_ast.py b/tests/nbackend/test_ast.py index 2408b8d867038a0f2fd5c4d8a5f22bc82312c701..5962208a7fda9598dc3aa2cf2b068dee3b054f3f 100644 --- a/tests/nbackend/test_ast.py +++ b/tests/nbackend/test_ast.py @@ -1,8 +1,12 @@ import pytest from pystencils import create_type -from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory, Typifier -from pystencils.backend.memory import PsSymbol, BufferBasePtr +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + Typifier, +) +from pystencils.backend.memory import BufferBasePtr from pystencils.backend.constants import PsConstant from pystencils.backend.ast.expressions import ( PsExpression, @@ -12,6 +16,9 @@ from pystencils.backend.ast.expressions import ( PsSubscript, PsBufferAcc, PsSymbolExpr, + PsLe, + PsGe, + PsAnd, ) from pystencils.backend.ast.structural import ( PsStatement, @@ -23,7 +30,7 @@ from pystencils.backend.ast.structural import ( PsPragma, PsLoop, ) -from pystencils.types.quick import Fp, Ptr +from pystencils.types.quick import Fp, Ptr, Bool def test_cloning(): @@ -32,7 +39,9 @@ def test_cloning(): x, y, z, m = [PsExpression.make(ctx.get_symbol(name)) for name in "xyzm"] q = PsExpression.make(ctx.get_symbol("q", create_type("bool"))) - a, b, c = [PsExpression.make(ctx.get_symbol(name, ctx.index_dtype)) for name in "abc"] + a, b, c = [ + PsExpression.make(ctx.get_symbol(name, ctx.index_dtype)) for name in "abc" + ] c1 = PsExpression.make(PsConstant(3.0)) c2 = PsExpression.make(PsConstant(-1.0)) one_f = PsExpression.make(PsConstant(1.0)) @@ -42,7 +51,7 @@ def test_cloning(): assert not (orig is clone) assert type(orig) is type(clone) assert orig.structurally_equal(clone) - + if isinstance(orig, PsExpression): # Regression: Expression data types used to not be cloned assert orig.dtype == clone.dtype @@ -63,13 +72,7 @@ def test_cloning(): PsConditional( q, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")]) ), - PsDeclaration( - m, - PsArrayInitList([ - [x, y, one_f + x], - [one_f, c2, z] - ]) - ), + PsDeclaration(m, PsArrayInitList([[x, y, one_f + x], [one_f, c2, z]])), PsPragma("omp parallel for"), PsLoop( a, @@ -84,7 +87,9 @@ def test_cloning(): PsPragma("#pragma clang loop vectorize(enable)"), PsStatement( PsMemAcc(PsCast(Ptr(Fp(32)), z), one_i) - + PsCast(Fp(32), PsSubscript(m, (one_i + one_i + one_i, b + one_i))) + + PsCast( + Fp(32), PsSubscript(m, (one_i + one_i + one_i, b + one_i)) + ) ), ] ), @@ -106,7 +111,10 @@ def test_buffer_acc(): f_buf = ctx.get_buffer(f) - f_acc = PsBufferAcc(f_buf.base_pointer, [PsExpression.make(i) for i in (a, b)] + [factory.parse_index(0)]) + f_acc = PsBufferAcc( + f_buf.base_pointer, + [PsExpression.make(i) for i in (a, b)] + [factory.parse_index(0)], + ) assert f_acc.buffer == f_buf assert f_acc.base_pointer.structurally_equal(PsSymbolExpr(f_buf.base_pointer)) @@ -121,11 +129,16 @@ def test_buffer_acc(): g_buf = ctx.get_buffer(g) - g_acc = PsBufferAcc(g_buf.base_pointer, [PsExpression.make(i) for i in (a, b)] + [factory.parse_index(2)]) + g_acc = PsBufferAcc( + g_buf.base_pointer, + [PsExpression.make(i) for i in (a, b)] + [factory.parse_index(2)], + ) assert g_acc.buffer == g_buf assert g_acc.base_pointer.structurally_equal(PsSymbolExpr(g_buf.base_pointer)) - second_bptr = PsExpression.make(ctx.get_symbol("data_g_interior", g_buf.base_pointer.dtype)) + second_bptr = PsExpression.make( + ctx.get_symbol("data_g_interior", g_buf.base_pointer.dtype) + ) second_bptr.symbol.add_property(BufferBasePtr(g_buf)) g_acc.base_pointer = second_bptr diff --git a/tests/nbackend/test_vectorization.py b/tests/nbackend/test_vectorization.py new file mode 100644 index 0000000000000000000000000000000000000000..0af614b23db4fee7056001334d185c171c66cbea --- /dev/null +++ b/tests/nbackend/test_vectorization.py @@ -0,0 +1,227 @@ +import pytest +import sympy as sp +import numpy as np +from dataclasses import dataclass +from itertools import chain + +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + FullIterationSpace, +) +from pystencils.backend.platforms import GenericVectorCpu, X86VectorArch, X86VectorCpu +from pystencils.backend.ast.structural import PsBlock +from pystencils.backend.transformations import ( + LoopVectorizer, + SelectIntrinsics, + LowerToC, +) +from pystencils.backend.constants import PsConstant +from pystencils.backend.kernelfunction import create_cpu_kernel_function +from pystencils.backend.jit import LegacyCpuJit + +from pystencils import Target, fields, Assignment, Field +from pystencils.field import create_numpy_array_with_layout +from pystencils.types import PsScalarType, PsIntegerType +from pystencils.types.quick import SInt, Fp + + +@dataclass +class VectorTestSetup: + platform: GenericVectorCpu + lanes: int + numeric_dtype: PsScalarType + index_dtype: PsIntegerType + + @property + def name(self) -> str: + if isinstance(self.platform, X86VectorCpu): + match self.platform.vector_arch: + case X86VectorArch.SSE: + isa = "SSE" + case X86VectorArch.AVX: + isa = "AVX" + case X86VectorArch.AVX512: + isa = "AVX512" + case X86VectorArch.AVX512_FP16: + isa = "AVX512_FP16" + else: + assert False + + return f"{isa}/{self.numeric_dtype}<{self.lanes}>/{self.index_dtype}" + + +def get_setups(target: Target) -> list[VectorTestSetup]: + match target: + case Target.X86_SSE: + sse_platform = X86VectorCpu(X86VectorArch.SSE) + return [ + VectorTestSetup(sse_platform, 4, Fp(32), SInt(32)), + VectorTestSetup(sse_platform, 2, Fp(64), SInt(64)), + ] + + case Target.X86_AVX: + avx_platform = X86VectorCpu(X86VectorArch.AVX) + return [ + VectorTestSetup(avx_platform, 4, Fp(32), SInt(32)), + VectorTestSetup(avx_platform, 8, Fp(32), SInt(32)), + VectorTestSetup(avx_platform, 2, Fp(64), SInt(64)), + VectorTestSetup(avx_platform, 4, Fp(64), SInt(64)), + ] + + case Target.X86_AVX512: + avx512_platform = X86VectorCpu(X86VectorArch.AVX512) + return [ + VectorTestSetup(avx512_platform, 4, Fp(32), SInt(32)), + VectorTestSetup(avx512_platform, 8, Fp(32), SInt(32)), + VectorTestSetup(avx512_platform, 16, Fp(32), SInt(32)), + VectorTestSetup(avx512_platform, 2, Fp(64), SInt(64)), + VectorTestSetup(avx512_platform, 4, Fp(64), SInt(64)), + VectorTestSetup(avx512_platform, 8, Fp(64), SInt(64)), + ] + + case Target.X86_AVX512_FP16: + avx512_platform = X86VectorCpu(X86VectorArch.AVX512_FP16) + return [ + VectorTestSetup(avx512_platform, 8, Fp(16), SInt(32)), + VectorTestSetup(avx512_platform, 16, Fp(16), SInt(32)), + VectorTestSetup(avx512_platform, 32, Fp(16), SInt(32)), + ] + + case _: + return [] + + +TEST_SETUPS: list[VectorTestSetup] = list( + chain.from_iterable(get_setups(t) for t in Target.available_vector_cpu_targets()) +) + +TEST_IDS = [t.name for t in TEST_SETUPS] + + +def create_vector_kernel( + assignments: list[Assignment], + field: Field, + setup: VectorTestSetup, + ghost_layers: int = 0, +): + ctx = KernelCreationContext( + default_dtype=setup.numeric_dtype, index_dtype=setup.index_dtype + ) + + factory = AstFactory(ctx) + + ispace = FullIterationSpace.create_with_ghost_layers(ctx, ghost_layers, field) + ctx.set_iteration_space(ispace) + + body = PsBlock([factory.parse_sympy(asm) for asm in assignments]) + + loop_order = field.layout + loop_nest = factory.loops_from_ispace(ispace, body, loop_order) + + for field in ctx.fields: + # Set inner strides to one to ensure packed memory access + buf = ctx.get_buffer(field) + buf.strides[0] = PsConstant(1, ctx.index_dtype) + + vectorize = LoopVectorizer(ctx, setup.lanes) + loop_nest = vectorize.vectorize_select_loops( + loop_nest, lambda l: l.counter.symbol.name == "ctr_0" + ) + + select_intrin = SelectIntrinsics(ctx, setup.platform) + loop_nest = select_intrin(loop_nest) + + lower = LowerToC(ctx) + loop_nest = lower(loop_nest) + + func = create_cpu_kernel_function( + ctx, + setup.platform, + PsBlock([loop_nest]), + "vector_kernel", + Target.CPU, + LegacyCpuJit(), + ) + + kernel = func.compile() + return kernel + + +@pytest.mark.parametrize("setup", TEST_SETUPS, ids=TEST_IDS) +@pytest.mark.parametrize("ghost_layers", [0, 2]) +def test_update_kernel(setup: VectorTestSetup, ghost_layers: int): + src, dst = fields(f"src(2), dst(4): {setup.numeric_dtype}[2D]", layout="fzyx") + + x = sp.symbols("x_:4") + + update = [ + Assignment(x[0], src[0, 0](0) + src[0, 0](1)), + Assignment(x[1], src[0, 0](0) - src[0, 0](1)), + Assignment(x[2], src[0, 0](0) * src[0, 0](1)), + Assignment(x[3], src[0, 0](0) / src[0, 0](1)), + Assignment(dst.center(0), x[0]), + Assignment(dst.center(1), x[1]), + Assignment(dst.center(2), x[2]), + Assignment(dst.center(3), x[3]), + ] + + kernel = create_vector_kernel(update, src, setup, ghost_layers) + + shape = (23, 17) + + rgen = np.random.default_rng(seed=1648) + src_arr = create_numpy_array_with_layout( + shape + (2,), layout=(2, 1, 0), dtype=setup.numeric_dtype.numpy_dtype + ) + rgen.random(dtype=setup.numeric_dtype.numpy_dtype, out=src_arr) + + dst_arr = create_numpy_array_with_layout( + shape + (4,), layout=(2, 1, 0), dtype=setup.numeric_dtype.numpy_dtype + ) + dst_arr[:] = 0.0 + + check_arr = np.zeros_like(dst_arr) + check_arr[:, :, 0] = src_arr[:, :, 0] + src_arr[:, :, 1] + check_arr[:, :, 1] = src_arr[:, :, 0] - src_arr[:, :, 1] + check_arr[:, :, 2] = src_arr[:, :, 0] * src_arr[:, :, 1] + check_arr[:, :, 3] = src_arr[:, :, 0] / src_arr[:, :, 1] + + kernel(src=src_arr, dst=dst_arr) + + resolution = np.finfo(setup.numeric_dtype.numpy_dtype).resolution + gls = ghost_layers + + np.testing.assert_allclose( + dst_arr[gls:-gls, gls:-gls, :], + check_arr[gls:-gls, gls:-gls, :], + rtol=resolution, + ) + + if gls != 0: + for i in range(gls): + np.testing.assert_equal(dst_arr[i, :, :], 0.0) + np.testing.assert_equal(dst_arr[-i, :, :], 0.0) + np.testing.assert_equal(dst_arr[:, i, :], 0.0) + np.testing.assert_equal(dst_arr[:, -i, :], 0.0) + + +@pytest.mark.parametrize("setup", TEST_SETUPS, ids=TEST_IDS) +def test_trailing_iterations(setup: VectorTestSetup): + f = fields(f"f(1): {setup.numeric_dtype}[1D]", layout="fzyx") + + update = [Assignment(f(0), 2 * f(0))] + + kernel = create_vector_kernel(update, f, setup) + + for trailing_iters in range(setup.lanes): + shape = (setup.lanes * 12 + trailing_iters, 1) + f_arr = create_numpy_array_with_layout( + shape, layout=(1, 0), dtype=setup.numeric_dtype.numpy_dtype + ) + + f_arr[:] = 1.0 + + kernel(f=f_arr) + + np.testing.assert_equal(f_arr, 2.0) diff --git a/tests/nbackend/transformations/test_ast_vectorizer.py b/tests/nbackend/transformations/test_ast_vectorizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ea425349529c45b94317a98d2d9f305933c9ba60 --- /dev/null +++ b/tests/nbackend/transformations/test_ast_vectorizer.py @@ -0,0 +1,505 @@ +import sympy as sp +import pytest + +from pystencils import Assignment, TypedSymbol, fields, FieldType +from pystencils.sympyextensions import CastFunc, mem_acc +from pystencils.sympyextensions.pointers import AddressOf + +from pystencils.backend.constants import PsConstant +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + FullIterationSpace, + Typifier, +) +from pystencils.backend.transformations import ( + VectorizationAxis, + VectorizationContext, + AstVectorizer, +) +from pystencils.backend.ast import dfs_preorder +from pystencils.backend.ast.structural import PsBlock, PsDeclaration, PsAssignment +from pystencils.backend.ast.expressions import ( + PsSymbolExpr, + PsConstantExpr, + PsExpression, + PsCast, + PsMemAcc, + PsCall +) +from pystencils.backend.functions import CFunction +from pystencils.backend.ast.vector import PsVecBroadcast, PsVecMemAcc +from pystencils.backend.exceptions import VectorizationError +from pystencils.types import PsVectorType, deconstify, create_type + + +def test_vectorize_expressions(): + x, y, z, w = sp.symbols("x, y, z, w") + + ctx = KernelCreationContext() + factory = AstFactory(ctx) + typify = Typifier(ctx) + + for s in (x, y, z, w): + _ = factory.parse_sympy(s) + + ctr = ctx.get_symbol("ctr", ctx.index_dtype) + + axis = VectorizationAxis(ctr) + vc = VectorizationContext(ctx, 4, axis) + vc.vectorize_symbol(ctx.get_symbol("x")) + vc.vectorize_symbol(ctx.get_symbol("w")) + + vectorize = AstVectorizer(ctx) + + for expr in [ + factory.parse_sympy(-x * y + 13 * z - 4 * (x / w) * (x + z)), + factory.parse_sympy(sp.sin(x + z) - sp.cos(w)), + factory.parse_sympy(y**2 - x**2), + typify(- factory.parse_sympy(x / (w**2))), # place the negation outside, since SymPy would remove it + factory.parse_sympy(13 + (1 / w) - sp.exp(x) * 24), + ]: + vec_expr = vectorize.visit(expr, vc) + + # Must be a clone + assert vec_expr is not expr + + scalar_type = ctx.default_dtype + vector_type = PsVectorType(scalar_type, 4) + + for subexpr in dfs_preorder(vec_expr): + match subexpr: + case PsSymbolExpr(symb) if symb.name in "yz": + # These are not vectorized, but broadcast + assert symb.dtype == scalar_type + assert subexpr.dtype == scalar_type + case PsConstantExpr(c): + assert deconstify(c.get_dtype()) == scalar_type + assert subexpr.dtype == scalar_type + case PsSymbolExpr(symb): + assert symb.name not in "xw" + assert symb.get_dtype() == vector_type + assert subexpr.dtype == vector_type + case PsVecBroadcast(lanes, operand): + assert lanes == 4 + assert subexpr.dtype == vector_type + assert subexpr.dtype.scalar_type == operand.dtype + case PsExpression(): + # All other expressions are vectorized + assert subexpr.dtype == vector_type + + +def test_vectorize_casts_and_counter(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + ctr = ctx.get_symbol("ctr", ctx.index_dtype) + vec_ctr = ctx.get_symbol("vec_ctr", PsVectorType(ctx.index_dtype, 4)) + + vectorize = AstVectorizer(ctx) + + axis = VectorizationAxis(ctr, vec_ctr) + vc = VectorizationContext(ctx, 4, axis) + + expr = factory.parse_sympy(CastFunc(sp.Symbol("ctr"), create_type("float32"))) + vec_expr = vectorize.visit(expr, vc) + + assert isinstance(vec_expr, PsCast) + assert ( + vec_expr.dtype + == vec_expr.target_type + == PsVectorType(create_type("float32"), 4) + ) + + assert isinstance(vec_expr.operand, PsSymbolExpr) + assert vec_expr.operand.symbol == vec_ctr + assert vec_expr.operand.dtype == PsVectorType(ctx.index_dtype, 4) + + +def test_invalid_vectorization(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + typify = Typifier(ctx) + + ctr = ctx.get_symbol("ctr", ctx.index_dtype) + + vectorize = AstVectorizer(ctx) + + axis = VectorizationAxis(ctr) + vc = VectorizationContext(ctx, 4, axis) + + expr = factory.parse_sympy(CastFunc(sp.Symbol("ctr"), create_type("float32"))) + + with pytest.raises(VectorizationError): + # Fails since no vectorized counter was specified + _ = vectorize.visit(expr, vc) + + expr = PsExpression.make( + ctx.get_symbol("x_v", PsVectorType(create_type("float32"), 4)) + ) + + with pytest.raises(VectorizationError): + # Fails since this symbol is already vectorial + _ = vectorize.visit(expr, vc) + + func = CFunction("compute", [ctx.default_dtype], ctx.default_dtype) + expr = typify(PsCall(func, [PsExpression.make(ctx.get_symbol("x"))])) + + with pytest.raises(VectorizationError): + # Can't vectorize unknown function + _ = vectorize.visit(expr, vc) + + +def test_vectorize_declarations(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + x, y, z, w = sp.symbols("x, y, z, w") + ctr = TypedSymbol("ctr", ctx.index_dtype) + + vectorize = AstVectorizer(ctx) + + axis = VectorizationAxis( + ctx.get_symbol("ctr", ctx.index_dtype), + ctx.get_symbol("vec_ctr", PsVectorType(ctx.index_dtype, 4)), + ) + vc = VectorizationContext(ctx, 4, axis) + + block = PsBlock( + [ + factory.parse_sympy(asm) + for asm in [ + Assignment(x, CastFunc.as_numeric(ctr)), + Assignment(y, sp.cos(x)), + Assignment(z, x**2 + 2 * y / 4), + Assignment(w, -x + y - z), + ] + ] + ) + + vec_block = vectorize.visit(block, vc) + assert vec_block is not block + assert isinstance(vec_block, PsBlock) + + for symb_name, decl in zip("xyzw", vec_block.statements): + symb = ctx.get_symbol(symb_name) + assert symb in vc.vectorized_symbols + + assert isinstance(decl, PsDeclaration) + assert decl.declared_symbol == vc.vectorized_symbols[symb] + assert ( + decl.lhs.dtype + == decl.declared_symbol.dtype + == PsVectorType(ctx.default_dtype, 4) + ) + + +def test_duplicate_declarations(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + x, y = sp.symbols("x, y") + + vectorize = AstVectorizer(ctx) + + axis = VectorizationAxis( + ctx.get_symbol("ctr", ctx.index_dtype), + ) + vc = VectorizationContext(ctx, 4, axis) + + block = PsBlock( + [ + factory.parse_sympy(asm) + for asm in [ + Assignment(y, sp.cos(x)), + Assignment(y, 21), + ] + ] + ) + + with pytest.raises(VectorizationError): + _ = vectorize.visit(block, vc) + + +def test_reject_symbol_assignments(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + x, y = sp.symbols("x, y") + + vectorize = AstVectorizer(ctx) + + axis = VectorizationAxis( + ctx.get_symbol("ctr", ctx.index_dtype), + ) + vc = VectorizationContext(ctx, 4, axis) + + asm = PsAssignment(factory.parse_sympy(x), factory.parse_sympy(3 + y)) + + with pytest.raises(VectorizationError): + _ = vectorize.visit(asm, vc) + + +def test_vectorize_memory_assignments(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + typify = Typifier(ctx) + vectorize = AstVectorizer(ctx) + + x, y = sp.symbols("x, y") + + ctr = TypedSymbol("ctr", ctx.index_dtype) + i = TypedSymbol("i", ctx.index_dtype) + axis = VectorizationAxis( + ctx.get_symbol("ctr", ctx.index_dtype), + ) + vc = VectorizationContext(ctx, 4, axis) + + ptr = TypedSymbol("ptr", create_type("float64 *")) + + asm = typify( + PsAssignment( + factory.parse_sympy(mem_acc(ptr, 3 * ctr + 2)), + factory.parse_sympy(x + y * mem_acc(ptr, ctr + 3)) + ) + ) + + vec_asm = vectorize.visit(asm, vc) + assert isinstance(vec_asm, PsAssignment) + assert isinstance(vec_asm.lhs, PsVecMemAcc) + + field = fields("field(1): [2D]", field_type=FieldType.CUSTOM) + asm = factory.parse_sympy( + Assignment( + field.absolute_access((ctr, i), (0,)), + x + y * field.absolute_access((ctr + 1, i), (0,)), + ) + ) + + vec_asm = vectorize.visit(asm, vc) + assert isinstance(vec_asm, PsAssignment) + assert isinstance(vec_asm.lhs, PsVecMemAcc) + + +def test_invalid_memory_assignments(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + typify = Typifier(ctx) + vectorize = AstVectorizer(ctx) + + x, y = sp.symbols("x, y") + + ctr = TypedSymbol("ctr", ctx.index_dtype) + axis = VectorizationAxis( + ctx.get_symbol("ctr", ctx.index_dtype), + ) + vc = VectorizationContext(ctx, 4, axis) + + i = TypedSymbol("i", ctx.index_dtype) + + ptr = TypedSymbol("ptr", create_type("float64 *")) + + # Cannot vectorize assignment to LHS that does not depend on axis counter + asm = typify( + PsAssignment( + factory.parse_sympy(mem_acc(ptr, 3 * i + 2)), + factory.parse_sympy(x + y * mem_acc(ptr, ctr + 3)) + ) + ) + + with pytest.raises(VectorizationError): + _ = vectorize.visit(asm, vc) + + +def test_vectorize_mem_acc(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + typify = Typifier(ctx) + vectorize = AstVectorizer(ctx) + + ctr = TypedSymbol("ctr", ctx.index_dtype) + axis = VectorizationAxis( + ctx.get_symbol("ctr", ctx.index_dtype), + ) + vc = VectorizationContext(ctx, 4, axis) + + i = TypedSymbol("i", ctx.index_dtype) + j = TypedSymbol("j", ctx.index_dtype) + + ptr = TypedSymbol("ptr", create_type("float64 *")) + + # Lane-invariant index + acc = factory.parse_sympy(mem_acc(ptr, 3 * i + 5 * j)) + + vec_acc = vectorize.visit(acc, vc) + assert isinstance(vec_acc, PsVecBroadcast) + assert vec_acc.operand is not acc + assert vec_acc.operand.structurally_equal(acc) + + # Counter as index + acc = factory.parse_sympy(mem_acc(ptr, ctr)) + assert isinstance(acc, PsMemAcc) + + vec_acc = vectorize.visit(acc, vc) + assert isinstance(vec_acc, PsVecMemAcc) + assert vec_acc.pointer is not acc.pointer + assert vec_acc.pointer.structurally_equal(acc.pointer) + assert vec_acc.offset is not acc.offset + assert vec_acc.offset.structurally_equal(acc.offset) + assert vec_acc.stride is None + assert vec_acc.vector_entries == 4 + + # Simple affine + acc = factory.parse_sympy(mem_acc(ptr, 3 * i + 5 * ctr)) + assert isinstance(acc, PsMemAcc) + + vec_acc = vectorize.visit(acc, vc) + assert isinstance(vec_acc, PsVecMemAcc) + assert vec_acc.pointer is not acc.pointer + assert vec_acc.pointer.structurally_equal(acc.pointer) + assert vec_acc.offset is not acc.offset + assert vec_acc.offset.structurally_equal(acc.offset) + assert vec_acc.stride.structurally_equal(factory.parse_index(5)) + assert vec_acc.vector_entries == 4 + + # More complex, nested affine + acc = factory.parse_sympy(mem_acc(ptr, j * i + 2 * (5 + j * ctr) + 2 * ctr)) + assert isinstance(acc, PsMemAcc) + + vec_acc = vectorize.visit(acc, vc) + assert isinstance(vec_acc, PsVecMemAcc) + assert vec_acc.pointer is not acc.pointer + assert vec_acc.pointer.structurally_equal(acc.pointer) + assert vec_acc.offset is not acc.offset + assert vec_acc.offset.structurally_equal(acc.offset) + assert vec_acc.stride.structurally_equal(factory.parse_index(2 * j + 2)) + assert vec_acc.vector_entries == 4 + + # Even more complex affine + idx = - factory.parse_index(ctr) / factory.parse_index(i) - factory.parse_index(ctr) * factory.parse_index(j) + acc = typify(PsMemAcc(factory.parse_sympy(ptr), idx)) + assert isinstance(acc, PsMemAcc) + + vec_acc = vectorize.visit(acc, vc) + assert isinstance(vec_acc, PsVecMemAcc) + assert vec_acc.pointer is not acc.pointer + assert vec_acc.pointer.structurally_equal(acc.pointer) + assert vec_acc.offset is not acc.offset + assert vec_acc.offset.structurally_equal(acc.offset) + assert vec_acc.stride.structurally_equal(factory.parse_index(-1) / factory.parse_index(i) - factory.parse_index(j)) + assert vec_acc.vector_entries == 4 + + # Mixture of strides in affine and axis + vc = VectorizationContext(ctx, 4, VectorizationAxis(ctx.get_symbol("ctr"), step=factory.parse_index(3))) + + acc = factory.parse_sympy(mem_acc(ptr, 3 * i + 5 * ctr)) + assert isinstance(acc, PsMemAcc) + + vec_acc = vectorize.visit(acc, vc) + assert isinstance(vec_acc, PsVecMemAcc) + assert vec_acc.pointer is not acc.pointer + assert vec_acc.pointer.structurally_equal(acc.pointer) + assert vec_acc.offset is not acc.offset + assert vec_acc.offset.structurally_equal(acc.offset) + assert vec_acc.stride.structurally_equal(factory.parse_index(15)) + assert vec_acc.vector_entries == 4 + + +def test_invalid_mem_acc(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + vectorize = AstVectorizer(ctx) + + ctr = TypedSymbol("ctr", ctx.index_dtype) + axis = VectorizationAxis( + ctx.get_symbol("ctr", ctx.index_dtype), + ) + vc = VectorizationContext(ctx, 4, axis) + + i = TypedSymbol("i", ctx.index_dtype) + j = TypedSymbol("j", ctx.index_dtype) + ptr = TypedSymbol("ptr", create_type("float64 *")) + + # Non-symbol pointer + acc = factory.parse_sympy(mem_acc(AddressOf(mem_acc(ptr, 10)), 3 * i + ctr * (3 + ctr))) + + with pytest.raises(VectorizationError): + _ = vectorize.visit(acc, vc) + + # Non-affine index + acc = factory.parse_sympy(mem_acc(ptr, 3 * i + ctr * (3 + ctr))) + + with pytest.raises(VectorizationError): + _ = vectorize.visit(acc, vc) + + # Non lane-invariant index + vc.vectorize_symbol(ctx.get_symbol("j", ctx.index_dtype)) + + acc = factory.parse_sympy(mem_acc(ptr, 3 * j + ctr)) + + with pytest.raises(VectorizationError): + _ = vectorize.visit(acc, vc) + + +def test_vectorize_buffer_acc(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + vectorize = AstVectorizer(ctx) + + field = fields("f(3): [3D]", layout="fzyx") + ispace = FullIterationSpace.create_with_ghost_layers(ctx, 0, archetype_field=field) + ctx.set_iteration_space(ispace) + + ctr = ispace.dimensions_in_loop_order()[-1].counter + + axis = VectorizationAxis(ctr) + vc = VectorizationContext(ctx, 4, axis) + + buf = ctx.get_buffer(field) + + acc = factory.parse_sympy(field[-1, -1, -1](2)) + + # Buffer strides are symbolic -> expect strided access + vec_acc = vectorize.visit(acc, vc) + assert isinstance(vec_acc, PsVecMemAcc) + assert vec_acc.stride is not None + assert vec_acc.stride.structurally_equal(PsExpression.make(buf.strides[0])) + + # Set buffer stride to one + buf.strides[0] = PsConstant(1, dtype=ctx.index_dtype) + + # Expect non-strided access + vec_acc = vectorize.visit(acc, vc) + assert isinstance(vec_acc, PsVecMemAcc) + assert vec_acc.stride is None + + +def test_invalid_buffer_acc(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + vectorize = AstVectorizer(ctx) + + field = fields("field(3): [3D]", field_type=FieldType.CUSTOM) + + ctr, i, j = [TypedSymbol(n, ctx.index_dtype) for n in ("ctr", "i", "j")] + + axis = VectorizationAxis(ctx.get_symbol("ctr", ctx.index_dtype)) + vc = VectorizationContext(ctx, 4, axis) + + # Counter occurs in more than one index + acc = factory.parse_sympy(field.absolute_access((ctr, i, ctr + j), (1,))) + + with pytest.raises(VectorizationError): + _ = vectorize.visit(acc, vc) + + # Counter occurs in index dimension + acc = factory.parse_sympy(field.absolute_access((ctr, i, j), (ctr,))) + + with pytest.raises(VectorizationError): + _ = vectorize.visit(acc, vc) + + # Counter occurs quadratically + acc = factory.parse_sympy(field.absolute_access(((ctr + i) * ctr, i, j), (1,))) + + with pytest.raises(VectorizationError): + _ = vectorize.visit(acc, vc) diff --git a/tests/nbackend/transformations/test_constant_elimination.py b/tests/nbackend/transformations/test_constant_elimination.py index 4c18970086b4537dec2ae974ddf6242da57b591e..00df4a8a96a2fd898197ee1605985a3294de3275 100644 --- a/tests/nbackend/transformations/test_constant_elimination.py +++ b/tests/nbackend/transformations/test_constant_elimination.py @@ -1,3 +1,7 @@ +from typing import Any +import pytest +import numpy as np + from pystencils.backend.kernelcreation import KernelCreationContext, Typifier from pystencils.backend.ast.expressions import PsExpression, PsConstantExpr from pystencils.backend.memory import PsSymbol @@ -12,204 +16,298 @@ from pystencils.backend.ast.expressions import ( PsGt, PsTernary, PsRem, - PsIntDiv + PsIntDiv, ) from pystencils.types.quick import Int, Fp, Bool - -x, y, z = [PsExpression.make(PsSymbol(name, Fp(32))) for name in "xyz"] -p, q, r = [PsExpression.make(PsSymbol(name, Int(32))) for name in "pqr"] -a, b, c = [PsExpression.make(PsSymbol(name, Bool())) for name in "abc"] - -f3p5 = PsExpression.make(PsConstant(3.5, Fp(32))) -f42 = PsExpression.make(PsConstant(42, Fp(32))) - -f0 = PsExpression.make(PsConstant(0.0, Fp(32))) -f1 = PsExpression.make(PsConstant(1.0, Fp(32))) - -i0 = PsExpression.make(PsConstant(0, Int(32))) -i1 = PsExpression.make(PsConstant(1, Int(32))) -im1 = PsExpression.make(PsConstant(-1, Int(32))) - -i3 = PsExpression.make(PsConstant(3, Int(32))) -i4 = PsExpression.make(PsConstant(4, Int(32))) -im3 = PsExpression.make(PsConstant(-3, Int(32))) -i12 = PsExpression.make(PsConstant(12, Int(32))) - -true = PsExpression.make(PsConstant(True, Bool())) -false = PsExpression.make(PsConstant(False, Bool())) - - -def test_idempotence(): +from pystencils.types import PsVectorType, create_numeric_type + + +class Exprs: + def __init__(self, mode: str): + self._mode = mode + + if mode == "scalar": + self._itype = Int(32) + self._ftype = Fp(32) + self._btype = Bool() + else: + self._itype = PsVectorType(Int(32), 4) + self._ftype = PsVectorType(Fp(32), 4) + self._btype = PsVectorType(Bool(), 4) + + self.x, self.y, self.z = [ + PsExpression.make(PsSymbol(name, self._ftype)) for name in "xyz" + ] + self.p, self.q, self.r = [ + PsExpression.make(PsSymbol(name, self._itype)) for name in "pqr" + ] + self.a, self.b, self.c = [ + PsExpression.make(PsSymbol(name, self._btype)) for name in "abc" + ] + + self.true = PsExpression.make(PsConstant(True, self._btype)) + self.false = PsExpression.make(PsConstant(False, self._btype)) + + def __call__(self, val) -> Any: + match val: + case int(): + return PsExpression.make(PsConstant(val, self._itype)) + case float(): + return PsExpression.make(PsConstant(val, self._ftype)) + case np.ndarray(): + return PsExpression.make( + PsConstant( + val, PsVectorType(create_numeric_type(val.dtype), len(val)) + ) + ) + case _: + raise ValueError() + + +@pytest.fixture(scope="module", params=["scalar", "vector"]) +def exprs(request): + return Exprs(request.param) + + +def test_idempotence(exprs): + e = exprs ctx = KernelCreationContext() typify = Typifier(ctx) elim = EliminateConstants(ctx) - expr = typify(f42 * (f1 + f0) - f0) + expr = typify(e(42.0) * (e(1.0) + e(0.0)) - e(0.0)) result = elim(expr) - assert isinstance(result, PsConstantExpr) and result.structurally_equal(f42) + assert isinstance(result, PsConstantExpr) and result.structurally_equal(e(42.0)) - expr = typify((x + f0) * f3p5 + (f1 * y + f0) * f42) + expr = typify((e.x + e(0.0)) * e(3.5) + (e(1.0) * e.y + e(0.0)) * e(42.0)) result = elim(expr) - assert result.structurally_equal(x * f3p5 + y * f42) + assert result.structurally_equal(e.x * e(3.5) + e.y * e(42.0)) - expr = typify((f3p5 * f1) + (f42 * f1)) + expr = typify((e(3.5) * e(1.0)) + (e(42.0) * e(1.0))) result = elim(expr) # do not fold floats by default - assert expr.structurally_equal(f3p5 + f42) + assert expr.structurally_equal(e(3.5) + e(42.0)) - expr = typify(f1 * x + f0 + (f0 + f0 + f1 + f0) * y) + expr = typify(e(1.0) * e.x + e(0.0) + (e(0.0) + e(0.0) + e(1.0) + e(0.0)) * e.y) result = elim(expr) - assert result.structurally_equal(x + y) + assert result.structurally_equal(e.x + e.y) + + expr = typify(e(0.0) - e(3.2)) + result = elim(expr) + assert result.structurally_equal(-e(3.2)) -def test_int_folding(): +def test_int_folding(exprs): + e = exprs ctx = KernelCreationContext() typify = Typifier(ctx) elim = EliminateConstants(ctx) - expr = typify((i1 * p + i1 * -i3) + i1 * i12) + expr = typify((e(1) * e.p + e(1) * -e(3)) + e(1) * e(12)) result = elim(expr) - assert result.structurally_equal((p + im3) + i12) + assert result.structurally_equal((e.p + e(-3)) + e(12)) - expr = typify((i1 + i1 + i1 + i0 + i0 + i1) * (i1 + i1 + i1)) + expr = typify((e(1) + e(1) + e(1) + e(0) + e(0) + e(1)) * (e(1) + e(1) + e(1))) result = elim(expr) - assert result.structurally_equal(i12) + assert result.structurally_equal(e(12)) -def test_zero_dominance(): +def test_zero_dominance(exprs): + e = exprs ctx = KernelCreationContext() typify = Typifier(ctx) elim = EliminateConstants(ctx) - expr = typify((f0 * x) + (y * f0) + f1) + expr = typify((e(0.0) * e.x) + (e.y * e(0.0)) + e(1.0)) result = elim(expr) - assert result.structurally_equal(f1) + assert result.structurally_equal(e(1.0)) - expr = typify((i3 + i12 * (p + q) + p / (i3 * q)) * i0) + expr = typify((e(3) + e(12) * (e.p + e.q) + e.p / (e(3) * e.q)) * e(0)) result = elim(expr) - assert result.structurally_equal(i0) + assert result.structurally_equal(e(0)) -def test_divisions(): +def test_divisions(exprs): + e = exprs ctx = KernelCreationContext() typify = Typifier(ctx) elim = EliminateConstants(ctx) - expr = typify(f3p5 / f1) + expr = typify(e(3.5) / e(1.0)) result = elim(expr) - assert result.structurally_equal(f3p5) + assert result.structurally_equal(e(3.5)) - expr = typify(i3 / i1) + expr = typify(e(3) / e(1)) result = elim(expr) - assert result.structurally_equal(i3) + assert result.structurally_equal(e(3)) - expr = typify(PsRem(i3, i1)) + expr = typify(PsRem(e(3), e(1))) result = elim(expr) - assert result.structurally_equal(i0) + assert result.structurally_equal(e(0)) - expr = typify(PsIntDiv(i12, i3)) + expr = typify(PsIntDiv(e(12), e(3))) result = elim(expr) - assert result.structurally_equal(i4) + assert result.structurally_equal(e(4)) - expr = typify(i12 / i3) + expr = typify(e(12) / e(3)) result = elim(expr) - assert result.structurally_equal(i4) + assert result.structurally_equal(e(4)) - expr = typify(PsIntDiv(i4, i3)) + expr = typify(PsIntDiv(e(4), e(3))) result = elim(expr) - assert result.structurally_equal(i1) + assert result.structurally_equal(e(1)) - expr = typify(PsIntDiv(-i4, i3)) + expr = typify(PsIntDiv(-e(4), e(3))) result = elim(expr) - assert result.structurally_equal(im1) + assert result.structurally_equal(e(-1)) - expr = typify(PsIntDiv(i4, -i3)) + expr = typify(PsIntDiv(e(4), -e(3))) result = elim(expr) - assert result.structurally_equal(im1) + assert result.structurally_equal(e(-1)) - expr = typify(PsIntDiv(-i4, -i3)) + expr = typify(PsIntDiv(-e(4), -e(3))) result = elim(expr) - assert result.structurally_equal(i1) + assert result.structurally_equal(e(1)) - expr = typify(PsRem(i4, i3)) + expr = typify(PsRem(e(4), e(3))) result = elim(expr) - assert result.structurally_equal(i1) + assert result.structurally_equal(e(1)) - expr = typify(PsRem(-i4, i3)) + expr = typify(PsRem(-e(4), e(3))) result = elim(expr) - assert result.structurally_equal(im1) + assert result.structurally_equal(e(-1)) - expr = typify(PsRem(i4, -i3)) + expr = typify(PsRem(e(4), -e(3))) result = elim(expr) - assert result.structurally_equal(i1) + assert result.structurally_equal(e(1)) - expr = typify(PsRem(-i4, -i3)) + expr = typify(PsRem(-e(4), -e(3))) result = elim(expr) - assert result.structurally_equal(im1) + assert result.structurally_equal(e(-1)) -def test_boolean_folding(): +def test_fold_floats(exprs): + e = exprs + ctx = KernelCreationContext() + typify = Typifier(ctx) + elim = EliminateConstants(ctx, fold_floats=True) + + expr = typify(e(8.0) / e(2.0)) + result = elim(expr) + assert result.structurally_equal(e(4.0)) + + expr = typify(e(3.0) * e(12.0) / e(6.0)) + result = elim(expr) + assert result.structurally_equal(e(6.0)) + + +def test_boolean_folding(exprs): + e = exprs ctx = KernelCreationContext() typify = Typifier(ctx) elim = EliminateConstants(ctx) - expr = typify(PsNot(PsAnd(false, PsOr(true, a)))) + expr = typify(PsNot(PsAnd(e.false, PsOr(e.true, e.a)))) + result = elim(expr) + assert result.structurally_equal(e.true) + + expr = typify(PsOr(PsAnd(e.a, e.b), PsNot(e.false))) + result = elim(expr) + assert result.structurally_equal(e.true) + + expr = typify(PsAnd(e.c, PsAnd(e.true, PsAnd(e.a, PsOr(e.false, e.b))))) result = elim(expr) - assert result.structurally_equal(true) + assert result.structurally_equal(PsAnd(e.c, PsAnd(e.a, e.b))) - expr = typify(PsOr(PsAnd(a, b), PsNot(false))) + expr = typify(PsAnd(e.false, PsAnd(e.c, e.a))) result = elim(expr) - assert result.structurally_equal(true) + assert result.structurally_equal(e.false) - expr = typify(PsAnd(c, PsAnd(true, PsAnd(a, PsOr(false, b))))) + expr = typify(PsAnd(PsOr(e.a, e.false), e.false)) result = elim(expr) - assert result.structurally_equal(PsAnd(c, PsAnd(a, b))) + assert result.structurally_equal(e.false) -def test_relations_folding(): +def test_relations_folding(exprs): + e = exprs ctx = KernelCreationContext() typify = Typifier(ctx) elim = EliminateConstants(ctx) - expr = typify(PsGt(p * i0, - i1)) + expr = typify(PsGt(e.p * e(0), -e(1))) result = elim(expr) - assert result.structurally_equal(true) + assert result.structurally_equal(e.true) - expr = typify(PsEq(i1 + i1 + i1, i3)) + expr = typify(PsEq(e(1) + e(1) + e(1), e(3))) result = elim(expr) - assert result.structurally_equal(true) + assert result.structurally_equal(e.true) - expr = typify(PsEq(- i1, - i3)) + expr = typify(PsEq(-e(1), -e(3))) result = elim(expr) - assert result.structurally_equal(false) + assert result.structurally_equal(e.false) - expr = typify(PsEq(x + y, f1 * (x + y))) + expr = typify(PsEq(e.x + e.y, e(1.0) * (e.x + e.y))) result = elim(expr) - assert result.structurally_equal(true) + assert result.structurally_equal(e.true) - expr = typify(PsGt(x + y, f1 * (x + y))) + expr = typify(PsGt(e.x + e.y, e(1.0) * (e.x + e.y))) result = elim(expr) - assert result.structurally_equal(false) + assert result.structurally_equal(e.false) def test_ternary_folding(): + e = Exprs("scalar") + ctx = KernelCreationContext() typify = Typifier(ctx) elim = EliminateConstants(ctx) - expr = typify(PsTernary(true, x, y)) + expr = typify(PsTernary(e.true, e.x, e.y)) + result = elim(expr) + assert result.structurally_equal(e.x) + + expr = typify(PsTernary(e.false, e.x, e.y)) result = elim(expr) - assert result.structurally_equal(x) + assert result.structurally_equal(e.y) + + expr = typify( + PsTernary(PsGt(e(1), e(0)), PsTernary(PsEq(e(1), e(12)), e.x, e.y), e.z) + ) + result = elim(expr) + assert result.structurally_equal(e.y) + + expr = typify(PsTernary(PsGt(e.x, e.y), e.x + e(0.0), e.y * e(1.0))) + result = elim(expr) + assert result.structurally_equal(PsTernary(PsGt(e.x, e.y), e.x, e.y)) + + +def test_fold_vectors(): + e = Exprs("vector") + + ctx = KernelCreationContext() + typify = Typifier(ctx) + elim = EliminateConstants(ctx, fold_floats=True) - expr = typify(PsTernary(false, x, y)) + expr = typify( + e(np.array([1, 3, 2, -4])) + - e(np.array([5, -1, -2, 6])) * e(np.array([1, -1, 1, -1])) + ) result = elim(expr) - assert result.structurally_equal(y) + assert result.structurally_equal(e(np.array([-4, 2, 4, 2]))) - expr = typify(PsTernary(PsGt(i1, i0), PsTernary(PsEq(i1, i12), x, y), z)) + expr = typify( + e(np.array([3.0, 1.0, 2.0, 4.0])) * e(np.array([1.0, -1.0, 1.0, -1.0])) + + e(np.array([2.0, 3.0, 1.0, 4.0])) + ) result = elim(expr) - assert result.structurally_equal(y) + assert result.structurally_equal(e(np.array([5.0, 2.0, 3.0, 0.0]))) - expr = typify(PsTernary(PsGt(x, y), x + f0, y * f1)) + expr = typify( + PsOr( + PsNot(e(np.array([False, False, True, True]))), + e(np.array([False, True, False, True])), + ) + ) result = elim(expr) - assert result.structurally_equal(PsTernary(PsGt(x, y), x, y)) + assert result.structurally_equal(e(np.array([True, True, False, True])))