From 1328db5ac556e8ad84e2dc64fbf2392d3ac1c293 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 20 Feb 2024 14:13:17 +0100 Subject: [PATCH] some work on function modelling --- src/pystencils/backend/functions.py | 56 +++++++++++++++++++ .../backend/kernelcreation/freeze.py | 26 ++++++++- .../backend/kernelcreation/typification.py | 12 ++-- 3 files changed, 87 insertions(+), 7 deletions(-) diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index a04eba761..e2474e52e 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -17,11 +17,33 @@ TODO: Figure out the best way to describe function signatures and overloads for from sys import intern import pymbolic.primitives as pb from abc import ABC, abstractmethod +from enum import Enum from .types import PsAbstractType from .typed_expressions import ExprOrConstant +class MathFunctions(Enum): + """Mathematical functions supported by the backend. + + Each platform has to materialize these functions to a concrete implementation. + """ + + Exp = ("exp", 1) + Sin = ("sin", 1) + Cos = ("cos", 1) + Tan = ("tan", 1) + + Abs = ("abs", 1) + + Min = ("min", 2) + Max = ("max", 2) + + def __init__(self, func_name, arg_count): + self.function_name = func_name + self.arg_count = arg_count + + class PsFunction(pb.FunctionSymbol, ABC): @property @abstractmethod @@ -29,6 +51,40 @@ class PsFunction(pb.FunctionSymbol, ABC): "Number of arguments this function takes" +class CFunction(PsFunction): + """A concrete C function.""" + + def __init__(self, qualified_name: str, arg_count: int): + self._qname = qualified_name + self._arg_count = arg_count + + @property + def qualified_name(self) -> str: + return self._qname + + @property + def arg_count(self) -> int: + return self._arg_count + + +class PsMathFunction(PsFunction): + """Homogenously typed mathematical functions.""" + + init_arg_names = ("func",) + mapper_method = intern("map_math_function") + + def __init__(self, func: MathFunctions) -> None: + self._func = func + + @property + def func(self) -> MathFunctions: + return self._func + + @property + def arg_count(self) -> int: + return self._func.arg_count + + class Deref(PsFunction): """Dereferences a pointer.""" diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index fb18cc978..c241ce2a7 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -22,6 +22,7 @@ from ..types import constify, make_type, PsStructType from ..typed_expressions import PsTypedVariable from ..arrays import PsArrayAccess from ..exceptions import PsInputError +from ..functions import PsMathFunction, MathFunctions class FreezeError(Exception): @@ -150,9 +151,28 @@ class FreezeExpressions(SympyToPymbolicMapper): else: return PsArrayAccess(ptr, index) - def map_Function(self, func: sp.Function): - """Map a SymPy function to a backend-supported function symbol. + def map_Function(self, func: sp.Function) -> pb.Call: + """Map SymPy function calls by mapping sympy function classes to backend-supported function symbols. SymPy functions are frozen to an instance of `nbackend.functions.PsFunction`. """ - raise NotImplementedError() + match func: + case sp.Abs(): + func_symbol = PsMathFunction(MathFunctions.Abs) + case sp.exp(): + func_symbol = PsMathFunction(MathFunctions.Exp) + case sp.sin(): + func_symbol = PsMathFunction(MathFunctions.Sin) + case sp.cos(): + func_symbol = PsMathFunction(MathFunctions.Cos) + case sp.tan(): + func_symbol = PsMathFunction(MathFunctions.Tan) + case sp.Min(): + func_symbol = PsMathFunction(MathFunctions.Min) + case sp.Max(): + func_symbol = PsMathFunction(MathFunctions.Max) + case _: + raise FreezeError(f"Unsupported function: {func}") + + args = tuple(self.rec(arg) for arg in func.args) + return pb.Call(func_symbol, args) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 893f3dc03..adb4c0cf7 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -10,6 +10,7 @@ from ..types import PsAbstractType, PsNumericType, PsStructType, deconstify from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant from ..arrays import PsArrayAccess from ..ast import PsAstNode, PsBlock, PsExpression, PsAssignment +from ..functions import PsMathFunction __all__ = ["Typifier"] @@ -208,10 +209,13 @@ class Typifier(Mapper): # Functions def map_call(self, expr: pb.Call, tc: TypeContext) -> pb.Call: - """ - TODO: Figure out how to describe function signatures - """ - raise NotImplementedError() + func = expr.function + args = expr.parameters + match func: + case PsMathFunction(): + return pb.Call(func, tuple(self.rec(arg, tc) for arg in args)) + case _: + raise TypificationError(f"Don't know how to typify calls to {func}") # Internals -- GitLab