diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 4f7d82b6eb9d4d96e15a0be1ccbfbfa36de0cf68..8437bdb6801ecc5adc717dec2983e5d6eda0baf0 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -11,12 +11,11 @@ from sympy.logic.boolalg import BooleanFalse, BooleanTrue from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize -from pystencils.typing import ( - PointerType, VectorType, CastFunc, create_type, get_type_of_expression, - ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol) +from pystencils.data_types import ( + PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, + reinterpret_cast_func, vector_memory_access, BasicType, TypedSymbol) from pystencils.enums import Backend from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt -from pystencils.functions import DivFunc, AddressOf from pystencils.integer_functions import ( bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil) @@ -31,6 +30,8 @@ __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSy HEADER_REGEX = re.compile(r'^[<"].*[">]$') +KERNCRAFT_NO_TERNARY_MODE = False + def generate_c(ast_node: Node, signature_only: bool = False, @@ -218,7 +219,7 @@ class CBackend: return getattr(self, method_name)(node) raise NotImplementedError(f"{self.__class__.__name__} does not support node of type {node.__class__.__name__}") - def _print_AbstractType(self, node): + def _print_Type(self, node): return str(node) def _print_KernelFunction(self, node): @@ -273,9 +274,9 @@ class CBackend: self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs)) else: - lhs_type = get_type_of_expression(node.lhs) # TOOD: this should have been typed + lhs_type = get_type_of_expression(node.lhs) printed_mask = "" - if type(lhs_type) is VectorType and isinstance(node.lhs, CastFunc): + if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func): arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args instr = 'storeU' if aligned: @@ -288,20 +289,20 @@ class CBackend: self._vector_instruction_set['load' + instr[-1]].format('{0}', **self._kwargs), '{1}', '{2}', **self._kwargs), **self._kwargs) printed_mask = self.sympy_printer.doprint(mask) - if data_type.base_type.c_name == 'double': + if data_type.base_type.base_name == 'double': if self._vector_instruction_set['double'] == '__m256d': printed_mask = f"_mm256_castpd_si256({printed_mask})" elif self._vector_instruction_set['double'] == '__m128d': printed_mask = f"_mm_castpd_si128({printed_mask})" - elif data_type.base_type.c_name == 'float': + elif data_type.base_type.base_name == 'float': if self._vector_instruction_set['float'] == '__m256': printed_mask = f"_mm256_castps_si256({printed_mask})" elif self._vector_instruction_set['float'] == '__m128': printed_mask = f"_mm_castps_si128({printed_mask})" - rhs_type = get_type_of_expression(node.rhs) # TOOD: vector only??? + rhs_type = get_type_of_expression(node.rhs) if type(rhs_type) is not VectorType: - rhs = CastFunc(node.rhs, VectorType(rhs_type)) + rhs = cast_func(node.rhs, VectorType(rhs_type)) else: rhs = node.rhs @@ -321,7 +322,7 @@ class CBackend: if stride == 1: offset = offset.subs({node.lhs.args[0].field.spatial_strides[0]: 1}) size = sp.Mul(*node.lhs.args[0].field.spatial_shape) - element_size = 8 if data_type.base_type.c_name == 'double' else 4 + element_size = 8 if data_type.base_type.base_name == 'double' else 4 size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}" pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \ self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n' @@ -415,7 +416,7 @@ class CBackend: return self._print_Block(node.true_block) elif type(node.condition_expr) is BooleanFalse: return self._print_Block(node.false_block) - cond_type = get_type_of_expression(node.condition_expr) # TODO: Could be vector or bool? + cond_type = get_type_of_expression(node.condition_expr) if isinstance(cond_type, VectorType): raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all") condition_expr = self.sympy_printer.doprint(node.condition_expr) @@ -435,19 +436,23 @@ class CustomSympyPrinter(CCodePrinter): def __init__(self): super(CustomSympyPrinter, self).__init__() + self._float_type = create_type("float32") def _print_Pow(self, expr): """Don't use std::pow function, for small integer exponents, write as multiplication""" if not expr.free_symbols: - raise NotImplementedError("This pow should be simplified already?") - # return self._typed_number(expr.evalf(), get_type_of_expression(expr.base)) - return super(CustomSympyPrinter, self)._print_Pow(expr) + return self._typed_number(expr.evalf(17), get_type_of_expression(expr.base)) - # TODO don't print ones in sp.Mul + if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: + return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})" + elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: + return f"1 / ({self._print(sp.Mul(*([expr.base] * -expr.exp), evaluate=False))})" + else: + return super(CustomSympyPrinter, self)._print_Pow(expr) def _print_Rational(self, expr): """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0""" - res = str(expr.evalf().num) + res = str(expr.evalf(17)) return res def _print_Equality(self, expr): @@ -478,15 +483,15 @@ class CustomSympyPrinter(CCodePrinter): } if hasattr(expr, 'to_c'): return expr.to_c(self._print) - if isinstance(expr, ReinterpretCastFunc): + if isinstance(expr, reinterpret_cast_func): arg, data_type = expr.args return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))" - elif isinstance(expr, AddressOf): + elif isinstance(expr, address_of): assert len(expr.args) == 1, "address_of must only have one argument" return f"&({self._print(expr.args[0])})" - elif isinstance(expr, CastFunc): + elif isinstance(expr, cast_func): arg, data_type = expr.args - if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)): + if isinstance(arg, sp.Number) and arg.is_finite: return self._typed_number(arg, data_type) else: return f"(({data_type})({self._print(arg)}))" @@ -500,6 +505,8 @@ class CustomSympyPrinter(CCodePrinter): return f"({self._print(1 / sp.sqrt(expr.args[0]))})" elif isinstance(expr, sp.Abs): return f"abs({self._print(expr.args[0])})" + elif isinstance(expr, sp.Max): + return self._print(expr) elif isinstance(expr, sp.Mod): if expr.args[0].is_integer and expr.args[1].is_integer: return f"({self._print(expr.args[0])} % {self._print(expr.args[1])})" @@ -511,8 +518,6 @@ class CustomSympyPrinter(CCodePrinter): return f"(1 << ({self._print(expr.args[0])}))" elif expr.func == int_div: return f"(({self._print(expr.args[0])}) / ({self._print(expr.args[1])}))" - elif expr.func == DivFunc: - return f'(({self._print(expr.divisor)}) / ({self._print(expr.dividend)}))' else: name = expr.name if hasattr(expr, 'name') else expr.__class__.__name__ arg_str = ', '.join(self._print(a) for a in expr.args) @@ -535,6 +540,52 @@ class CustomSympyPrinter(CCodePrinter): else: return res + def _print_Sum(self, expr): + template = """[&]() {{ + {dtype} sum = ({dtype}) 0; + for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{ + sum += {expr}; + }} + return sum; +}}()""" + var = expr.limits[0][0] + start = expr.limits[0][1] + end = expr.limits[0][2] + code = template.format( + dtype=get_type_of_expression(expr.args[0]), + iterator_dtype='int', + var=self._print(var), + start=self._print(start), + end=self._print(end), + expr=self._print(expr.function), + increment=str(1), + condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>=' + ) + return code + + def _print_Product(self, expr): + template = """[&]() {{ + {dtype} product = ({dtype}) 1; + for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{ + product *= {expr}; + }} + return product; +}}()""" + var = expr.limits[0][0] + start = expr.limits[0][1] + end = expr.limits[0][2] + code = template.format( + dtype=get_type_of_expression(expr.args[0]), + iterator_dtype='int', + var=self._print(var), + start=self._print(start), + end=self._print(end), + expr=self._print(expr.function), + increment=str(1), + condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>=' + ) + return code + def _print_ConditionalFieldAccess(self, node): return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True))) @@ -558,6 +609,27 @@ class CustomSympyPrinter(CCodePrinter): return f"(({a} < {b}) ? {a} : {b})" return inner_print_min(expr.args) + def _print_re(self, expr): + return f"real({self._print(expr.args[0])})" + + def _print_im(self, expr): + return f"imag({self._print(expr.args[0])})" + + def _print_ImaginaryUnit(self, expr): + return "complex<double>{0,1}" + + def _print_TypedImaginaryUnit(self, expr): + if expr.dtype.numpy_dtype == np.complex64: + return "complex<float>{0,1}" + elif expr.dtype.numpy_dtype == np.complex128: + return "complex<double>{0,1}" + else: + raise NotImplementedError( + "only complex64 and complex128 supported") + + def _print_Complex(self, expr): + return self._typed_number(expr, np.complex64) + # noinspection PyPep8Naming class VectorizedCustomSympyPrinter(CustomSympyPrinter): @@ -576,22 +648,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return None def _print_Abs(self, expr): - if 'abs' in self.instruction_set and isinstance(expr.args[0], VectorMemoryAccess): + if 'abs' in self.instruction_set and isinstance(expr.args[0], vector_memory_access): return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs) return super()._print_Abs(expr) def _print_Function(self, expr): - if isinstance(expr, VectorMemoryAccess): + if isinstance(expr, vector_memory_access): arg, data_type, aligned, _, mask, stride = expr.args if stride != 1: return self.instruction_set['loadS'].format(f"& {self._print(arg)}", stride, **self._kwargs) instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] return instruction.format(f"& {self._print(arg)}", **self._kwargs) - elif isinstance(expr, CastFunc): + elif isinstance(expr, cast_func): arg, data_type = expr.args if type(data_type) is VectorType: # vector_memory_access is a cast_func itself so it should't be directly inside a cast_func - assert not isinstance(arg, VectorMemoryAccess) + assert not isinstance(arg, vector_memory_access) if isinstance(arg, sp.Tuple): is_boolean = get_type_of_expression(arg[0]) == create_type("bool") is_integer = get_type_of_expression(arg[0]) == create_type("int") @@ -675,12 +747,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): # special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization suffix = "" - if all([(type(e) is CastFunc and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer) + if all([(type(e) is cast_func and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer) or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]): - dtype = set([e.dtype for e in args if type(e) is CastFunc]) + dtype = set([e.dtype for e in args if type(e) is cast_func]) assert len(dtype) == 1 dtype = dtype.pop() - args = [CastFunc(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e + args = [cast_func(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e for e in args] suffix = "int" @@ -808,9 +880,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): result = self._print(expr.args[-1][0]) for true_expr, condition in reversed(expr.args[:-1]): - if isinstance(condition, CastFunc) and get_type_of_expression(condition.args[0]) == create_type("bool"): - result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr), - result, **self._kwargs) + if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"): + if not KERNCRAFT_NO_TERNARY_MODE: + result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr), + result, **self._kwargs) + else: + print("Warning - skipping ternary op") else: # noinspection SpellCheckingInspection result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition), diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py index 955f6b73d37421c8d51cb04b871945571ffce7bd..720abb52ad0f66e030fa7b5f922b8ab0771124bd 100644 --- a/pystencils/simp/simplifications.py +++ b/pystencils/simp/simplifications.py @@ -3,10 +3,12 @@ from typing import Callable, List, Sequence, Union from collections import defaultdict import sympy as sp +from sympy.codegen.rewriting import optims_c99, optimize +from sympy.codegen.rewriting import ReplaceOptim from pystencils.assignment import Assignment -from pystencils.astnodes import Node -from pystencils.field import Field +from pystencils.astnodes import Node, SympyAssignment +from pystencils.field import AbstractField, Field from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect @@ -162,7 +164,7 @@ def add_subexpressions_for_sums(ac): for eq in ac.all_assignments: search_addends(eq.rhs) - addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, Field.Access)] + addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, AbstractField.AbstractAccess)] new_symbol_gen = ac.subexpression_symbol_generator substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)} return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False) @@ -225,29 +227,22 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]): return f -# TODO Markus -# TODO: make this really work for Assignmentcollections -# TODO: this function should ONLY evaluate -# TODO: do the optims_c99 elsewhere optionally -# def apply_sympy_optimisations(ac: AssignmentCollection): -# """ Evaluates constant expressions (e.g. :math:`\\sqrt{3}` will be replaced by its floating point representation) -# and applies the default sympy optimisations. See sympy.codegen.rewriting -# """ -# -# # Evaluates all constant terms -# -# assignments = ac.all_assignments -# -# evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer, -# lambda p: p.evalf()) -# -# sympy_optimisations = [evaluate_constant_terms] + list(optims_c99) -# -# assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations)) -# if hasattr(a, 'lhs') -# else a for a in assignments] -# assignments_nodes = [a.atoms(SympyAssignment) for a in assignments] -# for a in chain.from_iterable(assignments_nodes): -# a.optimize(sympy_optimisations) -# -# return AssignmentCollection(assignments) +def apply_sympy_optimisations(assignments): + """ Evaluates constant expressions (e.g. :math:`\\sqrt{3}` will be replaced by its floating point representation) + and applies the default sympy optimisations. See sympy.codegen.rewriting + """ + + # Evaluates all constant terms + evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer, + lambda p: p.evalf(17)) + + sympy_optimisations = [evaluate_constant_terms] + list(optims_c99) + + assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations)) + if hasattr(a, 'lhs') + else a for a in assignments] + assignments_nodes = [a.atoms(SympyAssignment) for a in assignments] + for a in chain.from_iterable(assignments_nodes): + a.optimize(sympy_optimisations) + + return assignments diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py index 8ac96f84e732debd8e1ff50426c483b6b5523868..b6a7cd81cf8b7618ab69f6e0dd69094f93de3238 100644 --- a/pystencils_tests/test_types.py +++ b/pystencils_tests/test_types.py @@ -1,93 +1,24 @@ -import pytest - -import pystencils.config import sympy as sp import numpy as np import pystencils as ps -from pystencils.typing import TypedSymbol, get_type_of_expression, VectorType, collate_types, \ - typed_symbols, CastFunc, PointerArithmeticFunc, PointerType, result_type, BasicType - - -def test_result_type(): - i = np.dtype('int32') - l = np.dtype('int64') - ui = np.dtype('uint32') - ul = np.dtype('uint64') - f = np.dtype('float32') - d = np.dtype('float64') - b = np.dtype('bool') - - assert result_type(i, l) == l - assert result_type(l, i) == l - assert result_type(ui, i) == i - assert result_type(ui, l) == l - assert result_type(ul, i) == i - assert result_type(ul, l) == l - assert result_type(d, f) == d - assert result_type(f, d) == d - assert result_type(i, f) == f - assert result_type(l, f) == f - assert result_type(ui, f) == f - assert result_type(ul, f) == f - assert result_type(i, d) == d - assert result_type(l, d) == d - assert result_type(ui, d) == d - assert result_type(ul, d) == d - assert result_type(b, i) == i - assert result_type(b, l) == l - assert result_type(b, ui) == ui - assert result_type(b, ul) == ul - assert result_type(b, f) == f - assert result_type(b, d) == d - - -@pytest.mark.parametrize('dtype', ('float64', 'float32', 'int64', 'int32', 'uint32', 'uint64')) -def test_simple_add(dtype): - constant = 1.0 - if dtype[0] in 'ui': - constant = 1 - f = ps.fields(f"f: {dtype}[1D]") - d = TypedSymbol("d", dtype) - - test_arr = np.array([constant], dtype=dtype) - - ur = ps.Assignment(f[0], f[0] + d) - - ast = ps.create_kernel(ur) - code = ps.get_code_str(ast) - kernel = ast.compile() - kernel(f=test_arr, d=constant) - - assert test_arr[0] == constant+constant - - -@pytest.mark.parametrize('dtype1', ('float64', 'float32', 'int64', 'int32', 'uint32', 'uint64')) -@pytest.mark.parametrize('dtype2', ('float64', 'float32', 'int64', 'int32', 'uint32', 'uint64')) -def test_mixed_add(dtype1, dtype2): - - constant = 1 - f = ps.fields(f"f: {dtype1}[1D]") - g = ps.fields(f"g: {dtype2}[1D]") +from pystencils import data_types +from pystencils.data_types import TypedSymbol, get_type_of_expression, VectorType, collate_types, create_type, \ + typed_symbols, type_all_numbers, matrix_symbols, cast_func, pointer_arithmetic_func, PointerType - test_f = np.array([constant], dtype=dtype1) - test_g = np.array([constant], dtype=dtype2) - ur = ps.Assignment(f[0], f[0] + g[0]) +def test_parsing(): + assert str(data_types.create_composite_type_from_string("const double *")) == "double const *" + assert str(data_types.create_composite_type_from_string("double const *")) == "double const *" - # TODO Markus: check for the logging if colate_types(dtype1, dtype2) != dtype1 - ast = ps.create_kernel(ur) - code = ps.get_code_str(ast) - kernel = ast.compile() - kernel(f=test_f, g=test_g) + t1 = data_types.create_composite_type_from_string("const double * const * const restrict") + t2 = data_types.create_composite_type_from_string(str(t1)) + assert t1 == t2 - assert test_f[0] == constant+constant - -# TODO vector def test_collation(): - double_type = BasicType('float64') - float_type = BasicType('float32') + double_type = create_type("double") + float_type = create_type("float32") double4_type = VectorType(double_type, 4) float4_type = VectorType(float_type, 4) assert collate_types([double_type, float_type]) == double_type @@ -96,23 +27,20 @@ def test_collation(): def test_vector_type(): - double_type = BasicType('float64') - float_type = BasicType('float32') + double_type = create_type("double") + float_type = create_type("float32") double4_type = VectorType(double_type, 4) float4_type = VectorType(float_type, 4) assert double4_type.item_size == 4 assert float4_type.item_size == 4 - double4_type2 = VectorType(double_type, 4) - assert double4_type == double4_type2 - assert double4_type != 4 - assert double4_type != float4_type + assert not double4_type == 4 def test_pointer_type(): - double_type = BasicType('float64') - float_type = BasicType('float32') + double_type = create_type("double") + float_type = create_type("float32") double4_type = PointerType(double_type, restrict=True) float4_type = PointerType(float_type, restrict=False) @@ -144,104 +72,96 @@ def test_assumptions(): assert x.shape[0].is_nonnegative assert (2 * x.shape[0]).is_nonnegative assert (2 * x.shape[0]).is_integer - assert (TypedSymbol('a', BasicType('uint64'))).is_nonnegative - assert (TypedSymbol('a', BasicType('uint64'))).is_positive is None - assert (TypedSymbol('a', BasicType('uint64')) + 1).is_positive + assert (TypedSymbol('a', create_type('uint64'))).is_nonnegative + assert (TypedSymbol('a', create_type('uint64'))).is_positive is None + assert (TypedSymbol('a', create_type('uint64')) + 1).is_positive assert (x.shape[0] + 1).is_real -@pytest.mark.parametrize('dtype', ('float64', 'float32')) -def test_sqrt_of_integer(dtype): +def test_sqrt_of_integer(): """Regression test for bug where sqrt(3) was classified as integer""" - f = ps.fields(f'f: {dtype}[1D]') - tmp = sp.symbols('tmp') + f = ps.fields("f: [1D]") + tmp = sp.symbols("tmp") assignments = [ps.Assignment(tmp, sp.sqrt(3)), ps.Assignment(f[0], tmp)] - arr = np.array([1], dtype=dtype) - # TODO Jupyter add auto lhs float/double problem - config = pystencils.config.CreateKernelConfig(data_type=dtype, default_number_float=dtype) + arr_double = np.array([1], dtype=np.float64) + kernel = ps.create_kernel(assignments).compile() + kernel(f=arr_double) + assert 1.7 < arr_double[0] < 1.8 - ast = ps.create_kernel(assignments, config=config) - kernel = ast.compile() - kernel(f=arr) - assert 1.7 < arr[0] < 1.8 + f = ps.fields("f: float32[1D]") + tmp = sp.symbols("tmp") - code = ps.get_code_str(ast) - constant = '1.7320508075688772f' - if dtype == 'float32': - assert constant in code - else: - assert constant not in code + assignments = [ps.Assignment(tmp, sp.sqrt(3)), + ps.Assignment(f[0], tmp)] + arr_single = np.array([1], dtype=np.float32) + config = ps.CreateKernelConfig(data_type="float32") + kernel = ps.create_kernel(assignments, config=config).compile() + kernel(f=arr_single) + + code = ps.get_code_str(kernel.ast) + # ps.show_code(kernel.ast) + # 1.7320508075688772935 --> it is actually correct to round to ...773. This was wrong before !282 + assert "1.7320508075688773f" in code + assert 1.7 < arr_single[0] < 1.8 -@pytest.mark.parametrize('dtype', ('float64', 'float32')) -def test_integer_comparision(dtype): - f = ps.fields(f"f: {dtype}[2D]") - d = TypedSymbol("dir", "int64") +def test_integer_comparision(): + f = ps.fields("f [2D]") + d = sp.Symbol("dir") ur = ps.Assignment(f[0, 0], sp.Piecewise((0, sp.Equality(d, 1)), (f[0, 0], True))) ast = ps.create_kernel(ur) code = ps.get_code_str(ast) - # There should be an explicit cast for the integer zero to the type of the field on the rhs - if dtype == 'float64': - t = "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (0.0): (_data_f_00[_stride_f_1*ctr_1]));" - else: - t = "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (0.0f): (_data_f_00[_stride_f_1*ctr_1]));" - assert t in code + assert "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (0.0): (_data_f_00[_stride_f_1*ctr_1]));" in code -def test_typed_symbols_dtype(): +def test_Basic_data_type(): assert typed_symbols(("s", "f"), np.uint) == typed_symbols("s, f", np.uint) t_symbols = typed_symbols(("s", "f"), np.uint) s = t_symbols[0] assert t_symbols[0] == TypedSymbol("s", np.uint) assert s.dtype.is_uint() + assert s.dtype.is_complex() == 0 - assert typed_symbols("s", np.float64).dtype.c_name == 'double' - assert typed_symbols("s", np.float32).dtype.c_name == 'float' + assert typed_symbols("s", str).dtype.is_other() + assert typed_symbols("s", bool).dtype.is_other() + assert typed_symbols("s", np.void).dtype.is_other() + + assert typed_symbols("s", np.float64).dtype.base_name == 'double' + # removed for old sympy version + # assert typed_symbols(("s"), np.float64).dtype.sympy_dtype == typed_symbols(("s"), float).dtype.sympy_dtype + + f, g = ps.fields("f, g : double[2D]") + + expr = ps.Assignment(f.center(), 2 * g.center() + 5) + new_expr = type_all_numbers(expr, np.float64) + + assert "cast_func(2, double)" in str(new_expr) + assert "cast_func(5, double)" in str(new_expr) + + m = matrix_symbols("a, b", np.uint, 3, 3) + assert len(m) == 2 + m = m[0] + for i, elem in enumerate(m): + assert elem == TypedSymbol(f"a{i}", np.uint) + assert elem.dtype.is_uint() assert TypedSymbol("s", np.uint).canonical == TypedSymbol("s", np.uint) assert TypedSymbol("s", np.uint).reversed == TypedSymbol("s", np.uint) def test_cast_func(): - assert CastFunc(TypedSymbol("s", np.uint), np.int64).canonical == TypedSymbol("s", np.uint).canonical + assert cast_func(TypedSymbol("s", np.uint), np.int64).canonical == TypedSymbol("s", np.uint).canonical - a = CastFunc(5, np.uint) + a = cast_func(5, np.uint) assert a.is_negative is False assert a.is_nonnegative def test_pointer_arithmetic_func(): - assert PointerArithmeticFunc(TypedSymbol("s", np.uint), 1).canonical == TypedSymbol("s", np.uint).canonical - - -def test_division(): - f = ps.fields('f(10): float32[2D]') - m, tau = sp.symbols("m, tau") - - up = [ps.Assignment(tau, 1 / (0.5 + (3.0 * m))), - ps.Assignment(f.center, tau)] - config = pystencils.config.CreateKernelConfig(data_type='float32', default_number_float='float32') - ast = ps.create_kernel(up, config=config) - code = ps.get_code_str(ast) - - assert "((1.0f) / (m*3.0f + 0.5f))" in code - - -def test_pow(): - f = ps.fields('f(10): float32[2D]') - m, tau = sp.symbols("m, tau") - - up = [ps.Assignment(tau, m ** 1.5), - ps.Assignment(f.center, tau)] - - config = pystencils.config.CreateKernelConfig(data_type="float32", default_number_float='float32') - ast = ps.create_kernel(up, config=config) - code = ps.get_code_str(ast) - - assert "1.5f" in code + assert pointer_arithmetic_func(TypedSymbol("s", np.uint), 1).canonical == TypedSymbol("s", np.uint).canonical