Commit 3c02d58e authored by Markus Holzer's avatar Markus Holzer
Browse files

Implemented Min and Max printer

parent 0653f52d
......@@ -330,10 +330,10 @@ class CustomSympyPrinter(CCodePrinter):
def __init__(self):
super(CustomSympyPrinter, self).__init__()
self._float_type = create_type("float32")
if 'Min' in self.known_functions:
del self.known_functions['Min']
if 'Max' in self.known_functions:
del self.known_functions['Max']
#if 'Min' in self.known_functions:
# del self.known_functions['Min']
# if 'Max' not in self.known_functions:
# self.known_functions.update({'Max': 'Max'})
def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication"""
......@@ -402,6 +402,8 @@ class CustomSympyPrinter(CCodePrinter):
return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
elif isinstance(expr, sp.Abs):
return f"abs({self._print(expr.args[0])})"
elif isinstance(expr, sp.Max):
return self._print(expr)
elif isinstance(expr, sp.Mod):
if expr.args[0].is_integer and expr.args[1].is_integer:
return f"({self._print(expr.args[0])} % {self._print(expr.args[1])})"
......@@ -476,8 +478,25 @@ class CustomSympyPrinter(CCodePrinter):
def _print_ConditionalFieldAccess(self, node):
return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True)))
_print_Max = C89CodePrinter._print_Max
_print_Min = C89CodePrinter._print_Min
def _print_Max(self, expr):
def inner_print_max(args):
if len(args) == 1:
return self._print(args[0])
half = len(args) // 2
a = inner_print_max(args[:half])
b = inner_print_max(args[half:])
return f"(({a} > {b}) ? {a} : {b})"
return inner_print_max(expr.args)
def _print_Min(self, expr):
def inner_print_min(args):
if len(args) == 1:
return self._print(args[0])
half = len(args) // 2
a = inner_print_min(args[:half])
b = inner_print_min(args[half:])
return f"(({a} < {b}) ? {a} : {b})"
return inner_print_min(expr.args)
def _print_re(self, expr):
return f"real({self._print(expr.args[0])})"
......@@ -575,6 +594,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self.instruction_set['&'].format(result, item)
return result
def _print_Max(self, expr):
return "test"
def _print_Or(self, expr):
result = self._scalarFallback('_print_Or', expr)
if result:
......
import sympy
import numpy
import pystencils
from pystencils.datahandling import create_data_handling
def test_max():
dh = create_data_handling(domain_size=(10, 10), periodicity=True)
x = dh.add_array('x', values_per_cell=1)
dh.fill("x", 0.0, ghost_layers=True)
y = dh.add_array('y', values_per_cell=1)
dh.fill("y", 1.0, ghost_layers=True)
z = dh.add_array('z', values_per_cell=1)
dh.fill("z", 2.0, ghost_layers=True)
# test sp.Max with one argument
assignment_1 = pystencils.Assignment(x.center, sympy.Max(y.center + 3.3))
ast_1 = pystencils.create_kernel(assignment_1)
kernel_1 = ast_1.compile()
# test sp.Max with two arguments
assignment_2 = pystencils.Assignment(x.center, sympy.Max(0.5, y.center - 1.5))
ast_2 = pystencils.create_kernel(assignment_2)
kernel_2 = ast_2.compile()
# test sp.Max with many arguments
assignment_3 = pystencils.Assignment(x.center, sympy.Max(z.center, 4.5, y.center - 1.5, y.center + z.center))
ast_3 = pystencils.create_kernel(assignment_3)
kernel_3 = ast_3.compile()
dh.run_kernel(kernel_1)
assert numpy.all(dh.cpu_arrays["x"] == 4.3)
dh.run_kernel(kernel_2)
assert numpy.all(dh.cpu_arrays["x"] == 0.5)
dh.run_kernel(kernel_3)
assert numpy.all(dh.cpu_arrays["x"] == 4.5)
def test_min():
dh = create_data_handling(domain_size=(10, 10), periodicity=True)
x = dh.add_array('x', values_per_cell=1)
dh.fill("x", 0.0, ghost_layers=True)
y = dh.add_array('y', values_per_cell=1)
dh.fill("y", 1.0, ghost_layers=True)
z = dh.add_array('z', values_per_cell=1)
dh.fill("z", 2.0, ghost_layers=True)
# test sp.Min with one argument
assignment_1 = pystencils.Assignment(x.center, sympy.Min(y.center + 3.3))
ast_1 = pystencils.create_kernel(assignment_1)
kernel_1 = ast_1.compile()
# test sp.Min with two arguments
assignment_2 = pystencils.Assignment(x.center, sympy.Min(0.5, y.center - 1.5))
ast_2 = pystencils.create_kernel(assignment_2)
kernel_2 = ast_2.compile()
# test sp.Min with many arguments
assignment_3 = pystencils.Assignment(x.center, sympy.Min(z.center, 4.5, y.center - 1.5, y.center + z.center))
ast_3 = pystencils.create_kernel(assignment_3)
kernel_3 = ast_3.compile()
dh.run_kernel(kernel_1)
assert numpy.all(dh.cpu_arrays["x"] == 4.3)
dh.run_kernel(kernel_2)
assert numpy.all(dh.cpu_arrays["x"] == - 0.5)
dh.run_kernel(kernel_3)
assert numpy.all(dh.cpu_arrays["x"] == - 0.5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment