diff --git a/src/pystencils/backend/arrays.py b/src/pystencils/backend/arrays.py index 233f5d275357083a0fce56e5f450a1c5e934d3d6..4cc3ad03780ee137363547a3b7d3635d96b5d70a 100644 --- a/src/pystencils/backend/arrays.py +++ b/src/pystencils/backend/arrays.py @@ -35,7 +35,6 @@ all occurences of the shape and stride variables with their constant value:: """ - from __future__ import annotations from sys import intern @@ -54,7 +53,7 @@ from .types import ( PsSignedIntegerType, PsScalarType, PsVectorType, - PsTypeError + PsTypeError, ) from .typed_expressions import PsTypedVariable, ExprOrConstant, PsTypedConstant @@ -301,34 +300,65 @@ class PsArrayAccess(pb.Subscript): class PsVectorArrayAccess(pb.AlgebraicLeaf): mapper_method = intern("map_vector_array_access") - def __init__(self, base_ptr: PsArrayBasePointer, base_index: ExprOrConstant, vector_width: int, stride: int = 1): + init_arg_names = ("base_ptr", "base_index", "vector_entries", "stride", "alignment") + + def __getinitargs__(self): + return ( + self._base_ptr, + self._base_index, + self._vector_type.vector_entries, + self._stride, + self._alignment, + ) + + def __init__( + self, + base_ptr: PsArrayBasePointer, + base_index: ExprOrConstant, + vector_entries: int, + stride: int = 1, + alignment: int = 0, + ): element_type = base_ptr.array.element_type if not isinstance(element_type, PsScalarType): - raise PsTypeError("Cannot generate vector accesses to arrays with non-scalar elements") + raise PsTypeError( + "Cannot generate vector accesses to arrays with non-scalar elements" + ) self._base_ptr = base_ptr self._base_index = base_index - self._vector_type = PsVectorType(element_type, vector_width, const=element_type.const) + self._vector_type = PsVectorType( + element_type, vector_entries, const=element_type.const + ) self._stride = stride + self._alignment = alignment @property def base_ptr(self) -> PsArrayBasePointer: return self._base_ptr - + @property def array(self) -> PsLinearizedArray: return self._base_ptr.array - + @property def base_index(self) -> ExprOrConstant: return self._base_index + @property + def vector_entries(self) -> int: + return self._vector_type.vector_entries + @property def dtype(self) -> PsVectorType: """Data type of this expression, i.e. the resulting generic vector type""" return self._vector_type - + @property def stride(self) -> int: return self._stride + + @property + def alignment(self) -> int: + return self._alignment diff --git a/src/pystencils/backend/ast/collectors.py b/src/pystencils/backend/ast/collectors.py index e64efa1a270bf0e198f8a483b1dc215021414954..e2488f95b87e084473b712dbe8e94d5ecc0cec42 100644 --- a/src/pystencils/backend/ast/collectors.py +++ b/src/pystencils/backend/ast/collectors.py @@ -7,12 +7,21 @@ from pymbolic.mapper import Collector from pymbolic.mapper.dependency import DependencyMapper from .kernelfunction import PsKernelFunction -from .nodes import PsAstNode, PsExpression, PsAssignment, PsDeclaration, PsLoop, PsBlock +from .nodes import ( + PsAstNode, + PsExpression, + PsStatement, + PsAssignment, + PsDeclaration, + PsLoop, + PsBlock, +) +from ..arrays import PsVectorArrayAccess from ..typed_expressions import PsTypedVariable, PsTypedConstant from ..exceptions import PsMalformedAstException, PsInternalCompilerError -class UndefinedVariablesCollector: +class UndefinedVariablesCollector(DependencyMapper): """Collector for undefined variables. This class implements an AST visitor that collects all `PsTypedVariable`s that have been used @@ -20,7 +29,7 @@ class UndefinedVariablesCollector: """ def __init__(self) -> None: - self._pb_dep_mapper = DependencyMapper( + super().__init__( include_subscripts=False, include_lookups=False, include_calls=False, @@ -37,7 +46,7 @@ class UndefinedVariablesCollector: return self(block) case PsExpression(expr): - variables: set[Variable] = self._pb_dep_mapper(expr) + variables: set[Variable] = self.rec(expr) for var in variables: if not isinstance(var, PsTypedVariable): @@ -47,6 +56,9 @@ class UndefinedVariablesCollector: return cast(set[PsTypedVariable], variables) + case PsStatement(expr): + return self(expr) + case PsAssignment(lhs, rhs): undefined_vars = self(lhs) | self(rhs) if isinstance(lhs.expression, PsTypedVariable): @@ -62,7 +74,7 @@ class UndefinedVariablesCollector: case PsLoop(ctr, start, stop, step, body): undefined_vars = self(start) | self(stop) | self(step) | self(body) - undefined_vars.remove(ctr.symbol) + undefined_vars.discard(ctr) return undefined_vars case unknown: @@ -77,7 +89,7 @@ class UndefinedVariablesCollector: case PsDeclaration(lhs, _): return {lhs.symbol} - case PsAssignment() | PsExpression() | PsLoop() | PsBlock(): + case PsStatement() | PsAssignment() | PsExpression() | PsLoop() | PsBlock(): return set() case unknown: @@ -85,6 +97,11 @@ class UndefinedVariablesCollector: f"Don't know how to collect declared variables from {unknown}" ) + def map_vector_array_access( + self, vacc: PsVectorArrayAccess + ) -> set[PsTypedVariable]: + return {vacc.base_ptr} | self.rec(vacc.base_index) + def collect_undefined_variables(node: PsAstNode) -> set[PsTypedVariable]: return UndefinedVariablesCollector()(node) diff --git a/src/pystencils/backend/ast/kernelfunction.py b/src/pystencils/backend/ast/kernelfunction.py index dd3029a254c5a84f4dcdb5dcb600d4a4e9a9ed66..deca2cf187b37ad2e9c6630eb52fbf07bc3945cb 100644 --- a/src/pystencils/backend/ast/kernelfunction.py +++ b/src/pystencils/backend/ast/kernelfunction.py @@ -8,7 +8,6 @@ from pymbolic.mapper.dependency import DependencyMapper from .nodes import PsAstNode, PsBlock, failing_cast from ..constraints import PsKernelConstraint -from ..platforms import Platform from ..typed_expressions import PsTypedVariable from ..arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAssocVar from ..jit import JitBase, no_jit @@ -70,7 +69,12 @@ class PsKernelFunction(PsAstNode): __match_args__ = ("body",) def __init__( - self, body: PsBlock, target: Target, name: str, required_headers: set[str], jit: JitBase = no_jit + self, + body: PsBlock, + target: Target, + name: str, + required_headers: set[str], + jit: JitBase = no_jit, ): self._body: PsBlock = body self._target = target @@ -110,7 +114,7 @@ class PsKernelFunction(PsAstNode): def instruction_set(self) -> str | None: """For backward compatibility""" return None - + @property def required_headers(self) -> set[str]: return self._required_headers diff --git a/src/pystencils/backend/ast/nodes.py b/src/pystencils/backend/ast/nodes.py index 03f1fa2976c236d07f0998450c008e7ac7410b65..f3609ffb71a6b14a71ef9c047919fe914e202136 100644 --- a/src/pystencils/backend/ast/nodes.py +++ b/src/pystencils/backend/ast/nodes.py @@ -7,7 +7,7 @@ from pymbolic.primitives import Variable from abc import ABC, abstractmethod from ..typed_expressions import ExprOrConstant -from ..arrays import PsArrayAccess +from ..arrays import PsArrayAccess, PsVectorArrayAccess from .util import failing_cast @@ -31,11 +31,11 @@ class PsAstNode(ABC): @abstractmethod def get_children(self) -> tuple[PsAstNode, ...]: - ... + pass @abstractmethod def set_child(self, idx: int, c: PsAstNode): - ... + pass def __eq__(self, other: object) -> bool: if not isinstance(other, PsAstNode): @@ -112,7 +112,7 @@ class PsLvalueExpr(PsExpression): """Wrapper around pymbolics expressions that may occur at the left-hand side of an assignment""" def __init__(self, expr: PsLvalue): - if not isinstance(expr, (Variable, PsArrayAccess)): + if not isinstance(expr, (Variable, PsArrayAccess, PsVectorArrayAccess)): raise TypeError("Expression was not a valid lvalue") super(PsLvalueExpr, self).__init__(expr) @@ -136,7 +136,7 @@ class PsSymbolExpr(PsLvalueExpr): class PsStatement(PsAstNode): - __match_args__ = ("expression") + __match_args__ = ("expression",) def __init__(self, expr: PsExpression): self._expression = expr @@ -144,21 +144,21 @@ class PsStatement(PsAstNode): @property def expression(self) -> PsExpression: return self._expression - + @expression.setter def expression(self, expr: PsExpression): self._expression = expr def get_children(self) -> tuple[PsAstNode, ...]: return (self._expression,) - + def set_child(self, idx: int, c: PsAstNode): idx = [0][idx] assert idx == 0 self._expression = failing_cast(PsExpression, c) -PsLvalue: TypeAlias = Variable | PsArrayAccess +PsLvalue: TypeAlias = Variable | PsArrayAccess | PsVectorArrayAccess """Types of expressions that may occur on the left-hand side of assignments.""" diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index 4c18bd1c945b5b022393863a148e259e8e14e3f0..3d1a2127902285e4c6d54ac27fb6d5f7e751ccc9 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -16,11 +16,10 @@ from .ast import ( ) from .ast.kernelfunction import PsKernelFunction from .typed_expressions import PsTypedVariable -from .functions import Deref, AddressOf, Cast +from .functions import Deref, AddressOf, Cast, CFunction def emit_code(kernel: PsKernelFunction): - # TODO: Specialize for different targets printer = CAstPrinter() return printer.print(kernel) @@ -35,6 +34,9 @@ class CExpressionsPrinter(CCodeMapper): def map_cast(self, cast: Cast, enclosing_prec): return f"({cast.target_type.c_string()})" + def map_c_function(self, func: CFunction, enclosing_prec): + return func.qualified_name + class CAstPrinter: def __init__(self, indent_width=3): @@ -70,14 +72,14 @@ class CAstPrinter: return self.indent("{ }") self._current_indent_level += self._indent_width - interior = "\n".join(self.visit(c) for c in block.children) + interior = "\n".join(self.visit(c) for c in block.children) + "\n" self._current_indent_level -= self._indent_width return self.indent("{\n") + interior + self.indent("}\n") @visit.case(PsExpression) def pymb_expression(self, expr: PsExpression): return self._expr_printer(expr.expression) - + @visit.case(PsStatement) def statement(self, stmt: PsStatement): return self.indent(f"{self.visit(stmt.expression)};") diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index e2474e52e7c62496acc1fd4ddb119fa5c56a9b62..90bc09c5ef1dd647fca273d6a45b2d279410ff6e 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -45,6 +45,9 @@ class MathFunctions(Enum): class PsFunction(pb.FunctionSymbol, ABC): + + mapper_method = intern("map_ps_function") + @property @abstractmethod def arg_count(self) -> int: @@ -54,6 +57,8 @@ class PsFunction(pb.FunctionSymbol, ABC): class CFunction(PsFunction): """A concrete C function.""" + mapper_method = intern("map_c_function") + def __init__(self, qualified_name: str, arg_count: int): self._qname = qualified_name self._arg_count = arg_count diff --git a/src/pystencils/backend/jit/legacy_cpu.py b/src/pystencils/backend/jit/legacy_cpu.py index df8eab673b01d06ba6cf2f6f769268affeb36adf..6a7e63d147623ef983dd6f691ef134cdd8a0b75c 100644 --- a/src/pystencils/backend/jit/legacy_cpu.py +++ b/src/pystencils/backend/jit/legacy_cpu.py @@ -43,6 +43,7 @@ Then 'cl.exe' is used to compile. For Windows compilers the qualifier should be ``__restrict`` """ + from appdirs import user_cache_dir, user_config_dir from collections import OrderedDict import importlib.util diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index c241ce2a7e9c7db5328d9c5c5acc11cb43bc8e9d..bdcf5120a3f943f50e167ec3f4450e6341fd646c 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -20,7 +20,7 @@ from ..ast.nodes import ( ) from ..types import constify, make_type, PsStructType from ..typed_expressions import PsTypedVariable -from ..arrays import PsArrayAccess +from ..arrays import PsArrayAccess, PsVectorArrayAccess from ..exceptions import PsInputError from ..functions import PsMathFunction, MathFunctions @@ -64,7 +64,7 @@ class FreezeExpressions(SympyToPymbolicMapper): if isinstance(lhs, pb.Variable): return PsDeclaration(PsSymbolExpr(lhs), PsExpression(rhs)) - elif isinstance(lhs, PsArrayAccess): + elif isinstance(lhs, (PsArrayAccess, PsVectorArrayAccess)): # todo return PsAssignment(PsLvalueExpr(lhs), PsExpression(rhs)) else: assert False, "That should not have happened." diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index adb4c0cf739beff4208773bfa232f7fab40bdd82..d251f0acd558c3348a80bd169bd24f86b496aa9c 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -8,7 +8,7 @@ from pymbolic.mapper import Mapper from .context import KernelCreationContext from ..types import PsAbstractType, PsNumericType, PsStructType, deconstify from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant -from ..arrays import PsArrayAccess +from ..arrays import PsArrayAccess, PsVectorArrayAccess from ..ast import PsAstNode, PsBlock, PsExpression, PsAssignment from ..functions import PsMathFunction @@ -178,6 +178,15 @@ class Typifier(Mapper): index = self.rec(access.index_tuple[0], TypeContext(self._ctx.index_dtype)) return PsArrayAccess(access.base_ptr, index) + def map_vector_array_access( + self, access: PsVectorArrayAccess, tc: TypeContext + ) -> PsVectorArrayAccess: + self._apply_target_type(access, access.dtype, tc) + base_index = self.rec(access.base_index, TypeContext(self._ctx.index_dtype)) + return PsVectorArrayAccess( + access.base_ptr, base_index, access.dtype.vector_entries, access.stride + ) + def map_lookup(self, lookup: pb.Lookup, tc: TypeContext) -> pb.Lookup: aggr_tc = TypeContext(None) aggregate = self.rec(lookup.aggregate, aggr_tc) diff --git a/src/pystencils/backend/platforms/__init__.py b/src/pystencils/backend/platforms/__init__.py index 7de61a12bbe6fc0b42f3282993ae699b7cfee9f0..61db873d20e6e06da0dc2ca90876d979b9302b2e 100644 --- a/src/pystencils/backend/platforms/__init__.py +++ b/src/pystencils/backend/platforms/__init__.py @@ -1,4 +1,11 @@ from .platform import Platform from .generic_cpu import GenericCpu, GenericVectorCpu +from .x86 import X86VectorCpu, X86VectorArch -__all__ = ["Platform", "GenericCpu", "GenericVectorCpu"] +__all__ = [ + "Platform", + "GenericCpu", + "GenericVectorCpu", + "X86VectorCpu", + "X86VectorArch", +] diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py index ddac626157ca85b94f7d289a9461b87f4d3703cb..d6bf54a570f1cbe8e63c6359d080ec712efacad1 100644 --- a/src/pystencils/backend/platforms/platform.py +++ b/src/pystencils/backend/platforms/platform.py @@ -27,8 +27,8 @@ class Platform(ABC): def materialize_iteration_space( self, block: PsBlock, ispace: IterationSpace ) -> PsBlock: - ... + pass @abstractmethod def optimize(self, kernel: PsBlock) -> PsBlock: - ... + pass diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py new file mode 100644 index 0000000000000000000000000000000000000000..c4a9f0d4ce8f27a730fd66b42124e299ccba1c45 --- /dev/null +++ b/src/pystencils/backend/platforms/x86.py @@ -0,0 +1,192 @@ +from __future__ import annotations +from enum import Enum +from functools import cache +from typing import Sequence + +from pymbolic.primitives import Expression + +from ..arrays import PsVectorArrayAccess +from ..transformations.vector_intrinsics import IntrinsicOps +from ..typed_expressions import PsTypedConstant +from ..types import PsCustomType, PsVectorType +from ..functions import address_of + +from .generic_cpu import GenericVectorCpu, IntrinsicsError + +from ..types.quick import Fp, SInt +from ..functions import CFunction + + +class X86VectorArch(Enum): + SSE = 128 + AVX = 256 + AVX512 = 512 + + def __ge__(self, other: X86VectorArch) -> bool: + return self.value >= other.value + + def __gt__(self, other: X86VectorArch) -> bool: + return self.value > other.value + + def __str__(self) -> str: + return self.name + + @property + def max_vector_width(self) -> int: + return self.value + + def intrin_prefix(self, vtype: PsVectorType) -> str: + match vtype.width: + case 128 if self >= X86VectorArch.SSE: + prefix = "_mm" + case 256 if self >= X86VectorArch.AVX: + prefix = "_mm256" + case 512 if self >= X86VectorArch.AVX512: + prefix = "_mm512" + case other: + raise IntrinsicsError( + f"X86/{self} does not support vector width {other}" + ) + + return prefix + + def intrin_suffix(self, vtype: PsVectorType) -> str: + scalar_type = vtype.scalar_type + match scalar_type: + case Fp(16) if self >= X86VectorArch.AVX512: + suffix = "ph" + case Fp(32): + suffix = "ps" + case Fp(64): + suffix = "pd" + case SInt(width): + suffix = f"epi{width}" + case _: + raise IntrinsicsError( + f"X86/{self} does not support scalar type {scalar_type}" + ) + + return suffix + + +class X86VectorCpu(GenericVectorCpu): + """Platform modelling the X86 SSE/AVX/AVX512 vector architectures. + + All intrinsics information is extracted from + https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html. + """ + + def __init__(self, vector_arch: X86VectorArch): + self._vector_arch = vector_arch + + @property + def required_headers(self) -> set[str]: + if self._vector_arch == X86VectorArch.SSE: + headers = { + "<immintrin.h>", + "<xmmintrin.h>", + "<emmintrin.h>", + "<pmmintrin.h>", + "<tmmintrin.h>", + "<smmintrin.h>", + "<nmmintrin.h>", + } + else: + headers = {"<immintrin.h>"} + + return super().required_headers | headers + + def type_intrinsic(self, vector_type: PsVectorType) -> PsCustomType: + scalar_type = vector_type.scalar_type + match scalar_type: + case Fp(16) if self._vector_arch >= X86VectorArch.AVX512: + suffix = "h" + case Fp(32): + suffix = "" + case Fp(64): + suffix = "d" + case SInt(_): + suffix = "i" + case _: + raise IntrinsicsError( + f"X86/{self._vector_arch} does not support scalar type {scalar_type}" + ) + + if vector_type.width > self._vector_arch.max_vector_width: + raise IntrinsicsError( + f"X86/{self._vector_arch} does not support {vector_type}" + ) + return PsCustomType(f"__m{vector_type.width}{suffix}") + + def constant_vector(self, c: PsTypedConstant) -> Expression: + vtype = c.dtype + assert isinstance(vtype, PsVectorType) + + prefix = self._vector_arch.intrin_prefix(vtype) + suffix = self._vector_arch.intrin_suffix(vtype) + set_func = CFunction(f"{prefix}_set_{suffix}", vtype.vector_entries) + + values = c.value + return set_func(*values) + + def op_intrinsic( + self, op: IntrinsicOps, vtype: PsVectorType, args: Sequence[Expression] + ) -> Expression: + func = _x86_op_intrin(self._vector_arch, op, vtype) + return func(*args) + + def vector_load(self, acc: PsVectorArrayAccess) -> Expression: + if acc.stride == 1: + load_func = _x86_packed_load(self._vector_arch, acc.dtype, False) + return load_func(address_of(acc.base_ptr[acc.base_index])) + else: + raise NotImplementedError("Gather loads not implemented yet.") + + def vector_store(self, acc: PsVectorArrayAccess, arg: Expression) -> Expression: + if acc.stride == 1: + store_func = _x86_packed_store(self._vector_arch, acc.dtype, False) + return store_func(address_of(acc.base_ptr[acc.base_index]), arg) + else: + raise NotImplementedError("Scatter stores not implemented yet.") + + +@cache +def _x86_packed_load( + varch: X86VectorArch, vtype: PsVectorType, aligned: bool +) -> CFunction: + prefix = varch.intrin_prefix(vtype) + suffix = varch.intrin_suffix(vtype) + return CFunction(f"{prefix}_load{'' if aligned else 'u'}_{suffix}", 1) + + +@cache +def _x86_packed_store( + varch: X86VectorArch, vtype: PsVectorType, aligned: bool +) -> CFunction: + prefix = varch.intrin_prefix(vtype) + suffix = varch.intrin_suffix(vtype) + return CFunction(f"{prefix}_store{'' if aligned else 'u'}_{suffix}", 2) + + +@cache +def _x86_op_intrin( + varch: X86VectorArch, op: IntrinsicOps, vtype: PsVectorType +) -> CFunction: + prefix = varch.intrin_prefix(vtype) + suffix = varch.intrin_suffix(vtype) + + match op: + case IntrinsicOps.ADD: + opstr = "add" + case IntrinsicOps.SUB: + opstr = "sub" + case IntrinsicOps.MUL: + opstr = "mul" + case IntrinsicOps.DIV: + opstr = "div" + case IntrinsicOps.FMA: + opstr = "fmadd" + case _: + assert False + + return CFunction(f"{prefix}_{opstr}_{suffix}", 3 if op == IntrinsicOps.FMA else 2) diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index 1abd074af598a8b86e34a705f8e10410dc202664..c630e580ba29f8bd75bf9e779ff8b05f8d5cb541 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -1,5 +1,4 @@ from .erase_anonymous_structs import EraseAnonymousStructTypes +from .vector_intrinsics import MaterializeVectorIntrinsics -__all__ = [ - "EraseAnonymousStructTypes" -] \ No newline at end of file +__all__ = ["EraseAnonymousStructTypes", "MaterializeVectorIntrinsics"] diff --git a/src/pystencils/backend/transformations/vector_intrinsics.py b/src/pystencils/backend/transformations/vector_intrinsics.py index 4cfd112c978bcd4713e7cf250075a35dd0eb1969..03c29015030e8cd273a46fd8534425c7f6ec4656 100644 --- a/src/pystencils/backend/transformations/vector_intrinsics.py +++ b/src/pystencils/backend/transformations/vector_intrinsics.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import TypeVar, TYPE_CHECKING from enum import Enum, auto @@ -5,7 +6,7 @@ import pymbolic.primitives as pb from pymbolic.mapper import IdentityMapper from ..ast import PsAstNode, PsExpression, PsAssignment, PsStatement -from ..types import PsVectorType +from ..types import PsVectorType, deconstify from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant from ..arrays import PsVectorArrayAccess from ..exceptions import PsInternalCompilerError @@ -39,8 +40,11 @@ class VecTypeCtx: return self._dtype def set(self, dtype: PsVectorType): - if self._dtype is not None: - raise PsInternalCompilerError("Ambiguous vector types.") + dtype = deconstify(dtype) + if self._dtype is not None and dtype != self._dtype: + raise PsInternalCompilerError( + f"Ambiguous vector types: {self._dtype} and {dtype}" + ) self._dtype = dtype def reset(self): @@ -57,35 +61,42 @@ class MaterializeVectorIntrinsics(IdentityMapper): # descend into expr node.expression = self.rec(expr, VecTypeCtx()) return node - case PsAssignment(lhs, rhs) if isinstance(lhs.expression, PsVectorArrayAccess): + case PsAssignment(lhs, rhs) if isinstance( + lhs.expression, PsVectorArrayAccess + ): vc = VecTypeCtx() vc.set(lhs.expression.dtype) store_arg = self.rec(rhs.expression, vc) - return PsStatement(PsExpression(self._platform.vector_store(lhs.expression, store_arg))) + return PsStatement( + PsExpression(self._platform.vector_store(lhs.expression, store_arg)) + ) case other: - for c in other.children: - self(c) - return node + other.children = (self(c) for c in other.children) + return node - def map_typed_variable(self, tv: PsTypedVariable, vc: VecTypeCtx) -> PsTypedVariable: + def map_typed_variable( + self, tv: PsTypedVariable, vc: VecTypeCtx + ) -> PsTypedVariable: if isinstance(tv.dtype, PsVectorType): intrin_type = self._platform.type_intrinsic(tv.dtype) vc.set(tv.dtype) return PsTypedVariable(tv.name, intrin_type) else: return tv - + def map_constant(self, c: PsTypedConstant, vc: VecTypeCtx) -> ExprOrConstant: if isinstance(c.dtype, PsVectorType): vc.set(c.dtype) return self._platform.constant_vector(c) else: return c - - def map_vector_array_access(self, acc: PsVectorArrayAccess, vc: VecTypeCtx) -> pb.Expression: + + def map_vector_array_access( + self, acc: PsVectorArrayAccess, vc: VecTypeCtx + ) -> pb.Expression: vc.set(acc.dtype) return self._platform.vector_load(acc) - + def map_sum(self, expr: pb.Sum, vc: VecTypeCtx) -> pb.Expression: args = [self.rec(arg, vc) for arg in expr.children] vtype = vc.get() @@ -95,7 +106,7 @@ class MaterializeVectorIntrinsics(IdentityMapper): return self._platform.op_intrinsic(IntrinsicOps.ADD, vtype, args) else: return expr - + def map_product(self, expr: pb.Product, vc: VecTypeCtx) -> pb.Expression: args = [self.rec(arg, vc) for arg in expr.children] vtype = vc.get() diff --git a/src/pystencils/backend/types/basic_types.py b/src/pystencils/backend/types/basic_types.py index 178eeafd7b0e14b4a3259dac75d102cb215abd71..8b7fdf794facbe78443cff9ac5a8f98f9c7a1a25 100644 --- a/src/pystencils/backend/types/basic_types.py +++ b/src/pystencils/backend/types/basic_types.py @@ -66,20 +66,23 @@ class PsAbstractType(ABC): return "const " if self._const else "" @abstractmethod - def c_string(self) -> str: ... + def c_string(self) -> str: + pass # ------------------------------------------------------------------------------------------- # Dunder Methods # ------------------------------------------------------------------------------------------- @abstractmethod - def __eq__(self, other: object) -> bool: ... + def __eq__(self, other: object) -> bool: + pass def __str__(self) -> str: return self.c_string() @abstractmethod - def __hash__(self) -> int: ... + def __hash__(self) -> int: + pass class PsCustomType(PsAbstractType): @@ -272,16 +275,20 @@ class PsNumericType(PsAbstractType, ABC): """ @abstractmethod - def is_int(self) -> bool: ... + def is_int(self) -> bool: + pass @abstractmethod - def is_sint(self) -> bool: ... + def is_sint(self) -> bool: + pass @abstractmethod - def is_uint(self) -> bool: ... + def is_uint(self) -> bool: + pass @abstractmethod - def is_float(self) -> bool: ... + def is_float(self) -> bool: + pass class PsScalarType(PsNumericType, ABC): @@ -295,6 +302,11 @@ class PsScalarType(PsNumericType, ABC): PsTypeError: If the given value's type is not the numeric type's compiler-internal representation. """ + @property + @abstractmethod + def width(self) -> int: + """Return this type's width in bits.""" + def is_int(self) -> bool: return isinstance(self, PsIntegerType) @@ -317,10 +329,10 @@ class PsVectorType(PsNumericType): """ def __init__( - self, scalar_type: PsScalarType, vector_width: int, const: bool = False + self, scalar_type: PsScalarType, vector_entries: int, const: bool = False ): super().__init__(const) - self._vector_width = vector_width + self._vector_entries = vector_entries self._scalar_type = constify(scalar_type) if const else deconstify(scalar_type) @property @@ -328,8 +340,12 @@ class PsVectorType(PsNumericType): return self._scalar_type @property - def vector_width(self) -> int: - return self._vector_width + def vector_entries(self) -> int: + return self._vector_entries + + @property + def width(self) -> int: + return self._scalar_type.width * self._vector_entries def is_int(self) -> bool: return self._scalar_type.is_int() @@ -348,23 +364,23 @@ class PsVectorType(PsNumericType): if self._scalar_type.itemsize is None: return None else: - return self._vector_width * self._scalar_type.itemsize + return self._vector_entries * self._scalar_type.itemsize @property def numpy_dtype(self): - return np.dtype((self._scalar_type.numpy_dtype, (self._vector_width,))) + return np.dtype((self._scalar_type.numpy_dtype, (self._vector_entries,))) def create_constant(self, value: Any) -> Any: if ( isinstance(value, np.ndarray) and value.dtype == self.scalar_type.numpy_dtype - and value.shape == (self._vector_width,) + and value.shape == (self._vector_entries,) ): return value.copy() element = self._scalar_type.create_constant(value) return np.array( - [element] * self._vector_width, dtype=self.scalar_type.numpy_dtype + [element] * self._vector_entries, dtype=self.scalar_type.numpy_dtype ) def __eq__(self, other: object) -> bool: @@ -374,12 +390,12 @@ class PsVectorType(PsNumericType): return ( self._base_equal(other) and self._scalar_type == other._scalar_type - and self._vector_width == other._vector_width + and self._vector_entries == other._vector_entries ) def __hash__(self) -> int: return hash( - ("PsVectorType", self._scalar_type, self._vector_width, self._const) + ("PsVectorType", self._scalar_type, self._vector_entries, self._const) ) def c_string(self) -> str: @@ -388,10 +404,13 @@ class PsVectorType(PsNumericType): ) def __str__(self) -> str: - return f"vector[{self._scalar_type}, {self._vector_width}]" + return f"vector[{self._scalar_type}, {self._vector_entries}]" def __repr__(self) -> str: - return f"PsVectorType( scalar_type={repr(self._scalar_type)}, vector_width={self._vector_width}, const={self.const} )" + return ( + f"PsVectorType( scalar_type={repr(self._scalar_type)}, " + f"vector_width={self._vector_entries}, const={self.const} )" + ) class PsIntegerType(PsScalarType, ABC):