diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index ab35ff10819b1056a8495cc86d4afd879e4551f9..b5dc66e10c40af4849e7904f8dcfff267166b7f7 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -491,10 +491,7 @@ class CustomSympyPrinter(CCodePrinter): return f"&({self._print(expr.args[0])})" elif isinstance(expr, CastFunc): arg, data_type = expr.args - if isinstance(arg, sp.Number) and arg.is_finite: - return self._typed_number(arg, data_type) - else: - return f"(({data_type})({self._print(arg)}))" + return f"(({data_type})({self._print(arg)}))" elif isinstance(expr, fast_division): return f"({self._print(expr.args[0] / expr.args[1])})" elif isinstance(expr, fast_sqrt): diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py index b3817c8f5c11ac07a8084f215ac04768deb15c8c..4e0573c89e9b66c92f81a97a97256460e75f711e 100644 --- a/pystencils/cpu/kernelcreation.py +++ b/pystencils/cpu/kernelcreation.py @@ -4,7 +4,7 @@ import sympy as sp import numpy as np import pystencils.astnodes as ast -from pystencils.assignment import Assignment +from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.config import CreateKernelConfig from pystencils.enums import Target, Backend from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment @@ -17,12 +17,8 @@ from pystencils.transformations import ( move_constants_before_loop, parse_base_pointer_info, resolve_buffer_accesses, resolve_field_accesses, split_inner_loop) -from pystencils.kernel_contrains_check import KernelConstraintsCheck -AssignmentOrAstNodeList = List[Union[Assignment, ast.Node]] - - -def create_kernel(assignments: AssignmentOrAstNodeList, config: CreateKernelConfig, split_groups) -> KernelFunction: +def create_kernel(assignments: AssignmentCollection, config: CreateKernelConfig) -> KernelFunction: """Creates an abstract syntax tree for a kernel function, by taking a list of update rules. Loops are created according to the field accesses in the equations. @@ -31,8 +27,6 @@ def create_kernel(assignments: AssignmentOrAstNodeList, config: CreateKernelConf assignments: list of sympy equations, containing accesses to :class:`pystencils.field.Field`. Defining the update rules of the kernel config: create kernel config - split_groups: Specification on how to split up inner loop into multiple loops. For details see - transformation :func:`pystencils.transformation.split_inner_loop` Returns: AST node representing a function, that can be printed as C or CUDA code @@ -41,8 +35,13 @@ def create_kernel(assignments: AssignmentOrAstNodeList, config: CreateKernelConf type_info = config.data_type iteration_slice = config.iteration_slice ghost_layers = config.ghost_layers - skip_independence_check = config.skip_independence_check - allow_double_writes = config.allow_double_writes + fields_written = assignments.bound_fields + fields_read = assignments.free_fields + + split_groups = () + if 'split_groups' in assignments.simplification_hints: + split_groups = assignments.simplification_hints['split_groups'] + assignments = assignments.all_assignments # TODO: try to delete def type_symbol(term): @@ -56,12 +55,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, config: CreateKernelConf else: raise ValueError("Term has to be field access or symbol") - check = KernelConstraintsCheck(check_independence_condition=skip_independence_check, - check_double_write_condition=allow_double_writes) - check.visit(assignments) - - fields_read = check.fields_read - fields_written = check.fields_written + # TODO move add_types to create_domain_kernel or create_kernel assignments = add_types(assignments, config) @@ -78,7 +72,6 @@ def create_kernel(assignments: AssignmentOrAstNodeList, config: CreateKernelConf ast_node = KernelFunction(loop_node, Target.CPU, Backend.C, compile_function=make_python_function, ghost_layers=ghost_layer_info, function_name=function_name, assignments=assignments) - # TODO move split groups here if split_groups: typed_split_groups = [[type_symbol(s) for s in split_group] for split_group in split_groups] split_inner_loop(ast_node, typed_split_groups) @@ -100,7 +93,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, config: CreateKernelConf return ast_node -def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, function_name="kernel", +def create_indexed_kernel(assignments: AssignmentCollection, index_fields, function_name="kernel", type_info=None, coordinate_names=('x', 'y', 'z')) -> KernelFunction: """ Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with diff --git a/pystencils/kernel_contrains_check.py b/pystencils/kernel_contrains_check.py index 74f56666651563d5d49df141819cc44a64835889..42204822a29321d611189ae9f9afd021344a5305 100644 --- a/pystencils/kernel_contrains_check.py +++ b/pystencils/kernel_contrains_check.py @@ -4,6 +4,7 @@ from typing import Union import sympy as sp from sympy.codegen import Assignment +from pystencils.simp import AssignmentCollection from pystencils import astnodes as ast, TypedSymbol from pystencils.field import Field from pystencils.transformations import NestedScopes @@ -41,7 +42,9 @@ class KernelConstraintsCheck: self.check_double_write_condition = check_double_write_condition def visit(self, obj): - if isinstance(obj, list) or isinstance(obj, tuple): + if isinstance(obj, AssignmentCollection): + [self.visit(e) for e in obj.all_assignments] + elif isinstance(obj, list) or isinstance(obj, tuple): [self.visit(e) for e in obj] elif isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)): self.process_assignment(obj) diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py index 5ee7817f170beb24688aca6ce3940e63ae3a0b89..2635eb035e7b3ff7ef84fec071bdbd8ef96019d8 100644 --- a/pystencils/kernelcreation.py +++ b/pystencils/kernelcreation.py @@ -12,7 +12,7 @@ from pystencils.enums import Target, Backend from pystencils.field import Field, FieldType from pystencils.gpucuda.indexing import indexing_creator_from_params from pystencils.simp.assignment_collection import AssignmentCollection -from pystencils.simp.simplifications import apply_sympy_optimisations +from pystencils.kernel_contrains_check import KernelConstraintsCheck from pystencils.simplificationfactory import create_simplification_strategy from pystencils.stencil import direction_string_to_offset, inverse_direction_string from pystencils.transformations import ( @@ -62,6 +62,8 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol if isinstance(assignments, Assignment): assignments = [assignments] assert assignments, "Assignments must not be empty!" + if isinstance(assignments, list): + assignments = AssignmentCollection(assignments) if config.index_fields: return create_indexed_kernel(assignments, config=config) @@ -69,7 +71,7 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol return create_domain_kernel(assignments, config=config) -def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelConfig): +def create_domain_kernel(assignments: AssignmentCollection, *, config: CreateKernelConfig): """ Creates abstract syntax tree (AST) of kernel, using a list of update equations. @@ -82,6 +84,7 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC can be compiled with through its 'compile()' member Example: + # TODO change to assignment collection >>> import pystencils as ps >>> import numpy as np >>> s, d = ps.fields('s, d: [2D]') @@ -98,6 +101,7 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC [0., 4., 4., 4., 0.], [0., 0., 0., 0., 0.]]) """ + # --- applying first default simplifications try: if config.default_assignment_simplifications and isinstance(assignments, AssignmentCollection): @@ -107,20 +111,18 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC warnings.warn(f"It was not possible to apply the default pystencils optimisations to the " f"AssignmentCollection due to the following problem :{e}") - # TODO: shift to CPU - # ---- Normalizing parameters - split_groups = () - if isinstance(assignments, AssignmentCollection): - if 'split_groups' in assignments.simplification_hints: - split_groups = assignments.simplification_hints['split_groups'] - assignments = assignments.all_assignments + assignments.evaluate_terms() - try: - if config.default_assignment_simplifications: - assignments = apply_sympy_optimisations(assignments) - except Exception as e: - warnings.warn(f"It was not possible to apply the default SymPy optimisations to the " - f"Assignments due to the following problem :{e}") + # --- eval + # TODO split apply_sympy_optimisations and do the eval here + + # FUTURE WORK from here we shouldn't NEED sympy + # --- check constrains + check = KernelConstraintsCheck(check_independence_condition=config.skip_independence_check, + check_double_write_condition=config.allow_double_writes) + check.visit(assignments) + assert assignments.bound_fields == check.fields_written, f'WTF' + assert assignments.rhs_fields == check.fields_read, f'WTF' # ---- Creating ast ast = None @@ -128,7 +130,7 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC if config.backend == Backend.C: from pystencils.cpu import add_openmp, create_kernel # TODO: data type keyword should be unified to data_type - ast = create_kernel(assignments, config=config, split_groups=split_groups) + ast = create_kernel(assignments, config=config) for optimization in config.cpu_prepend_optimizations: optimization(ast) omp_collapse = None @@ -170,7 +172,7 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC return ast -def create_indexed_kernel(assignments: List[Assignment], *, config: CreateKernelConfig): +def create_indexed_kernel(assignments: AssignmentCollection, *, config: CreateKernelConfig): """ Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling. @@ -212,6 +214,8 @@ import pystencils.kernel_creation_config >>> import pystencils as ps [0. , 0. , 0. , 4.3, 0. ], [0. , 0. , 0. , 0. , 0. ]]) """ + # TODO do this in backends + assignments = assignments.all_assignments ast = None if config.target == Target.CPU and config.backend == Backend.C: from pystencils.cpu import add_openmp, create_indexed_kernel diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index f493b08e931bf69f0b9713255c0eee87e9bace66..7309e7d87bb3402ece30f5008baed0514a99b770 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -3,6 +3,7 @@ from copy import copy from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union import sympy as sp +from sympy.codegen.rewriting import ReplaceOptim, optimize import pystencils from pystencils.assignment import Assignment @@ -107,16 +108,21 @@ class AssignmentCollection: return self.subexpressions + self.main_assignments @property - def free_symbols(self) -> Set[sp.Symbol]: - """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment.""" - free_symbols = set() + def rhs_symbols(self) -> Set[sp.Symbol]: + """All symbols used in the assignment collection, which occur on the rhs of any assignment.""" + rhs_symbols = set() for eq in self.all_assignments: if isinstance(eq, Assignment): - free_symbols.update(eq.rhs.atoms(sp.Symbol)) + rhs_symbols.update(eq.rhs.atoms(sp.Symbol)) elif isinstance(eq, pystencils.astnodes.Node): - free_symbols.update(eq.undefined_symbols) + rhs_symbols.update(eq.undefined_symbols) - return free_symbols - self.bound_symbols + return rhs_symbols + + @property + def free_symbols(self) -> Set[sp.Symbol]: + """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment.""" + return self.rhs_symbols - self.bound_symbols @property def bound_symbols(self) -> Set[sp.Symbol]: @@ -132,10 +138,15 @@ class AssignmentCollection: assignment.symbols_defined for assignment in self.all_assignments if isinstance(assignment, pystencils.astnodes.Node) ] - ) + ) return bound_symbols_set + @property + def rhs_fields(self): + """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment.""" + return {s.field for s in self.rhs_symbols if hasattr(s, 'field')} + @property def free_fields(self): """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment.""" @@ -152,7 +163,7 @@ class AssignmentCollection: return (set( [assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)] ).union(*[assignment.symbols_defined for assignment in self.main_assignments if isinstance( - assignment, pystencils.astnodes.Node)] + assignment, pystencils.astnodes.Node)] )) @property @@ -214,6 +225,7 @@ class AssignmentCollection: return {s: func(*args, **kwargs) for s, func in lambdas.items()} return f + # ---------------------------- Creating new modified collections --------------------------------------------------- def copy(self, @@ -353,10 +365,26 @@ class AssignmentCollection: new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments] return self.copy(new_assignment, kept_subexpressions) + def evaluate_terms(self): + + evaluate_constant_terms = ReplaceOptim( + lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer, + lambda p: p.evalf()) + + sympy_optimisations = [evaluate_constant_terms] + + self.subexpressions = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations)) + if hasattr(a, 'lhs') + else a for a in self.subexpressions] + + self.main_assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations)) + if hasattr(a, 'lhs') + else a for a in self.main_assignments] # ----------------------------------------- Display and Printing ------------------------------------------------- def _repr_html_(self): """Interface to Jupyter notebook, to display as a nicely formatted HTML table""" + def make_html_equation_table(equations): no_border = 'style="border:none"' html_table = '<table style="border:none; width: 100%; ">' diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py index c36ba558e0a788a9d1751f2057d4647104d7e5a1..955f6b73d37421c8d51cb04b871945571ffce7bd 100644 --- a/pystencils/simp/simplifications.py +++ b/pystencils/simp/simplifications.py @@ -3,12 +3,10 @@ from typing import Callable, List, Sequence, Union from collections import defaultdict import sympy as sp -from sympy.codegen.rewriting import optims_c99, optimize -from sympy.codegen.rewriting import ReplaceOptim from pystencils.assignment import Assignment -from pystencils.astnodes import Node, SympyAssignment -from pystencils.field import Field, Field +from pystencils.astnodes import Node +from pystencils.field import Field from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect @@ -227,22 +225,29 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]): return f -def apply_sympy_optimisations(assignments): - """ Evaluates constant expressions (e.g. :math:`\\sqrt{3}` will be replaced by its floating point representation) - and applies the default sympy optimisations. See sympy.codegen.rewriting - """ - - # Evaluates all constant terms - evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer, - lambda p: p.evalf()) - - sympy_optimisations = [evaluate_constant_terms] + list(optims_c99) - - assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations)) - if hasattr(a, 'lhs') - else a for a in assignments] - assignments_nodes = [a.atoms(SympyAssignment) for a in assignments] - for a in chain.from_iterable(assignments_nodes): - a.optimize(sympy_optimisations) - - return assignments +# TODO Markus +# TODO: make this really work for Assignmentcollections +# TODO: this function should ONLY evaluate +# TODO: do the optims_c99 elsewhere optionally +# def apply_sympy_optimisations(ac: AssignmentCollection): +# """ Evaluates constant expressions (e.g. :math:`\\sqrt{3}` will be replaced by its floating point representation) +# and applies the default sympy optimisations. See sympy.codegen.rewriting +# """ +# +# # Evaluates all constant terms +# +# assignments = ac.all_assignments +# +# evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer, +# lambda p: p.evalf()) +# +# sympy_optimisations = [evaluate_constant_terms] + list(optims_c99) +# +# assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations)) +# if hasattr(a, 'lhs') +# else a for a in assignments] +# assignments_nodes = [a.atoms(SympyAssignment) for a in assignments] +# for a in chain.from_iterable(assignments_nodes): +# a.optimize(sympy_optimisations) +# +# return AssignmentCollection(assignments) diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index 6c51691ff8a56a0a27829ccd66e4c409cf461545..df36c0d9175b4b01572e94570c9a4ce0cd41e5c9 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -5,6 +5,11 @@ import logging import numpy as np import sympy as sp +from sympy import Piecewise +from sympy.functions.elementary.piecewise import ExprCondPair +from sympy.codegen import Assignment +from sympy.logic.boolalg import BooleanFunction +from sympy.logic.boolalg import BooleanAtom from pystencils import astnodes as ast from pystencils.field import Field @@ -13,8 +18,6 @@ from pystencils.typing.utilities import get_type_of_expression, collate_types from pystencils.typing.cast_functions import CastFunc from pystencils.typing.typed_sympy import TypedSymbol from pystencils.utils import ContextVar -from sympy.codegen import Assignment -from sympy.logic.boolalg import BooleanFunction class TypeAdder: @@ -93,6 +96,8 @@ class TypeAdder: def figure_out_type(self, expr) -> Tuple[Any, BasicType]: # TODO or abstract type? vector type? # Trivial cases from pystencils.field import Field + import pystencils.integer_functions + from pystencils.bit_masks import flag_cond if isinstance(expr, Field.Access): return expr, expr.dtype @@ -104,24 +109,56 @@ class TypeAdder: elif isinstance(expr, np.generic): assert False, f'Why do we have a np.generic in rhs???? {expr}' elif isinstance(expr, sp.Number): - if expr.is_Float: - data_type = self.default_number_float.get() - elif expr.is_Integer: + if expr.is_Integer: data_type = self.default_number_int.get() + elif expr.is_Float or expr.is_Rational: + data_type = self.default_number_float.get() else: assert False, f'{sp.Number} is neither Float nor Integer' return expr, data_type - # TODO add everything in between + elif isinstance(expr, BooleanAtom): + return expr, BasicType('bool') + elif isinstance(expr, sp.Equality): + new_args = [self.figure_out_type(arg)[0] for arg in expr.args] + new_eq = sp.Equality(*new_args) + return new_eq, BasicType('bool') + elif isinstance(expr, CastFunc): + raise NotImplementedError('CastFunc') + elif isinstance(expr, BooleanFunction) or \ + type(expr, ) in pystencils.integer_functions.__dict__.values(): + raise NotImplementedError('BooleanFunction') + elif isinstance(expr, flag_cond): + # do not process the arguments to the bit shift - they must remain integers + raise NotImplementedError('flag_cond') elif isinstance(expr, sp.Mul): + raise NotImplementedError('sp.Mul') # TODO can we ignore this and move it to general expr handling, i.e. removing Mul? - args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)] - return None # TODO collate types + # args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)] elif isinstance(expr, sp.Indexed): - self.apply_type(expr, BasicType('uintp')) # TODO double check - return None + raise NotImplementedError('sp.Indexed') elif isinstance(expr, sp.Pow): - # TODO sp.Pow should know a type - return None # TODO + args_types = [self.figure_out_type(arg) for arg in expr.args] + collated_type = collate_types([t for _, t in args_types]) + return expr, collated_type + elif isinstance(expr, ExprCondPair): + expr_expr, expr_type = self.figure_out_type(expr.expr) + condition, condition_type = self.figure_out_type(expr.cond) + if condition_type != BasicType('bool'): + logging.warning(f'Condition "{condition}" is of type "{condition_type}" and not "bool"') + return expr.func(expr_expr, condition), expr_type + elif isinstance(expr, Piecewise): + args_types = [self.figure_out_type(arg) for arg in expr.args] + collated_type = collate_types([t for _, t in args_types]) + new_args = [] + for a, t in args_types: + if t != collated_type: + if isinstance(a, ExprCondPair): + new_args.append(a.func(CastFunc(a.expr, collated_type), a.cond)) + else: + new_args.append(CastFunc(a, collated_type)) + else: + new_args.append(a) + return expr.func(*new_args) if new_args else expr, collated_type else: args_types = [self.figure_out_type(arg) for arg in expr.args] collated_type = collate_types([t for _, t in args_types]) @@ -190,6 +227,6 @@ class TypeAdder: def process_lhs(self, lhs: Union[Field.Access, TypedSymbol, sp.Symbol]): if not isinstance(lhs, (Field.Access, TypedSymbol)): - return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name]) + return TypedSymbol(lhs.name, self.type_for_symbol[lhs.name]) else: return lhs diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py index 795e5c26c91f7a6ac2294650c4050c5822b226f5..d1b0d33cb84910d56679acf7269ffa55ae86f1e3 100644 --- a/pystencils_tests/test_types.py +++ b/pystencils_tests/test_types.py @@ -84,7 +84,7 @@ def test_mixed_add(dtype1, dtype2): assert test_f[0] == constant+constant -# TODO redo following tests +# TODO vector def test_collation(): double_type = BasicType('float64') float_type = BasicType('float32') @@ -95,8 +95,9 @@ def test_collation(): assert collate_types([double4_type, float4_type]) == double4_type +# TODO this def test_vector_type(): - double_type = BasicType("double") + double_type = BasicType('float64') float_type = BasicType('float32') double4_type = VectorType(double_type, 4) float4_type = VectorType(float_type, 4) @@ -147,36 +148,33 @@ def test_assumptions(): assert (x.shape[0] + 1).is_real -def test_sqrt_of_integer(): +@pytest.mark.parametrize('dtype', ('float64', 'float32')) +def test_sqrt_of_integer(dtype): """Regression test for bug where sqrt(3) was classified as integer""" - f = ps.fields("f: [1D]") - tmp = sp.symbols("tmp") - - assignments = [ps.Assignment(tmp, sp.sqrt(3)), - ps.Assignment(f[0], tmp)] - arr_double = np.array([1], dtype=np.float64) - kernel = ps.create_kernel(assignments).compile() - kernel(f=arr_double) - assert 1.7 < arr_double[0] < 1.8 - - f = ps.fields("f: float32[1D]") - tmp = sp.symbols("tmp") + f = ps.fields(f'f: {dtype}[1D]') + tmp = sp.symbols('tmp') assignments = [ps.Assignment(tmp, sp.sqrt(3)), ps.Assignment(f[0], tmp)] - arr_single = np.array([1], dtype=np.float32) - config = pystencils.config.CreateKernelConfig(data_type="float32") - kernel = ps.create_kernel(assignments, config=config).compile() - kernel(f=arr_single) + arr = np.array([1], dtype=dtype) + config = pystencils.config.CreateKernelConfig(data_type=dtype) - code = ps.get_code_str(kernel.ast) + ast = ps.create_kernel(assignments, config=config) + kernel = ast.compile() + kernel(f=arr) + assert 1.7 < arr[0] < 1.8 - assert "1.7320508075688772f" in code - assert 1.7 < arr_single[0] < 1.8 + code = ps.get_code_str(ast) + constant = '1.7320508075688772f' + if dtype == 'float32': + assert constant in code + else: + assert constant not in code -def test_integer_comparision(): - f = ps.fields("f [2D]") +@pytest.mark.parametrize('dtype', ('float64', 'float32')) +def test_integer_comparision(dtype): + f = ps.fields(f"f: {dtype}[2D]") d = sp.Symbol("dir") ur = ps.Assignment(f[0, 0], sp.Piecewise((0, sp.Equality(d, 1)), (f[0, 0], True))) @@ -184,9 +182,17 @@ def test_integer_comparision(): ast = ps.create_kernel(ur) code = ps.get_code_str(ast) - assert "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (0.0): (_data_f_00[_stride_f_1*ctr_1]));" in code + print(code) + # There should be an explicit cast for the integer zero to the type of the field on the rhs + if dtype == 'float64': + t = "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (((double)(0))): (_data_f_00[_stride_f_1*ctr_1]));" + else: + t = "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (((float)(0))): (_data_f_00[_stride_f_1*ctr_1]));" + + assert t in code +# TODO this def test_Basic_data_type(): assert typed_symbols(("s", "f"), np.uint) == typed_symbols("s, f", np.uint) t_symbols = typed_symbols(("s", "f"), np.uint) @@ -223,6 +229,7 @@ def test_Basic_data_type(): assert TypedSymbol("s", np.uint).reversed == TypedSymbol("s", np.uint) +# TODO this def test_cast_func(): assert CastFunc(TypedSymbol("s", np.uint), np.int64).canonical == TypedSymbol("s", np.uint).canonical @@ -235,6 +242,7 @@ def test_pointer_arithmetic_func(): assert PointerArithmeticFunc(TypedSymbol("s", np.uint), 1).canonical == TypedSymbol("s", np.uint).canonical +# TODO this def test_division(): f = ps.fields('f(10): float32[2D]') m, tau = sp.symbols("m, tau") @@ -248,6 +256,7 @@ def test_division(): assert "1.0f" in code +# TODO this def test_pow(): f = ps.fields('f(10): float32[2D]') m, tau = sp.symbols("m, tau")