From 0c9e9fcdff09fc8ed4146ca1cb8cd247975b3d48 Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Fri, 10 Dec 2021 14:17:07 +0100 Subject: [PATCH] Support min max and conditional Field Access --- pystencils/backends/cbackend.py | 2 - pystencils/config.py | 1 + pystencils/typing/leaf_typing.py | 65 +++++++++++-------- pystencils_tests/test_Min_Max.py | 65 +++++++++++++++++-- .../test_conditional_field_access.py | 30 ++++----- pystencils_tests/test_types.py | 6 +- 6 files changed, 113 insertions(+), 56 deletions(-) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 4aa1d0964..5b27f7461 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -501,8 +501,6 @@ class CustomSympyPrinter(CCodePrinter): return f"({self._print(1 / sp.sqrt(expr.args[0]))})" elif isinstance(expr, sp.Abs): return f"abs({self._print(expr.args[0])})" - elif isinstance(expr, sp.Max): - return self._print(expr) elif isinstance(expr, sp.Mod): if expr.args[0].is_integer and expr.args[1].is_integer: return f"({self._print(expr.args[0])} % {self._print(expr.args[1])})" diff --git a/pystencils/config.py b/pystencils/config.py index 97f2e2d8a..936a92cf2 100644 --- a/pystencils/config.py +++ b/pystencils/config.py @@ -31,6 +31,7 @@ class CreateKernelConfig: """ # TODO: config should check that the datatype is a Numpy type # TODO: check for the python types and issue warnings + # TODO: QoL default_number_float and default_number_int should be data_type if they are not specified by the user data_type: Union[str, Dict[str, BasicType]] = 'float64' """ Data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name to type diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index 1c84ffd65..04bfacbf4 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -1,6 +1,5 @@ from collections import namedtuple, defaultdict -from copy import copy -from typing import Union, Dict, Tuple, Any +from typing import Union, Tuple, Any import logging import numpy as np @@ -14,8 +13,9 @@ from sympy.logic.boolalg import BooleanFunction from sympy.logic.boolalg import BooleanAtom from pystencils import astnodes as ast +from pystencils.functions import DivFunc from pystencils.field import Field -from pystencils.typing.types import AbstractType, BasicType, create_type +from pystencils.typing.types import BasicType, create_type from pystencils.typing.utilities import get_type_of_expression, collate_types from pystencils.typing.cast_functions import CastFunc, BooleanCastFunc from pystencils.typing.typed_sympy import TypedSymbol @@ -40,9 +40,9 @@ class TypeAdder: """ FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index']) - def __init__(self, type_for_symbol: Dict[str, BasicType], default_number_float: BasicType, + def __init__(self, type_for_symbol: defaultdict[str, BasicType], default_number_float: BasicType, default_number_int: BasicType): - self.type_for_symbol = ContextVar(type_for_symbol) + self.type_for_symbol = type_for_symbol self.default_number_float = ContextVar(default_number_float) self.default_number_int = ContextVar(default_number_int) @@ -72,13 +72,16 @@ class TypeAdder: def process_assignment(self, assignment: Union[sp.Eq, ast.SympyAssignment, Assignment]) -> ast.SympyAssignment: # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1 new_rhs, rhs_type = self.figure_out_type(assignment.rhs) - # TODO: - dt = copy(rhs_type) # The copy is necessary because BasicType has sympy shinanigans - dd = defaultdict(lambda: BasicType(dt)) - dd.update(self.type_for_symbol.get()) - with self.type_for_symbol(dd): - new_lhs, lhs_type = self.figure_out_type(assignment.lhs) - # TODO add symbol to dict with type if defined! + + lhs = assignment.lhs + if not isinstance(lhs, (Field.Access, TypedSymbol)): + if isinstance(lhs, sp.Symbol): + self.type_for_symbol[lhs.name] = rhs_type + else: + raise ValueError(f'Lhs: `{lhs}` is not a subtype of sp.Symbol') + new_lhs, lhs_type = self.figure_out_type(lhs) + assert isinstance(new_lhs, (Field.Access, TypedSymbol)) + if lhs_type != rhs_type: logging.warning(f'Lhs"{new_lhs} of type "{lhs_type}" is assigned with a different datatype ' f'rhs: "{new_rhs}" of type "{rhs_type}".') @@ -89,7 +92,8 @@ class TypeAdder: # Type System Specification # - Defined Types: TypedSymbol, Field, Field.Access, ...? # - Indexed: always unsigned_integer64 - # - Undefined Types: Symbol - Is specified in Config in the dict or as 'default_type' + # - Undefined Types: Symbol + # - Is specified in Config in the dict or as 'default_type' or behaves like `auto` in the case of lhs. # - Constants/Numbers: Are either integer or floating. The precision and sign is specified via config # - Example: 1.4 config:float32 -> float32 # - Expressions deduce types from arguments @@ -100,7 +104,7 @@ class TypeAdder: # Possible Problems - Do we need to support this? # - Mixture in expression with int and float # - Mixture in expression with uint64 and sint64 - + # TODO: Lowest log level should log all casts ----> cast factory, make cast should contain logging def figure_out_type(self, expr) -> Tuple[Any, BasicType]: # TODO or abstract type? vector type? # Trivial cases from pystencils.field import Field @@ -113,7 +117,7 @@ class TypeAdder: elif isinstance(expr, TypedSymbol): return expr, expr.dtype elif isinstance(expr, sp.Symbol): - t = TypedSymbol(expr.name, self.type_for_symbol.get()[expr.name]) # TODO with or without name + t = TypedSymbol(expr.name, self.type_for_symbol[expr.name]) # TODO with or without name return t, t.dtype elif isinstance(expr, np.generic): assert False, f'Why do we have a np.generic in rhs???? {expr}' @@ -139,6 +143,22 @@ class TypeAdder: elif isinstance(expr, CastFunc): new_expr, _ = self.figure_out_type(expr.expr) return expr.func(*[new_expr, expr.dtype]), expr.dtype + elif isinstance(expr, ast.ConditionalFieldAccess): + access, access_type = self.figure_out_type(expr.access) + value, value_type = self.figure_out_type(expr.outofbounds_value) + condition, condition_type = self.figure_out_type(expr.outofbounds_condition) + assert condition_type == bool_type + collated_type = collate_types([access_type, value_type]) + if collated_type == access_type: + new_access = access + else: + logging.warning(f"In {expr} the Field Access had to be casted to {collated_type}. This is " + f"probably due to a type missmatch of the Field and the value of " + f"ConditionalFieldAccess") + new_access = CastFunc(access, collated_type) + + new_value = value if value_type == collated_type else CastFunc(value, collated_type) + return expr.func(new_access, condition, new_value), collated_type elif isinstance(expr, BooleanFunction): args_types = [self.figure_out_type(a) for a in expr.args] new_args = [a if t.dtype_eq(bool_type) else BooleanCastFunc(a, bool_type) for a, t in args_types] @@ -177,16 +197,15 @@ class TypeAdder: else: new_args.append(a) return expr.func(*new_args) if new_args else expr, collated_type - else: + elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc)): args_types = [self.figure_out_type(arg) for arg in expr.args] collated_type = collate_types([t for _, t in args_types]) new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types] return expr.func(*new_args) if new_args else expr, collated_type + else: + raise NotImplementedError(f'expr {expr} unknown to typing') - def apply_type(self, expr, data_type: AbstractType): - pass - - def process_expression(self, rhs, type_constants=True): # TODO default_type as parameter + def process_expression(self, rhs, type_constants=True): # TODO DELETE import pystencils.integer_functions from pystencils.bit_masks import flag_cond @@ -242,9 +261,3 @@ class TypeAdder: else: new_args = [self.process_expression(arg, type_constants) for arg in rhs.args] return rhs.func(*new_args) if new_args else rhs - - def process_lhs(self, lhs: Union[Field.Access, TypedSymbol, sp.Symbol]): - if not isinstance(lhs, (Field.Access, TypedSymbol)): - return TypedSymbol(lhs.name, self.type_for_symbol.get()[lhs.name]) - else: - return lhs diff --git a/pystencils_tests/test_Min_Max.py b/pystencils_tests/test_Min_Max.py index c227fbf14..7fb48b18d 100644 --- a/pystencils_tests/test_Min_Max.py +++ b/pystencils_tests/test_Min_Max.py @@ -6,31 +6,37 @@ import pystencils from pystencils.datahandling import create_data_handling +@pytest.mark.parametrize('dtype', ["float64", "float32"]) @pytest.mark.parametrize('sympy_function', [sp.Min, sp.Max]) -def test_max(sympy_function): +def test_max(dtype, sympy_function): dh = create_data_handling(domain_size=(10, 10), periodicity=True) - x = dh.add_array('x', values_per_cell=1) + x = dh.add_array('x', values_per_cell=1, dtype=dtype) dh.fill("x", 0.0, ghost_layers=True) - y = dh.add_array('y', values_per_cell=1) + y = dh.add_array('y', values_per_cell=1, dtype=dtype) dh.fill("y", 1.0, ghost_layers=True) - z = dh.add_array('z', values_per_cell=1) + z = dh.add_array('z', values_per_cell=1, dtype=dtype) dh.fill("z", 2.0, ghost_layers=True) + config = pystencils.CreateKernelConfig(default_number_float=dtype) + # test sp.Max with one argument assignment_1 = pystencils.Assignment(x.center, sympy_function(y.center + 3.3)) - ast_1 = pystencils.create_kernel(assignment_1) + ast_1 = pystencils.create_kernel(assignment_1, config=config) kernel_1 = ast_1.compile() + # pystencils.show_code(ast_1) # test sp.Max with two arguments assignment_2 = pystencils.Assignment(x.center, sympy_function(0.5, y.center - 1.5)) - ast_2 = pystencils.create_kernel(assignment_2) + ast_2 = pystencils.create_kernel(assignment_2, config=config) kernel_2 = ast_2.compile() + # pystencils.show_code(ast_2) # test sp.Max with many arguments assignment_3 = pystencils.Assignment(x.center, sympy_function(z.center, 4.5, y.center - 1.5, y.center + z.center)) - ast_3 = pystencils.create_kernel(assignment_3) + ast_3 = pystencils.create_kernel(assignment_3, config=config) kernel_3 = ast_3.compile() + # pystencils.show_code(ast_3) if sympy_function is sp.Max: results = [4.3, 0.5, 4.5] @@ -43,3 +49,48 @@ def test_max(sympy_function): assert numpy.all(dh.gather_array('x') == results[1]) dh.run_kernel(kernel_3) assert numpy.all(dh.gather_array('x') == results[2]) + + +@pytest.mark.parametrize('dtype', ["int64", 'int32']) +@pytest.mark.parametrize('sympy_function', [sp.Min, sp.Max]) +def test_max_integer(dtype, sympy_function): + dh = create_data_handling(domain_size=(10, 10), periodicity=True) + + x = dh.add_array('x', values_per_cell=1, dtype=dtype) + dh.fill("x", 0, ghost_layers=True) + y = dh.add_array('y', values_per_cell=1, dtype=dtype) + dh.fill("y", 1, ghost_layers=True) + z = dh.add_array('z', values_per_cell=1, dtype=dtype) + dh.fill("z", 2, ghost_layers=True) + + config = pystencils.CreateKernelConfig(default_number_int=dtype) + + # test sp.Max with one argument + assignment_1 = pystencils.Assignment(x.center, sympy_function(y.center + 3)) + ast_1 = pystencils.create_kernel(assignment_1, config=config) + kernel_1 = ast_1.compile() + # pystencils.show_code(ast_1) + + # test sp.Max with two arguments + assignment_2 = pystencils.Assignment(x.center, sympy_function(1, y.center - 1)) + ast_2 = pystencils.create_kernel(assignment_2, config=config) + kernel_2 = ast_2.compile() + # pystencils.show_code(ast_2) + + # test sp.Max with many arguments + assignment_3 = pystencils.Assignment(x.center, sympy_function(z.center, 4, y.center - 1, y.center + z.center)) + ast_3 = pystencils.create_kernel(assignment_3, config=config) + kernel_3 = ast_3.compile() + # pystencils.show_code(ast_3) + + if sympy_function is sp.Max: + results = [4, 1, 4] + else: + results = [4, 0, 0] + + dh.run_kernel(kernel_1) + assert numpy.all(dh.gather_array('x') == results[0]) + dh.run_kernel(kernel_2) + assert numpy.all(dh.gather_array('x') == results[1]) + dh.run_kernel(kernel_3) + assert numpy.all(dh.gather_array('x') == results[2]) diff --git a/pystencils_tests/test_conditional_field_access.py b/pystencils_tests/test_conditional_field_access.py index f39d4767e..f8026c7dc 100644 --- a/pystencils_tests/test_conditional_field_access.py +++ b/pystencils_tests/test_conditional_field_access.py @@ -35,11 +35,11 @@ def add_fixed_constant_boundary_handling(assignments, with_cse): for a in assignment.rhs.atoms(Field.Access) if not a.is_absolute_access })) for assignment in assignments.all_assignments] - subs = [{a: ConditionalFieldAccess(a, is_out_of_bound( - sp.Matrix(a.offsets) + x_vector(ndim), common_shape)) - for a in assignment.rhs.atoms(Field.Access) if not a.is_absolute_access - } for assignment in assignments.all_assignments] - print(subs) + # subs = [{a: ConditionalFieldAccess(a, is_out_of_bound( + # sp.Matrix(a.offsets) + x_vector(ndim), common_shape)) + # for a in assignment.rhs.atoms(Field.Access) if not a.is_absolute_access + # } for assignment in assignments.all_assignments] + # print(subs) if with_cse: safe_assignments = sympy_cse(ps.AssignmentCollection(safe_assignments)) @@ -48,24 +48,20 @@ def add_fixed_constant_boundary_handling(assignments, with_cse): return ps.AssignmentCollection(safe_assignments) +@pytest.mark.parametrize('dtype', ('float64', 'float32')) @pytest.mark.parametrize('with_cse', (False, 'with_cse')) -def test_boundary_check(with_cse): - if not with_cse: - return True +def test_boundary_check(dtype, with_cse): + f, g = ps.fields(f"f, g : {dtype}[2D]") + stencil = ps.Assignment(g[0, 0], (f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4) - f, g = ps.fields("f, g : [2D]") - stencil = ps.Assignment(g[0, 0], - (f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4) - - f_arr = np.random.rand(10, 10) + f_arr = np.random.rand(10, 10).astype(dtype=dtype) g_arr = np.zeros_like(f_arr) - # kernel(f=f_arr, g=g_arr) assignments = add_fixed_constant_boundary_handling(ps.AssignmentCollection([stencil]), with_cse) - print(assignments) - kernel_checked = ps.create_kernel(assignments, ghost_layers=0).compile() - ps.show_code(kernel_checked) + config = ps.CreateKernelConfig(data_type=dtype, default_number_float=dtype, ghost_layers=0) + kernel_checked = ps.create_kernel(assignments, config=config).compile() + # ps.show_code(kernel_checked) # No SEGFAULT, please!! kernel_checked(f=f_arr, g=g_arr) diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py index c55816b05..164d941cf 100644 --- a/pystencils_tests/test_types.py +++ b/pystencils_tests/test_types.py @@ -159,7 +159,8 @@ def test_sqrt_of_integer(dtype): assignments = [ps.Assignment(tmp, sp.sqrt(3)), ps.Assignment(f[0], tmp)] arr = np.array([1], dtype=dtype) - config = pystencils.config.CreateKernelConfig(data_type=dtype) + # TODO Jupyter add auto lhs float/double problem + config = pystencils.config.CreateKernelConfig(data_type=dtype, default_number_float=dtype) ast = ps.create_kernel(assignments, config=config) kernel = ast.compile() @@ -189,9 +190,6 @@ def test_integer_comparision(dtype): t = "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (0.0): (_data_f_00[_stride_f_1*ctr_1]));" else: t = "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (0.0f): (_data_f_00[_stride_f_1*ctr_1]));" - - print(code) - assert t in code -- GitLab