From f0de552fe1f21b5d964430ad75c28c200fec0f9e Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 27 Mar 2024 17:16:39 +0100 Subject: [PATCH] Add support for function materialization. - Add `select_function` to `Platform` and implement it for `GenericCpu` - Add `SelectFunctions` AST pass Squashed commit of the following: commit e6b3aa856d2471b645f39d20673eb5259f7b1e1c Author: Frederik Hennig <frederik.hennig@fau.de> Date: Wed Mar 27 17:15:55 2024 +0100 add select_function stub to GenericGpu commit 679428f053d6bfd3a5ba4d1d1ff6c903577e9ed3 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Wed Mar 27 17:14:16 2024 +0100 Remove IntrinsicsError in favor of MaterializaitonError. Move the latter to `backend.exceptions`. commit 1a01cfde2f39b3394935c2a313312097d4cf66e0 Merge: d73d24c 0e4677d Author: Frederik Hennig <frederik.hennig@fau.de> Date: Wed Mar 27 17:07:31 2024 +0100 Merge branch 'backend-rework' into b_function_mat commit d73d24cf038950f2129e717a43dd09df11d725c0 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Tue Mar 26 18:19:08 2024 +0100 introduce SelectFunctions and select_function protocol commit 671f0578a39e452504243019dab28d93f0114082 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Tue Mar 26 16:39:43 2024 +0100 Fix documentation for Typifier and PsExpression commit 3ec258517ad8a510118265184b5dc7805128dcd3 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Mon Mar 25 17:14:21 2024 +0100 Typing refactor: - Annotate all expressions with types - Refactor Typifier for cleaner information flow and better readability - Have iteration space and transformers typify newly created AST nodes --- src/pystencils/backend/ast/expressions.py | 8 ++++ src/pystencils/backend/exceptions.py | 15 +++--- src/pystencils/backend/functions.py | 4 +- .../backend/platforms/generic_cpu.py | 48 +++++++++++++++---- .../backend/platforms/generic_gpu.py | 5 ++ src/pystencils/backend/platforms/platform.py | 12 +++++ src/pystencils/backend/platforms/x86.py | 11 +++-- .../backend/transformations/__init__.py | 2 + .../transformations/select_functions.py | 24 ++++++++++ src/pystencils/kernelcreation.py | 5 +- 10 files changed, 111 insertions(+), 23 deletions(-) create mode 100644 src/pystencils/backend/transformations/select_functions.py diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index ed94f9c8b..8a66457a9 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -349,6 +349,14 @@ class PsCall(PsExpression): def function(self) -> PsFunction: return self._function + @function.setter + def function(self, func: PsFunction): + if func.arg_count != self._function.arg_count: + raise ValueError( + "Current and replacement function must have the same number of parameters." + ) + self._function = func + @property def args(self) -> tuple[PsExpression, ...]: return tuple(self._args) diff --git a/src/pystencils/backend/exceptions.py b/src/pystencils/backend/exceptions.py index 29434d8ef..4c0812249 100644 --- a/src/pystencils/backend/exceptions.py +++ b/src/pystencils/backend/exceptions.py @@ -1,18 +1,21 @@ +"""Errors and Exceptions raised by the backend during kernel translation.""" + + class PsInternalCompilerError(Exception): - pass + """Indicates an internal error during kernel translation, most likely due to a bug inside pystencils.""" class PsOptionsError(Exception): - pass + """Indicates an option clash in the `CreateKernelConfig`.""" class PsInputError(Exception): - pass + """Indicates unsupported user input to the translation system""" class KernelConstraintsError(Exception): - pass + """Indicates a constraint violation in the symbolic kernel""" -class PsMalformedAstException(Exception): - pass +class MaterializationError(Exception): + """Indicates a fatal error during materialization of any abstract kernel component.""" diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index 39d5019ba..313b622be 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -36,8 +36,8 @@ class MathFunctions(Enum): Abs = ("abs", 1) - Min = ("fmin", 2) - Max = ("fmax", 2) + Min = ("min", 2) + Max = ("max", 2) Pow = ("pow", 2) diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 08e87ed82..a83bad4ba 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -1,7 +1,11 @@ from typing import Sequence from abc import ABC, abstractmethod +from ..functions import CFunction, PsMathFunction, MathFunctions +from ...types import PsType, PsIeeeFloatType + from .platform import Platform +from ..exceptions import MaterializationError from ..kernelcreation.iteration_space import ( IterationSpace, @@ -23,6 +27,14 @@ from ..transformations.select_intrinsics import IntrinsicOps class GenericCpu(Platform): + """Generic CPU platform. + + The `GenericCPU` platform models the following execution environment: + + - Generic multicore CPU architecture + - Iteration space represented by a loop nest, kernels are executed as a whole + - C standard library math functions available (``#include <math.h>`` or ``#include <cmath>``) + """ @property def required_headers(self) -> set[str]: @@ -38,6 +50,27 @@ class GenericCpu(Platform): else: assert False, "unreachable code" + def select_function( + self, math_function: PsMathFunction, dtype: PsType + ) -> CFunction: + func = math_function.func + if isinstance(dtype, PsIeeeFloatType) and dtype.width in (32, 64): + match func: + case ( + MathFunctions.Exp + | MathFunctions.Sin + | MathFunctions.Cos + | MathFunctions.Tan + | MathFunctions.Pow + ): + return CFunction(func.function_name, func.num_args) + case MathFunctions.Abs | MathFunctions.Min | MathFunctions.Max: + return CFunction("f" + func.function_name, func.num_args) + + raise MaterializationError( + f"No implementation available for function {math_function} on data type {dtype}" + ) + # Internals def _create_domain_loops( @@ -94,21 +127,18 @@ class GenericCpu(Platform): return PsBlock([loop]) -class IntrinsicsError(Exception): - """Exception indicating a fatal error during intrinsic materialization.""" - - class GenericVectorCpu(GenericCpu, ABC): + """Base class for CPU platforms with vectorization support through intrinsics.""" @abstractmethod def type_intrinsic(self, vector_type: PsVectorType) -> PsCustomType: """Return the intrinsic vector type for the given generic vector type, - or raise an `IntrinsicsError` if type is not supported.""" + or raise an `MaterializationError` if type is not supported.""" @abstractmethod def constant_vector(self, c: PsConstant) -> PsExpression: """Return an expression that initializes a constant vector, - or raise an `IntrinsicsError` if not supported.""" + or raise an `MaterializationError` if not supported.""" @abstractmethod def op_intrinsic( @@ -116,14 +146,14 @@ class GenericVectorCpu(GenericCpu, ABC): ) -> PsExpression: """Return an expression intrinsically invoking the given operation on the given arguments with the given vector type, - or raise an `IntrinsicsError` if not supported.""" + or raise an `MaterializationError` if not supported.""" @abstractmethod def vector_load(self, acc: PsVectorArrayAccess) -> PsExpression: """Return an expression intrinsically performing a vector load, - or raise an `IntrinsicsError` if not supported.""" + or raise an `MaterializationError` if not supported.""" @abstractmethod def vector_store(self, acc: PsVectorArrayAccess, arg: PsExpression) -> PsExpression: """Return an expression intrinsically performing a vector store, - or raise an `IntrinsicsError` if not supported.""" + or raise an `MaterializationError` if not supported.""" diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 79ab6f9ec..64c0cd3e9 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -1,3 +1,5 @@ +from pystencils.backend.functions import CFunction, PsMathFunction +from pystencils.types.basic_types import PsType from .platform import Platform from ..kernelcreation.iteration_space import ( @@ -54,6 +56,9 @@ class GenericGpu(Platform): ] return indices[:dim] + + def select_function(self, math_function: PsMathFunction, dtype: PsType) -> CFunction: + raise NotImplementedError() # Internals def _guard_full_iteration_space( diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py index 3fedf7c01..2c718ae5f 100644 --- a/src/pystencils/backend/platforms/platform.py +++ b/src/pystencils/backend/platforms/platform.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from ..ast.structural import PsBlock +from ..functions import PsMathFunction, CFunction +from ...types import PsType from ..kernelcreation.context import KernelCreationContext from ..kernelcreation.iteration_space import IterationSpace @@ -28,3 +30,13 @@ class Platform(ABC): self, block: PsBlock, ispace: IterationSpace ) -> PsBlock: pass + + @abstractmethod + def select_function( + self, math_function: PsMathFunction, dtype: PsType + ) -> CFunction: + """Select an implementation for the given function on the given data type. + + If no viable implementation exists, raise a `MaterializationError`. + """ + pass diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py index fe5e1c76d..fa5af4655 100644 --- a/src/pystencils/backend/platforms/x86.py +++ b/src/pystencils/backend/platforms/x86.py @@ -13,7 +13,8 @@ from ..transformations.select_intrinsics import IntrinsicOps from ...types import PsCustomType, PsVectorType from ..constants import PsConstant -from .generic_cpu import GenericVectorCpu, IntrinsicsError +from ..exceptions import MaterializationError +from .generic_cpu import GenericVectorCpu from ...types.quick import Fp, SInt from ..functions import CFunction @@ -46,7 +47,7 @@ class X86VectorArch(Enum): case 512 if self >= X86VectorArch.AVX512: prefix = "_mm512" case other: - raise IntrinsicsError( + raise MaterializationError( f"X86/{self} does not support vector width {other}" ) @@ -64,7 +65,7 @@ class X86VectorArch(Enum): case SInt(width): suffix = f"epi{width}" case _: - raise IntrinsicsError( + raise MaterializationError( f"X86/{self} does not support scalar type {scalar_type}" ) @@ -110,12 +111,12 @@ class X86VectorCpu(GenericVectorCpu): case SInt(_): suffix = "i" case _: - raise IntrinsicsError( + raise MaterializationError( f"X86/{self._vector_arch} does not support scalar type {scalar_type}" ) if vector_type.width > self._vector_arch.max_vector_width: - raise IntrinsicsError( + raise MaterializationError( f"X86/{self._vector_arch} does not support {vector_type}" ) return PsCustomType(f"__m{vector_type.width}{suffix}") diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index b4c0e8bbd..8ef35e4fb 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -1,9 +1,11 @@ from .eliminate_constants import EliminateConstants from .erase_anonymous_structs import EraseAnonymousStructTypes +from .select_functions import SelectFunctions from .select_intrinsics import MaterializeVectorIntrinsics __all__ = [ "EliminateConstants", "EraseAnonymousStructTypes", + "SelectFunctions", "MaterializeVectorIntrinsics", ] diff --git a/src/pystencils/backend/transformations/select_functions.py b/src/pystencils/backend/transformations/select_functions.py new file mode 100644 index 000000000..c4085b2bc --- /dev/null +++ b/src/pystencils/backend/transformations/select_functions.py @@ -0,0 +1,24 @@ +from ..platforms import Platform +from ..ast import PsAstNode +from ..ast.expressions import PsCall +from ..functions import PsMathFunction + + +class SelectFunctions: + """Traverse the AST to replace all instances of `PsMathFunction` by their implementation + provided by the given `Platform`.""" + + def __init__(self, platform: Platform): + self._platform = platform + + def __call__(self, node: PsAstNode) -> PsAstNode: + self.visit(node) + return node + + def visit(self, node: PsAstNode): + for c in node.children: + self.visit(c) + + if isinstance(node, PsCall) and isinstance(node.function, PsMathFunction): + impl = self._platform.select_function(node.function, node.get_dtype()) + node.function = impl diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 79cf92f8b..3f17e66bd 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -25,7 +25,7 @@ from .backend.kernelcreation.iteration_space import ( ) from .backend.ast.analysis import collect_required_headers, collect_undefined_symbols -from .backend.transformations import EraseAnonymousStructTypes, EliminateConstants +from .backend.transformations import EraseAnonymousStructTypes, EliminateConstants, SelectFunctions from .sympyextensions import AssignmentCollection, Assignment @@ -104,6 +104,9 @@ def create_kernel( from .backend.kernelcreation import optimize_cpu optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) + select_functions = SelectFunctions(platform) + kernel_ast = cast(PsBlock, select_functions(kernel_ast)) + assert config.jit is not None return create_kernel_function( ctx, kernel_ast, config.function_name, config.target, config.jit -- GitLab