Newer
Older
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import sympy as sp
import sympy.abc
@pytest.mark.parametrize('dtype', ["float64", "float32"])
def test_sum(dtype):
sum = sp.Sum(sp.abc.k, (sp.abc.k, 1, 100))
# print(sum)
# print(expanded_sum)
x = ps.fields(f'x: {dtype}[1d]')
assignments = ps.AssignmentCollection({x.center(): sum})
ast = ps.create_kernel(assignments)
if dtype == "float32":
assert "5050.0f;" in code
array = np.zeros((10,), dtype=dtype)
kernel(x=array)
assert np.allclose(array, int(expanded_sum) * np.ones_like(array))
@pytest.mark.parametrize('dtype', ["int32", "int64", "float64", "float32"])
def test_product(dtype):
k = ps.TypedSymbol('k', create_type(dtype))
sum = sympy.Product(k, (k, 1, 10))
expanded_sum = sum.doit()
# print(sum)
# print(expanded_sum)
x = ps.fields(f'x: {dtype}[1d]')
assignments = ps.AssignmentCollection({x.center(): sum})
config = pystencils.config.CreateKernelConfig()
ast = ps.create_kernel(assignments, config=config)
code = ps.get_code_str(ast)
# print(code)
if dtype == "int64" or dtype == "int32":
assert '3628800;' in code
elif dtype == "float32":
assert '3628800.0f;' in code
else:
assert '3628800.0;' in code
array = np.zeros((10,), dtype=dtype)
kernel(x=array)
assert np.allclose(array, int(expanded_sum) * np.ones_like(array))
# TODO: See Issue !55
# def test_prod_var_limit():
#
# k = ps.TypedSymbol('k', create_type('int64'))
# limit = ps.TypedSymbol('limit', create_type('int64'))
#
# sum = sympy.Sum(k, (k, 1, limit))
# expanded_sum = sum.replace(limit, 100).doit()
#
# print(sum)
# print(expanded_sum)
#
# x = ps.fields('x: int64[1d]')
#
# assignments = ps.AssignmentCollection({x.center(): sum})
#
# ast = ps.create_kernel(assignments)
# ps.show_code(ast)
# kernel = ast.compile()
#
# array = np.zeros((10,), np.int64)
#
# kernel(x=array, limit=100)
#
# assert np.allclose(array, int(expanded_sum) * np.ones_like(array))