diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index ed94f9c8b1355ec94b4af7896ab273c58abaa990..8a66457a9f516f5089c75888c4cb753a296005b2 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -349,6 +349,14 @@ class PsCall(PsExpression): def function(self) -> PsFunction: return self._function + @function.setter + def function(self, func: PsFunction): + if func.arg_count != self._function.arg_count: + raise ValueError( + "Current and replacement function must have the same number of parameters." + ) + self._function = func + @property def args(self) -> tuple[PsExpression, ...]: return tuple(self._args) diff --git a/src/pystencils/backend/exceptions.py b/src/pystencils/backend/exceptions.py index 29434d8ef7643fef0dcc9ca530d0e6a63d13d908..4c081224913bfc61c4f542501ce8f4b5a1ddc59c 100644 --- a/src/pystencils/backend/exceptions.py +++ b/src/pystencils/backend/exceptions.py @@ -1,18 +1,21 @@ +"""Errors and Exceptions raised by the backend during kernel translation.""" + + class PsInternalCompilerError(Exception): - pass + """Indicates an internal error during kernel translation, most likely due to a bug inside pystencils.""" class PsOptionsError(Exception): - pass + """Indicates an option clash in the `CreateKernelConfig`.""" class PsInputError(Exception): - pass + """Indicates unsupported user input to the translation system""" class KernelConstraintsError(Exception): - pass + """Indicates a constraint violation in the symbolic kernel""" -class PsMalformedAstException(Exception): - pass +class MaterializationError(Exception): + """Indicates a fatal error during materialization of any abstract kernel component.""" diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index 39d5019ba4b14de11305563acaa1f9b26f2e60f9..313b622beaa151d294b8bcf6c66d730830ce2497 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -36,8 +36,8 @@ class MathFunctions(Enum): Abs = ("abs", 1) - Min = ("fmin", 2) - Max = ("fmax", 2) + Min = ("min", 2) + Max = ("max", 2) Pow = ("pow", 2) diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 08e87ed82e014434e8febbb2de912f09ab8b6ed1..a83bad4badf89a95f0d3efbd61c0bc355b534b2a 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -1,7 +1,11 @@ from typing import Sequence from abc import ABC, abstractmethod +from ..functions import CFunction, PsMathFunction, MathFunctions +from ...types import PsType, PsIeeeFloatType + from .platform import Platform +from ..exceptions import MaterializationError from ..kernelcreation.iteration_space import ( IterationSpace, @@ -23,6 +27,14 @@ from ..transformations.select_intrinsics import IntrinsicOps class GenericCpu(Platform): + """Generic CPU platform. + + The `GenericCPU` platform models the following execution environment: + + - Generic multicore CPU architecture + - Iteration space represented by a loop nest, kernels are executed as a whole + - C standard library math functions available (``#include <math.h>`` or ``#include <cmath>``) + """ @property def required_headers(self) -> set[str]: @@ -38,6 +50,27 @@ class GenericCpu(Platform): else: assert False, "unreachable code" + def select_function( + self, math_function: PsMathFunction, dtype: PsType + ) -> CFunction: + func = math_function.func + if isinstance(dtype, PsIeeeFloatType) and dtype.width in (32, 64): + match func: + case ( + MathFunctions.Exp + | MathFunctions.Sin + | MathFunctions.Cos + | MathFunctions.Tan + | MathFunctions.Pow + ): + return CFunction(func.function_name, func.num_args) + case MathFunctions.Abs | MathFunctions.Min | MathFunctions.Max: + return CFunction("f" + func.function_name, func.num_args) + + raise MaterializationError( + f"No implementation available for function {math_function} on data type {dtype}" + ) + # Internals def _create_domain_loops( @@ -94,21 +127,18 @@ class GenericCpu(Platform): return PsBlock([loop]) -class IntrinsicsError(Exception): - """Exception indicating a fatal error during intrinsic materialization.""" - - class GenericVectorCpu(GenericCpu, ABC): + """Base class for CPU platforms with vectorization support through intrinsics.""" @abstractmethod def type_intrinsic(self, vector_type: PsVectorType) -> PsCustomType: """Return the intrinsic vector type for the given generic vector type, - or raise an `IntrinsicsError` if type is not supported.""" + or raise an `MaterializationError` if type is not supported.""" @abstractmethod def constant_vector(self, c: PsConstant) -> PsExpression: """Return an expression that initializes a constant vector, - or raise an `IntrinsicsError` if not supported.""" + or raise an `MaterializationError` if not supported.""" @abstractmethod def op_intrinsic( @@ -116,14 +146,14 @@ class GenericVectorCpu(GenericCpu, ABC): ) -> PsExpression: """Return an expression intrinsically invoking the given operation on the given arguments with the given vector type, - or raise an `IntrinsicsError` if not supported.""" + or raise an `MaterializationError` if not supported.""" @abstractmethod def vector_load(self, acc: PsVectorArrayAccess) -> PsExpression: """Return an expression intrinsically performing a vector load, - or raise an `IntrinsicsError` if not supported.""" + or raise an `MaterializationError` if not supported.""" @abstractmethod def vector_store(self, acc: PsVectorArrayAccess, arg: PsExpression) -> PsExpression: """Return an expression intrinsically performing a vector store, - or raise an `IntrinsicsError` if not supported.""" + or raise an `MaterializationError` if not supported.""" diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 79ab6f9ec24a424c0ef08c1393f72fcf0898f8ff..64c0cd3e94b8bbae675417f165e9351a4d122f72 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -1,3 +1,5 @@ +from pystencils.backend.functions import CFunction, PsMathFunction +from pystencils.types.basic_types import PsType from .platform import Platform from ..kernelcreation.iteration_space import ( @@ -54,6 +56,9 @@ class GenericGpu(Platform): ] return indices[:dim] + + def select_function(self, math_function: PsMathFunction, dtype: PsType) -> CFunction: + raise NotImplementedError() # Internals def _guard_full_iteration_space( diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py index 3fedf7c01f6e186d0c92efa4141aa3b9e7fb6ccd..2c718ae5fd329197c3a67a26d51c9a737f63271f 100644 --- a/src/pystencils/backend/platforms/platform.py +++ b/src/pystencils/backend/platforms/platform.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from ..ast.structural import PsBlock +from ..functions import PsMathFunction, CFunction +from ...types import PsType from ..kernelcreation.context import KernelCreationContext from ..kernelcreation.iteration_space import IterationSpace @@ -28,3 +30,13 @@ class Platform(ABC): self, block: PsBlock, ispace: IterationSpace ) -> PsBlock: pass + + @abstractmethod + def select_function( + self, math_function: PsMathFunction, dtype: PsType + ) -> CFunction: + """Select an implementation for the given function on the given data type. + + If no viable implementation exists, raise a `MaterializationError`. + """ + pass diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py index fe5e1c76db8e106bb99324599b79e0e8d43c4da8..fa5af4655810943c47f503329a8c41ce3baa36c5 100644 --- a/src/pystencils/backend/platforms/x86.py +++ b/src/pystencils/backend/platforms/x86.py @@ -13,7 +13,8 @@ from ..transformations.select_intrinsics import IntrinsicOps from ...types import PsCustomType, PsVectorType from ..constants import PsConstant -from .generic_cpu import GenericVectorCpu, IntrinsicsError +from ..exceptions import MaterializationError +from .generic_cpu import GenericVectorCpu from ...types.quick import Fp, SInt from ..functions import CFunction @@ -46,7 +47,7 @@ class X86VectorArch(Enum): case 512 if self >= X86VectorArch.AVX512: prefix = "_mm512" case other: - raise IntrinsicsError( + raise MaterializationError( f"X86/{self} does not support vector width {other}" ) @@ -64,7 +65,7 @@ class X86VectorArch(Enum): case SInt(width): suffix = f"epi{width}" case _: - raise IntrinsicsError( + raise MaterializationError( f"X86/{self} does not support scalar type {scalar_type}" ) @@ -110,12 +111,12 @@ class X86VectorCpu(GenericVectorCpu): case SInt(_): suffix = "i" case _: - raise IntrinsicsError( + raise MaterializationError( f"X86/{self._vector_arch} does not support scalar type {scalar_type}" ) if vector_type.width > self._vector_arch.max_vector_width: - raise IntrinsicsError( + raise MaterializationError( f"X86/{self._vector_arch} does not support {vector_type}" ) return PsCustomType(f"__m{vector_type.width}{suffix}") diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index b4c0e8bbd32fe738b29828ac16a4bdd1401e4fae..8ef35e4fbcfd864be05c1d9bcedc7ee83b3d6a96 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -1,9 +1,11 @@ from .eliminate_constants import EliminateConstants from .erase_anonymous_structs import EraseAnonymousStructTypes +from .select_functions import SelectFunctions from .select_intrinsics import MaterializeVectorIntrinsics __all__ = [ "EliminateConstants", "EraseAnonymousStructTypes", + "SelectFunctions", "MaterializeVectorIntrinsics", ] diff --git a/src/pystencils/backend/transformations/select_functions.py b/src/pystencils/backend/transformations/select_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..c4085b2bc913a885cdd9f6a42de19a8f2c2ab404 --- /dev/null +++ b/src/pystencils/backend/transformations/select_functions.py @@ -0,0 +1,24 @@ +from ..platforms import Platform +from ..ast import PsAstNode +from ..ast.expressions import PsCall +from ..functions import PsMathFunction + + +class SelectFunctions: + """Traverse the AST to replace all instances of `PsMathFunction` by their implementation + provided by the given `Platform`.""" + + def __init__(self, platform: Platform): + self._platform = platform + + def __call__(self, node: PsAstNode) -> PsAstNode: + self.visit(node) + return node + + def visit(self, node: PsAstNode): + for c in node.children: + self.visit(c) + + if isinstance(node, PsCall) and isinstance(node.function, PsMathFunction): + impl = self._platform.select_function(node.function, node.get_dtype()) + node.function = impl diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 79cf92f8bf882749ade1afd0c8c686cdfb21d5a0..3f17e66bd7e8df097e48c1c2607e13d9844d9990 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -25,7 +25,7 @@ from .backend.kernelcreation.iteration_space import ( ) from .backend.ast.analysis import collect_required_headers, collect_undefined_symbols -from .backend.transformations import EraseAnonymousStructTypes, EliminateConstants +from .backend.transformations import EraseAnonymousStructTypes, EliminateConstants, SelectFunctions from .sympyextensions import AssignmentCollection, Assignment @@ -104,6 +104,9 @@ def create_kernel( from .backend.kernelcreation import optimize_cpu optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) + select_functions = SelectFunctions(platform) + kernel_ast = cast(PsBlock, select_functions(kernel_ast)) + assert config.jit is not None return create_kernel_function( ctx, kernel_ast, config.function_name, config.target, config.jit