diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index bb25a6a4ca9e119886d3b553ab937b188bad1173..d49b3ff96645625e4076639908bdd13d67c2e5e6 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -330,10 +330,10 @@ class CustomSympyPrinter(CCodePrinter): def __init__(self): super(CustomSympyPrinter, self).__init__() self._float_type = create_type("float32") - if 'Min' in self.known_functions: - del self.known_functions['Min'] - if 'Max' in self.known_functions: - del self.known_functions['Max'] + #if 'Min' in self.known_functions: + # del self.known_functions['Min'] + # if 'Max' not in self.known_functions: + # self.known_functions.update({'Max': 'Max'}) def _print_Pow(self, expr): """Don't use std::pow function, for small integer exponents, write as multiplication""" @@ -402,6 +402,8 @@ class CustomSympyPrinter(CCodePrinter): return f"({self._print(1 / sp.sqrt(expr.args[0]))})" elif isinstance(expr, sp.Abs): return f"abs({self._print(expr.args[0])})" + elif isinstance(expr, sp.Max): + return self._print(expr) elif isinstance(expr, sp.Mod): if expr.args[0].is_integer and expr.args[1].is_integer: return f"({self._print(expr.args[0])} % {self._print(expr.args[1])})" @@ -476,8 +478,25 @@ class CustomSympyPrinter(CCodePrinter): def _print_ConditionalFieldAccess(self, node): return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True))) - _print_Max = C89CodePrinter._print_Max - _print_Min = C89CodePrinter._print_Min + def _print_Max(self, expr): + def inner_print_max(args): + if len(args) == 1: + return self._print(args[0]) + half = len(args) // 2 + a = inner_print_max(args[:half]) + b = inner_print_max(args[half:]) + return f"(({a} > {b}) ? {a} : {b})" + return inner_print_max(expr.args) + + def _print_Min(self, expr): + def inner_print_min(args): + if len(args) == 1: + return self._print(args[0]) + half = len(args) // 2 + a = inner_print_min(args[:half]) + b = inner_print_min(args[half:]) + return f"(({a} < {b}) ? {a} : {b})" + return inner_print_min(expr.args) def _print_re(self, expr): return f"real({self._print(expr.args[0])})" @@ -575,6 +594,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): result = self.instruction_set['&'].format(result, item) return result + def _print_Max(self, expr): + return "test" + def _print_Or(self, expr): result = self._scalarFallback('_print_Or', expr) if result: diff --git a/pystencils_tests/test_Min_Max.py b/pystencils_tests/test_Min_Max.py new file mode 100644 index 0000000000000000000000000000000000000000..18cd2d99d10ec0b03a38630d07679f3191281493 --- /dev/null +++ b/pystencils_tests/test_Min_Max.py @@ -0,0 +1,70 @@ +import sympy +import numpy +import pystencils +from pystencils.datahandling import create_data_handling + + +def test_max(): + dh = create_data_handling(domain_size=(10, 10), periodicity=True) + + x = dh.add_array('x', values_per_cell=1) + dh.fill("x", 0.0, ghost_layers=True) + y = dh.add_array('y', values_per_cell=1) + dh.fill("y", 1.0, ghost_layers=True) + z = dh.add_array('z', values_per_cell=1) + dh.fill("z", 2.0, ghost_layers=True) + + # test sp.Max with one argument + assignment_1 = pystencils.Assignment(x.center, sympy.Max(y.center + 3.3)) + ast_1 = pystencils.create_kernel(assignment_1) + kernel_1 = ast_1.compile() + + # test sp.Max with two arguments + assignment_2 = pystencils.Assignment(x.center, sympy.Max(0.5, y.center - 1.5)) + ast_2 = pystencils.create_kernel(assignment_2) + kernel_2 = ast_2.compile() + + # test sp.Max with many arguments + assignment_3 = pystencils.Assignment(x.center, sympy.Max(z.center, 4.5, y.center - 1.5, y.center + z.center)) + ast_3 = pystencils.create_kernel(assignment_3) + kernel_3 = ast_3.compile() + + dh.run_kernel(kernel_1) + assert numpy.all(dh.cpu_arrays["x"] == 4.3) + dh.run_kernel(kernel_2) + assert numpy.all(dh.cpu_arrays["x"] == 0.5) + dh.run_kernel(kernel_3) + assert numpy.all(dh.cpu_arrays["x"] == 4.5) + + +def test_min(): + dh = create_data_handling(domain_size=(10, 10), periodicity=True) + + x = dh.add_array('x', values_per_cell=1) + dh.fill("x", 0.0, ghost_layers=True) + y = dh.add_array('y', values_per_cell=1) + dh.fill("y", 1.0, ghost_layers=True) + z = dh.add_array('z', values_per_cell=1) + dh.fill("z", 2.0, ghost_layers=True) + + # test sp.Min with one argument + assignment_1 = pystencils.Assignment(x.center, sympy.Min(y.center + 3.3)) + ast_1 = pystencils.create_kernel(assignment_1) + kernel_1 = ast_1.compile() + + # test sp.Min with two arguments + assignment_2 = pystencils.Assignment(x.center, sympy.Min(0.5, y.center - 1.5)) + ast_2 = pystencils.create_kernel(assignment_2) + kernel_2 = ast_2.compile() + + # test sp.Min with many arguments + assignment_3 = pystencils.Assignment(x.center, sympy.Min(z.center, 4.5, y.center - 1.5, y.center + z.center)) + ast_3 = pystencils.create_kernel(assignment_3) + kernel_3 = ast_3.compile() + + dh.run_kernel(kernel_1) + assert numpy.all(dh.cpu_arrays["x"] == 4.3) + dh.run_kernel(kernel_2) + assert numpy.all(dh.cpu_arrays["x"] == - 0.5) + dh.run_kernel(kernel_3) + assert numpy.all(dh.cpu_arrays["x"] == - 0.5)