diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index 73634539505e789963b62072f14a28c0586a65a5..18c2277cf76102f9265114853f97b8e2eb50cc67 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -100,8 +100,12 @@ class NumericLimitsFunctions(Enum): Each platform has to materialize these functions to a concrete implementation. """ - min = ("min", 0) - max = ("max", 0) + Min = ("min", 0) + Max = ("max", 0) + + def __init__(self, func_name, num_args): + self.function_name = func_name + self.num_args = num_args class PsMathFunction(PsFunction): diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index ae728dd4984cbb9814a120b7ab6c753c91254ea8..9a34303e21071137c6e929d423660837fffdd6d0 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -206,10 +206,10 @@ class FreezeExpressions: # TODO: unsure if sp.Min & sp.Max are mapped by map_Min/map_Max afterwards case "min": op = sp.Min - init_val = NumericLimitsFunctions("min") + init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), []) case "max": op = sp.Max - init_val = NumericLimitsFunctions("max") + init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), []) case _: raise FreezeError(f"Unsupported reduced assignment: {expr.op}.") diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 27df6aee4fd79843d5e80a95113d4f49c7a6f04a..e1a34564d195ac16312768735facdf33990c7239 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -62,7 +62,7 @@ class GenericCpu(Platform): dtype = call.get_dtype() arg_types = (dtype,) * func.num_args - if isinstance(dtype, PsScalarType) and func in (NumericLimitsFunctions.min, NumericLimitsFunctions.max): + if isinstance(dtype, PsScalarType) and func in (NumericLimitsFunctions.Min, NumericLimitsFunctions.Max): cfunc = CFunction(f"{dtype.c_string()}_{func.function_name}".capitalize(), arg_types, dtype) call.function = cfunc return call