diff --git a/src/pystencils/nbackend/functions.py b/src/pystencils/nbackend/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..e7dc4e6cb26e59cb0a1748fe45966356e5c76930 --- /dev/null +++ b/src/pystencils/nbackend/functions.py @@ -0,0 +1,24 @@ +""" +Functions supported by pystencils. + +Every supported function might require handling logic in the following modules: + + - In `freeze.FreezeExpressions`, a case in `map_Function` or a separate mapper method to catch its frontend variant + - In each backend platform, a case in `materialize_functions` to map the function onto a concrete C/C++ implementation + - If very special typing rules apply, a case in `typification.Typifier`. + +In most cases, typification of function applications will require no special handling. + +TODO: Maybe add a way for the user to register additional functions +TODO: Figure out the best way to describe function signatures and overloads for typing +""" + +import pymbolic.primitives as pb +from abc import ABC, abstractmethod + + +class PsFunction(pb.FunctionSymbol, ABC): + @property + @abstractmethod + def arg_count(self) -> int: + "Number of arguments this function takes" diff --git a/src/pystencils/nbackend/kernelcreation/freeze.py b/src/pystencils/nbackend/kernelcreation/freeze.py index 64d3fa2e9ec05be1eadb71bc8c11ae68fda978de..cab2ab70ad2a442f0cb4a9faea80845a0526b266 100644 --- a/src/pystencils/nbackend/kernelcreation/freeze.py +++ b/src/pystencils/nbackend/kernelcreation/freeze.py @@ -118,3 +118,10 @@ class FreezeExpressions(SympyToPymbolicMapper): index = summands[0] if len(summands) == 1 else pb.Sum(summands) return PsArrayAccess(ptr, index) + + def map_Function(self, func: sp.Function): + """Map a SymPy function to a backend-supported function symbol. + + SymPy functions are frozen to an instance of `nbackend.functions.PsFunction`. + """ + raise NotImplementedError() diff --git a/src/pystencils/nbackend/kernelcreation/kernelcreation.py b/src/pystencils/nbackend/kernelcreation/kernelcreation.py index 07614da9ac7915bd14a27fa2c50f3dfbaf8faf59..f29cd9a13488502caffae966e0f9ad1cf1840066 100644 --- a/src/pystencils/nbackend/kernelcreation/kernelcreation.py +++ b/src/pystencils/nbackend/kernelcreation/kernelcreation.py @@ -9,7 +9,6 @@ from .freeze import FreezeExpressions from .typification import Typifier from .options import KernelCreationOptions from .iteration_space import ( - IterationSpace, create_sparse_iteration_space, create_full_iteration_space, ) @@ -21,11 +20,10 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti analysis = KernelAnalysis(ctx) analysis(assignments) - ispace: IterationSpace = ( - create_sparse_iteration_space(ctx, assignments) - if len(ctx.fields.index_fields) > 0 or ctx.options.index_field is not None - else create_full_iteration_space(ctx, assignments) - ) + if len(ctx.fields.index_fields) > 0 or ctx.options.index_field is not None: + ispace = create_sparse_iteration_space(ctx, assignments) + else: + ispace = create_full_iteration_space(ctx, assignments) ctx.set_iteration_space(ispace) @@ -37,22 +35,22 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti match options.target: case Target.CPU: - from .platform import BasicCpuGen + from .platform import BasicCpu # TODO: CPU platform should incorporate instruction set info, OpenMP, etc. - platform_generator = BasicCpuGen(ctx) + platform = BasicCpu(ctx) case _: # TODO: CUDA/HIP platform # TODO: SYCL platform (?) raise NotImplementedError("Target platform not implemented") - kernel_ast = platform_generator.materialize_iteration_space(kernel_body, ispace) + kernel_ast = platform.materialize_iteration_space(kernel_body, ispace) # 7. Apply optimizations # - Vectorization # - OpenMP # - Loop Splitting, Tiling, Blocking - kernel_ast = platform_generator.optimize(kernel_ast) + kernel_ast = platform.optimize(kernel_ast) function = PsKernelFunction(kernel_ast, options.target, name=options.function_name) function.add_constraints(*ctx.constraints) diff --git a/src/pystencils/nbackend/kernelcreation/platform/__init__.py b/src/pystencils/nbackend/kernelcreation/platform/__init__.py index 85b5af9c0c0a95a3a5ec6398664d40309b40daaf..20e2c0aae07e1d300793a3b6102b9c6b0536f83f 100644 --- a/src/pystencils/nbackend/kernelcreation/platform/__init__.py +++ b/src/pystencils/nbackend/kernelcreation/platform/__init__.py @@ -1,5 +1,5 @@ -from .basic_cpu import BasicCpuGen +from .basic_cpu import BasicCpu __all__ = [ - 'BasicCpuGen' + 'BasicCpu' ] diff --git a/src/pystencils/nbackend/kernelcreation/platform/basic_cpu.py b/src/pystencils/nbackend/kernelcreation/platform/basic_cpu.py index 347061f19114b87829059ee0314cbd24b6dc7f0c..f5deaf0d65a0a19d568c7dead77bace5efbac517 100644 --- a/src/pystencils/nbackend/kernelcreation/platform/basic_cpu.py +++ b/src/pystencils/nbackend/kernelcreation/platform/basic_cpu.py @@ -1,4 +1,4 @@ -from .platform import PlatformGen +from .platform import Platform from ..iteration_space import ( IterationSpace, @@ -11,7 +11,7 @@ from ...typed_expressions import PsTypedConstant from ...arrays import PsArrayAccess -class BasicCpuGen(PlatformGen): +class BasicCpu(Platform): def materialize_iteration_space( self, body: PsBlock, ispace: IterationSpace ) -> PsBlock: diff --git a/src/pystencils/nbackend/kernelcreation/platform/platform.py b/src/pystencils/nbackend/kernelcreation/platform/platform.py index 17dfa23f6ed3d847a9eab60cf22a8c05c31b6b88..b2c7f899ea0842bf77e4d4cb759a4b5a9b3493b6 100644 --- a/src/pystencils/nbackend/kernelcreation/platform/platform.py +++ b/src/pystencils/nbackend/kernelcreation/platform/platform.py @@ -6,7 +6,7 @@ from ..context import KernelCreationContext from ..iteration_space import IterationSpace -class PlatformGen(ABC): +class Platform(ABC): """Abstract base class for all supported platforms. The platform performs all target-dependent tasks during code generation: diff --git a/src/pystencils/nbackend/kernelcreation/typification.py b/src/pystencils/nbackend/kernelcreation/typification.py index 9bc9a462f881af26fced8c6f812877db8f81dec3..41f14431b1c263fcfef7a19fbd7c73346cf591dc 100644 --- a/src/pystencils/nbackend/kernelcreation/typification.py +++ b/src/pystencils/nbackend/kernelcreation/typification.py @@ -115,7 +115,9 @@ class Typifier(Mapper): ) -> tuple[PsArrayAccess, PsNumericType]: self._check_target_type(access, access.dtype, target_type) index, _ = self.rec(access.index_tuple[0], self._ctx.options.index_dtype) - return PsArrayAccess(access.base_ptr, index), cast(PsNumericType, deconstify(access.dtype)) + return PsArrayAccess(access.base_ptr, index), cast( + PsNumericType, deconstify(access.dtype) + ) # Arithmetic Expressions @@ -156,6 +158,16 @@ class Typifier(Mapper): new_args, dtype = self._homogenize(expr, expr.children, target_type) return pb.Product(new_args), dtype + def map_call( + self, expr: pb.Call, target_type: PsNumericType | None + ) -> tuple[pb.Call, PsNumericType]: + """ + TODO: Figure out the best way to typify functions + + - How to propagate target_type in the face of multiple overloads? + """ + raise NotImplementedError() + def _check_target_type( self, expr: ExprOrConstant,