From 1328db5ac556e8ad84e2dc64fbf2392d3ac1c293 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 20 Feb 2024 14:13:17 +0100
Subject: [PATCH] some work on function modelling

---
 src/pystencils/backend/functions.py           | 56 +++++++++++++++++++
 .../backend/kernelcreation/freeze.py          | 26 ++++++++-
 .../backend/kernelcreation/typification.py    | 12 ++--
 3 files changed, 87 insertions(+), 7 deletions(-)

diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py
index a04eba761..e2474e52e 100644
--- a/src/pystencils/backend/functions.py
+++ b/src/pystencils/backend/functions.py
@@ -17,11 +17,33 @@ TODO: Figure out the best way to describe function signatures and overloads for
 from sys import intern
 import pymbolic.primitives as pb
 from abc import ABC, abstractmethod
+from enum import Enum
 
 from .types import PsAbstractType
 from .typed_expressions import ExprOrConstant
 
 
+class MathFunctions(Enum):
+    """Mathematical functions supported by the backend.
+
+    Each platform has to materialize these functions to a concrete implementation.
+    """
+
+    Exp = ("exp", 1)
+    Sin = ("sin", 1)
+    Cos = ("cos", 1)
+    Tan = ("tan", 1)
+
+    Abs = ("abs", 1)
+
+    Min = ("min", 2)
+    Max = ("max", 2)
+
+    def __init__(self, func_name, arg_count):
+        self.function_name = func_name
+        self.arg_count = arg_count
+
+
 class PsFunction(pb.FunctionSymbol, ABC):
     @property
     @abstractmethod
@@ -29,6 +51,40 @@ class PsFunction(pb.FunctionSymbol, ABC):
         "Number of arguments this function takes"
 
 
+class CFunction(PsFunction):
+    """A concrete C function."""
+
+    def __init__(self, qualified_name: str, arg_count: int):
+        self._qname = qualified_name
+        self._arg_count = arg_count
+
+    @property
+    def qualified_name(self) -> str:
+        return self._qname
+
+    @property
+    def arg_count(self) -> int:
+        return self._arg_count
+
+
+class PsMathFunction(PsFunction):
+    """Homogenously typed mathematical functions."""
+
+    init_arg_names = ("func",)
+    mapper_method = intern("map_math_function")
+
+    def __init__(self, func: MathFunctions) -> None:
+        self._func = func
+
+    @property
+    def func(self) -> MathFunctions:
+        return self._func
+
+    @property
+    def arg_count(self) -> int:
+        return self._func.arg_count
+
+
 class Deref(PsFunction):
     """Dereferences a pointer."""
 
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index fb18cc978..c241ce2a7 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -22,6 +22,7 @@ from ..types import constify, make_type, PsStructType
 from ..typed_expressions import PsTypedVariable
 from ..arrays import PsArrayAccess
 from ..exceptions import PsInputError
+from ..functions import PsMathFunction, MathFunctions
 
 
 class FreezeError(Exception):
@@ -150,9 +151,28 @@ class FreezeExpressions(SympyToPymbolicMapper):
         else:
             return PsArrayAccess(ptr, index)
 
-    def map_Function(self, func: sp.Function):
-        """Map a SymPy function to a backend-supported function symbol.
+    def map_Function(self, func: sp.Function) -> pb.Call:
+        """Map SymPy function calls by mapping sympy function classes to backend-supported function symbols.
 
         SymPy functions are frozen to an instance of `nbackend.functions.PsFunction`.
         """
-        raise NotImplementedError()
+        match func:
+            case sp.Abs():
+                func_symbol = PsMathFunction(MathFunctions.Abs)
+            case sp.exp():
+                func_symbol = PsMathFunction(MathFunctions.Exp)
+            case sp.sin():
+                func_symbol = PsMathFunction(MathFunctions.Sin)
+            case sp.cos():
+                func_symbol = PsMathFunction(MathFunctions.Cos)
+            case sp.tan():
+                func_symbol = PsMathFunction(MathFunctions.Tan)
+            case sp.Min():
+                func_symbol = PsMathFunction(MathFunctions.Min)
+            case sp.Max():
+                func_symbol = PsMathFunction(MathFunctions.Max)
+            case _:
+                raise FreezeError(f"Unsupported function: {func}")
+
+        args = tuple(self.rec(arg) for arg in func.args)
+        return pb.Call(func_symbol, args)
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index 893f3dc03..adb4c0cf7 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -10,6 +10,7 @@ from ..types import PsAbstractType, PsNumericType, PsStructType, deconstify
 from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant
 from ..arrays import PsArrayAccess
 from ..ast import PsAstNode, PsBlock, PsExpression, PsAssignment
+from ..functions import PsMathFunction
 
 __all__ = ["Typifier"]
 
@@ -208,10 +209,13 @@ class Typifier(Mapper):
     #   Functions
 
     def map_call(self, expr: pb.Call, tc: TypeContext) -> pb.Call:
-        """
-        TODO: Figure out how to describe function signatures
-        """
-        raise NotImplementedError()
+        func = expr.function
+        args = expr.parameters
+        match func:
+            case PsMathFunction():
+                return pb.Call(func, tuple(self.rec(arg, tc) for arg in args))
+            case _:
+                raise TypificationError(f"Don't know how to typify calls to {func}")
 
     #   Internals
 
-- 
GitLab