Skip to content
Snippets Groups Projects
Commit 1328db5a authored by Frederik Hennig's avatar Frederik Hennig
Browse files

some work on function modelling

parent 03651cbf
No related merge requests found
Pipeline #62752 failed with stages
in 2 minutes and 3 seconds
...@@ -17,11 +17,33 @@ TODO: Figure out the best way to describe function signatures and overloads for ...@@ -17,11 +17,33 @@ TODO: Figure out the best way to describe function signatures and overloads for
from sys import intern from sys import intern
import pymbolic.primitives as pb import pymbolic.primitives as pb
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum
from .types import PsAbstractType from .types import PsAbstractType
from .typed_expressions import ExprOrConstant from .typed_expressions import ExprOrConstant
class MathFunctions(Enum):
"""Mathematical functions supported by the backend.
Each platform has to materialize these functions to a concrete implementation.
"""
Exp = ("exp", 1)
Sin = ("sin", 1)
Cos = ("cos", 1)
Tan = ("tan", 1)
Abs = ("abs", 1)
Min = ("min", 2)
Max = ("max", 2)
def __init__(self, func_name, arg_count):
self.function_name = func_name
self.arg_count = arg_count
class PsFunction(pb.FunctionSymbol, ABC): class PsFunction(pb.FunctionSymbol, ABC):
@property @property
@abstractmethod @abstractmethod
...@@ -29,6 +51,40 @@ class PsFunction(pb.FunctionSymbol, ABC): ...@@ -29,6 +51,40 @@ class PsFunction(pb.FunctionSymbol, ABC):
"Number of arguments this function takes" "Number of arguments this function takes"
class CFunction(PsFunction):
"""A concrete C function."""
def __init__(self, qualified_name: str, arg_count: int):
self._qname = qualified_name
self._arg_count = arg_count
@property
def qualified_name(self) -> str:
return self._qname
@property
def arg_count(self) -> int:
return self._arg_count
class PsMathFunction(PsFunction):
"""Homogenously typed mathematical functions."""
init_arg_names = ("func",)
mapper_method = intern("map_math_function")
def __init__(self, func: MathFunctions) -> None:
self._func = func
@property
def func(self) -> MathFunctions:
return self._func
@property
def arg_count(self) -> int:
return self._func.arg_count
class Deref(PsFunction): class Deref(PsFunction):
"""Dereferences a pointer.""" """Dereferences a pointer."""
......
...@@ -22,6 +22,7 @@ from ..types import constify, make_type, PsStructType ...@@ -22,6 +22,7 @@ from ..types import constify, make_type, PsStructType
from ..typed_expressions import PsTypedVariable from ..typed_expressions import PsTypedVariable
from ..arrays import PsArrayAccess from ..arrays import PsArrayAccess
from ..exceptions import PsInputError from ..exceptions import PsInputError
from ..functions import PsMathFunction, MathFunctions
class FreezeError(Exception): class FreezeError(Exception):
...@@ -150,9 +151,28 @@ class FreezeExpressions(SympyToPymbolicMapper): ...@@ -150,9 +151,28 @@ class FreezeExpressions(SympyToPymbolicMapper):
else: else:
return PsArrayAccess(ptr, index) return PsArrayAccess(ptr, index)
def map_Function(self, func: sp.Function): def map_Function(self, func: sp.Function) -> pb.Call:
"""Map a SymPy function to a backend-supported function symbol. """Map SymPy function calls by mapping sympy function classes to backend-supported function symbols.
SymPy functions are frozen to an instance of `nbackend.functions.PsFunction`. SymPy functions are frozen to an instance of `nbackend.functions.PsFunction`.
""" """
raise NotImplementedError() match func:
case sp.Abs():
func_symbol = PsMathFunction(MathFunctions.Abs)
case sp.exp():
func_symbol = PsMathFunction(MathFunctions.Exp)
case sp.sin():
func_symbol = PsMathFunction(MathFunctions.Sin)
case sp.cos():
func_symbol = PsMathFunction(MathFunctions.Cos)
case sp.tan():
func_symbol = PsMathFunction(MathFunctions.Tan)
case sp.Min():
func_symbol = PsMathFunction(MathFunctions.Min)
case sp.Max():
func_symbol = PsMathFunction(MathFunctions.Max)
case _:
raise FreezeError(f"Unsupported function: {func}")
args = tuple(self.rec(arg) for arg in func.args)
return pb.Call(func_symbol, args)
...@@ -10,6 +10,7 @@ from ..types import PsAbstractType, PsNumericType, PsStructType, deconstify ...@@ -10,6 +10,7 @@ from ..types import PsAbstractType, PsNumericType, PsStructType, deconstify
from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant
from ..arrays import PsArrayAccess from ..arrays import PsArrayAccess
from ..ast import PsAstNode, PsBlock, PsExpression, PsAssignment from ..ast import PsAstNode, PsBlock, PsExpression, PsAssignment
from ..functions import PsMathFunction
__all__ = ["Typifier"] __all__ = ["Typifier"]
...@@ -208,10 +209,13 @@ class Typifier(Mapper): ...@@ -208,10 +209,13 @@ class Typifier(Mapper):
# Functions # Functions
def map_call(self, expr: pb.Call, tc: TypeContext) -> pb.Call: def map_call(self, expr: pb.Call, tc: TypeContext) -> pb.Call:
""" func = expr.function
TODO: Figure out how to describe function signatures args = expr.parameters
""" match func:
raise NotImplementedError() case PsMathFunction():
return pb.Call(func, tuple(self.rec(arg, tc) for arg in args))
case _:
raise TypificationError(f"Don't know how to typify calls to {func}")
# Internals # Internals
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment