diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index a04eba7612a5c61b5e1d88846b1c2e18ebef68c5..e2474e52e7c62496acc1fd4ddb119fa5c56a9b62 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -17,11 +17,33 @@ TODO: Figure out the best way to describe function signatures and overloads for from sys import intern import pymbolic.primitives as pb from abc import ABC, abstractmethod +from enum import Enum from .types import PsAbstractType from .typed_expressions import ExprOrConstant +class MathFunctions(Enum): + """Mathematical functions supported by the backend. + + Each platform has to materialize these functions to a concrete implementation. + """ + + Exp = ("exp", 1) + Sin = ("sin", 1) + Cos = ("cos", 1) + Tan = ("tan", 1) + + Abs = ("abs", 1) + + Min = ("min", 2) + Max = ("max", 2) + + def __init__(self, func_name, arg_count): + self.function_name = func_name + self.arg_count = arg_count + + class PsFunction(pb.FunctionSymbol, ABC): @property @abstractmethod @@ -29,6 +51,40 @@ class PsFunction(pb.FunctionSymbol, ABC): "Number of arguments this function takes" +class CFunction(PsFunction): + """A concrete C function.""" + + def __init__(self, qualified_name: str, arg_count: int): + self._qname = qualified_name + self._arg_count = arg_count + + @property + def qualified_name(self) -> str: + return self._qname + + @property + def arg_count(self) -> int: + return self._arg_count + + +class PsMathFunction(PsFunction): + """Homogenously typed mathematical functions.""" + + init_arg_names = ("func",) + mapper_method = intern("map_math_function") + + def __init__(self, func: MathFunctions) -> None: + self._func = func + + @property + def func(self) -> MathFunctions: + return self._func + + @property + def arg_count(self) -> int: + return self._func.arg_count + + class Deref(PsFunction): """Dereferences a pointer.""" diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index fb18cc97891538d92e927172bbd18bb6c3dd58ce..c241ce2a7e9c7db5328d9c5c5acc11cb43bc8e9d 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -22,6 +22,7 @@ from ..types import constify, make_type, PsStructType from ..typed_expressions import PsTypedVariable from ..arrays import PsArrayAccess from ..exceptions import PsInputError +from ..functions import PsMathFunction, MathFunctions class FreezeError(Exception): @@ -150,9 +151,28 @@ class FreezeExpressions(SympyToPymbolicMapper): else: return PsArrayAccess(ptr, index) - def map_Function(self, func: sp.Function): - """Map a SymPy function to a backend-supported function symbol. + def map_Function(self, func: sp.Function) -> pb.Call: + """Map SymPy function calls by mapping sympy function classes to backend-supported function symbols. SymPy functions are frozen to an instance of `nbackend.functions.PsFunction`. """ - raise NotImplementedError() + match func: + case sp.Abs(): + func_symbol = PsMathFunction(MathFunctions.Abs) + case sp.exp(): + func_symbol = PsMathFunction(MathFunctions.Exp) + case sp.sin(): + func_symbol = PsMathFunction(MathFunctions.Sin) + case sp.cos(): + func_symbol = PsMathFunction(MathFunctions.Cos) + case sp.tan(): + func_symbol = PsMathFunction(MathFunctions.Tan) + case sp.Min(): + func_symbol = PsMathFunction(MathFunctions.Min) + case sp.Max(): + func_symbol = PsMathFunction(MathFunctions.Max) + case _: + raise FreezeError(f"Unsupported function: {func}") + + args = tuple(self.rec(arg) for arg in func.args) + return pb.Call(func_symbol, args) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 893f3dc03d3e6a548766b8f162332d2e284032c0..adb4c0cf739beff4208773bfa232f7fab40bdd82 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -10,6 +10,7 @@ from ..types import PsAbstractType, PsNumericType, PsStructType, deconstify from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant from ..arrays import PsArrayAccess from ..ast import PsAstNode, PsBlock, PsExpression, PsAssignment +from ..functions import PsMathFunction __all__ = ["Typifier"] @@ -208,10 +209,13 @@ class Typifier(Mapper): # Functions def map_call(self, expr: pb.Call, tc: TypeContext) -> pb.Call: - """ - TODO: Figure out how to describe function signatures - """ - raise NotImplementedError() + func = expr.function + args = expr.parameters + match func: + case PsMathFunction(): + return pb.Call(func, tuple(self.rec(arg, tc) for arg in args)) + case _: + raise TypificationError(f"Don't know how to typify calls to {func}") # Internals