Skip to content
Snippets Groups Projects
Commit a011bb5f authored by Markus Holzer's avatar Markus Holzer
Browse files

Delete support of sympy SUM and sympy PROD

parent 0c9e9fcd
No related merge requests found
...@@ -536,52 +536,6 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -536,52 +536,6 @@ class CustomSympyPrinter(CCodePrinter):
else: else:
return res return res
def _print_Sum(self, expr):
template = """[&]() {{
{dtype} sum = ({dtype}) 0;
for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{
sum += {expr};
}}
return sum;
}}()"""
var = expr.limits[0][0]
start = expr.limits[0][1]
end = expr.limits[0][2]
code = template.format(
dtype=get_type_of_expression(expr.args[0]),
iterator_dtype='int',
var=self._print(var),
start=self._print(start),
end=self._print(end),
expr=self._print(expr.function),
increment=str(1),
condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>='
)
return code
def _print_Product(self, expr):
template = """[&]() {{
{dtype} product = ({dtype}) 1;
for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{
product *= {expr};
}}
return product;
}}()"""
var = expr.limits[0][0]
start = expr.limits[0][1]
end = expr.limits[0][2]
code = template.format(
dtype=get_type_of_expression(expr.args[0]),
iterator_dtype='int',
var=self._print(var),
start=self._print(start),
end=self._print(end),
expr=self._print(expr.function),
increment=str(1),
condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>='
)
return code
def _print_ConditionalFieldAccess(self, node): def _print_ConditionalFieldAccess(self, node):
return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True))) return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True)))
......
...@@ -18,118 +18,88 @@ import pystencils as ps ...@@ -18,118 +18,88 @@ import pystencils as ps
from pystencils.typing import create_type from pystencils.typing import create_type
@pytest.mark.parametrize('default_assignment_simplifications', [False, True]) @pytest.mark.parametrize('dtype', ["float64", "float32"])
def test_sum(default_assignment_simplifications): def test_sum(dtype):
sum = sp.Sum(sp.abc.k, (sp.abc.k, 1, 100)) sum = sp.Sum(sp.abc.k, (sp.abc.k, 1, 100))
expanded_sum = sum.doit() expanded_sum = sum.doit()
print(sum) # print(sum)
print(expanded_sum) # print(expanded_sum)
x = ps.fields('x: float32[1d]') x = ps.fields(f'x: {dtype}[1d]')
assignments = ps.AssignmentCollection({x.center(): sum}) assignments = ps.AssignmentCollection({x.center(): sum})
config = pystencils.config.CreateKernelConfig(default_assignment_simplifications=default_assignment_simplifications) ast = ps.create_kernel(assignments)
ast = ps.create_kernel(assignments, config=config)
code = ps.get_code_str(ast)
kernel = ast.compile()
print(code)
if default_assignment_simplifications is False:
assert 'double sum' in code
array = np.zeros((10,), np.float32)
kernel(x=array)
assert np.allclose(array, int(expanded_sum) * np.ones_like(array))
@pytest.mark.parametrize('default_assignment_simplifications', [False, True])
def test_sum_use_float(default_assignment_simplifications):
sum = sympy.Sum(sp.abc.k, (sp.abc.k, 1, 100))
expanded_sum = sum.doit()
print(sum)
print(expanded_sum)
x = ps.fields('x: float32[1d]')
assignments = ps.AssignmentCollection({x.center(): sum})
config = pystencils.config.CreateKernelConfig(default_assignment_simplifications=default_assignment_simplifications,
data_type=create_type('float32'))
ast = ps.create_kernel(assignments, config=config)
code = ps.get_code_str(ast) code = ps.get_code_str(ast)
kernel = ast.compile() kernel = ast.compile()
print(code) # ps.show_code(ast)
if default_assignment_simplifications is False:
assert 'float sum' in code
array = np.zeros((10,), np.float32) if dtype == "float32":
assert "5050.0f;" in code
array = np.zeros((10,), dtype=dtype)
kernel(x=array) kernel(x=array)
assert np.allclose(array, int(expanded_sum) * np.ones_like(array)) assert np.allclose(array, int(expanded_sum) * np.ones_like(array))
@pytest.mark.parametrize('default_assignment_simplifications', [False, True]) @pytest.mark.parametrize('dtype', ["int32", "int64", "float64", "float32"])
def test_product(default_assignment_simplifications): def test_product(dtype):
k = ps.TypedSymbol('k', create_type('int64')) k = ps.TypedSymbol('k', create_type(dtype))
sum = sympy.Product(k, (k, 1, 10)) sum = sympy.Product(k, (k, 1, 10))
expanded_sum = sum.doit() expanded_sum = sum.doit()
print(sum) # print(sum)
print(expanded_sum) # print(expanded_sum)
x = ps.fields('x: int64[1d]') x = ps.fields(f'x: {dtype}[1d]')
assignments = ps.AssignmentCollection({x.center(): sum}) assignments = ps.AssignmentCollection({x.center(): sum})
config = pystencils.config.CreateKernelConfig(default_assignment_simplifications=default_assignment_simplifications) config = pystencils.config.CreateKernelConfig()
ast = ps.create_kernel(assignments, config=config) ast = ps.create_kernel(assignments, config=config)
code = ps.get_code_str(ast) code = ps.get_code_str(ast)
kernel = ast.compile() kernel = ast.compile()
print(code) # print(code)
if default_assignment_simplifications is False: if dtype == "int64" or dtype == "int32":
assert 'int64_t product' in code assert '3628800;' in code
elif dtype == "float32":
array = np.zeros((10,), np.int64) assert '3628800.0f;' in code
else:
assert '3628800.0;' in code
array = np.zeros((10,), dtype=dtype)
kernel(x=array) kernel(x=array)
assert np.allclose(array, int(expanded_sum) * np.ones_like(array)) assert np.allclose(array, int(expanded_sum) * np.ones_like(array))
# TODO: See Issue !55
def test_prod_var_limit(): # def test_prod_var_limit():
#
k = ps.TypedSymbol('k', create_type('int64')) # k = ps.TypedSymbol('k', create_type('int64'))
limit = ps.TypedSymbol('limit', create_type('int64')) # limit = ps.TypedSymbol('limit', create_type('int64'))
#
sum = sympy.Sum(k, (k, 1, limit)) # sum = sympy.Sum(k, (k, 1, limit))
expanded_sum = sum.replace(limit, 100).doit() # expanded_sum = sum.replace(limit, 100).doit()
#
print(sum) # print(sum)
print(expanded_sum) # print(expanded_sum)
#
x = ps.fields('x: int64[1d]') # x = ps.fields('x: int64[1d]')
#
assignments = ps.AssignmentCollection({x.center(): sum}) # assignments = ps.AssignmentCollection({x.center(): sum})
#
ast = ps.create_kernel(assignments) # ast = ps.create_kernel(assignments)
ps.show_code(ast) # ps.show_code(ast)
kernel = ast.compile() # kernel = ast.compile()
#
array = np.zeros((10,), np.int64) # array = np.zeros((10,), np.int64)
#
kernel(x=array, limit=100) # kernel(x=array, limit=100)
#
assert np.allclose(array, int(expanded_sum) * np.ones_like(array)) # assert np.allclose(array, int(expanded_sum) * np.ones_like(array))
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment