Skip to content
Snippets Groups Projects
Commit 1328db5a authored by Frederik Hennig's avatar Frederik Hennig
Browse files

some work on function modelling

parent 03651cbf
Branches
Tags
No related merge requests found
Pipeline #62752 failed with stages
in 2 minutes and 3 seconds
......@@ -17,11 +17,33 @@ TODO: Figure out the best way to describe function signatures and overloads for
from sys import intern
import pymbolic.primitives as pb
from abc import ABC, abstractmethod
from enum import Enum
from .types import PsAbstractType
from .typed_expressions import ExprOrConstant
class MathFunctions(Enum):
"""Mathematical functions supported by the backend.
Each platform has to materialize these functions to a concrete implementation.
"""
Exp = ("exp", 1)
Sin = ("sin", 1)
Cos = ("cos", 1)
Tan = ("tan", 1)
Abs = ("abs", 1)
Min = ("min", 2)
Max = ("max", 2)
def __init__(self, func_name, arg_count):
self.function_name = func_name
self.arg_count = arg_count
class PsFunction(pb.FunctionSymbol, ABC):
@property
@abstractmethod
......@@ -29,6 +51,40 @@ class PsFunction(pb.FunctionSymbol, ABC):
"Number of arguments this function takes"
class CFunction(PsFunction):
"""A concrete C function."""
def __init__(self, qualified_name: str, arg_count: int):
self._qname = qualified_name
self._arg_count = arg_count
@property
def qualified_name(self) -> str:
return self._qname
@property
def arg_count(self) -> int:
return self._arg_count
class PsMathFunction(PsFunction):
"""Homogenously typed mathematical functions."""
init_arg_names = ("func",)
mapper_method = intern("map_math_function")
def __init__(self, func: MathFunctions) -> None:
self._func = func
@property
def func(self) -> MathFunctions:
return self._func
@property
def arg_count(self) -> int:
return self._func.arg_count
class Deref(PsFunction):
"""Dereferences a pointer."""
......
......@@ -22,6 +22,7 @@ from ..types import constify, make_type, PsStructType
from ..typed_expressions import PsTypedVariable
from ..arrays import PsArrayAccess
from ..exceptions import PsInputError
from ..functions import PsMathFunction, MathFunctions
class FreezeError(Exception):
......@@ -150,9 +151,28 @@ class FreezeExpressions(SympyToPymbolicMapper):
else:
return PsArrayAccess(ptr, index)
def map_Function(self, func: sp.Function):
"""Map a SymPy function to a backend-supported function symbol.
def map_Function(self, func: sp.Function) -> pb.Call:
"""Map SymPy function calls by mapping sympy function classes to backend-supported function symbols.
SymPy functions are frozen to an instance of `nbackend.functions.PsFunction`.
"""
raise NotImplementedError()
match func:
case sp.Abs():
func_symbol = PsMathFunction(MathFunctions.Abs)
case sp.exp():
func_symbol = PsMathFunction(MathFunctions.Exp)
case sp.sin():
func_symbol = PsMathFunction(MathFunctions.Sin)
case sp.cos():
func_symbol = PsMathFunction(MathFunctions.Cos)
case sp.tan():
func_symbol = PsMathFunction(MathFunctions.Tan)
case sp.Min():
func_symbol = PsMathFunction(MathFunctions.Min)
case sp.Max():
func_symbol = PsMathFunction(MathFunctions.Max)
case _:
raise FreezeError(f"Unsupported function: {func}")
args = tuple(self.rec(arg) for arg in func.args)
return pb.Call(func_symbol, args)
......@@ -10,6 +10,7 @@ from ..types import PsAbstractType, PsNumericType, PsStructType, deconstify
from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant
from ..arrays import PsArrayAccess
from ..ast import PsAstNode, PsBlock, PsExpression, PsAssignment
from ..functions import PsMathFunction
__all__ = ["Typifier"]
......@@ -208,10 +209,13 @@ class Typifier(Mapper):
# Functions
def map_call(self, expr: pb.Call, tc: TypeContext) -> pb.Call:
"""
TODO: Figure out how to describe function signatures
"""
raise NotImplementedError()
func = expr.function
args = expr.parameters
match func:
case PsMathFunction():
return pb.Call(func, tuple(self.rec(arg, tc) for arg in args))
case _:
raise TypificationError(f"Don't know how to typify calls to {func}")
# Internals
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment