From 89fe68a693ea69f2c004f0651dd532f6506965bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de> Date: Wed, 16 Feb 2022 15:50:31 +0100 Subject: [PATCH] Removed inster_casts --- pystencils/typing/__init__.py | 4 +- pystencils/typing/utilities.py | 95 ---------------------------------- 2 files changed, 2 insertions(+), 97 deletions(-) diff --git a/pystencils/typing/__init__.py b/pystencils/typing/__init__.py index 5bb560d10..ae4483da4 100644 --- a/pystencils/typing/__init__.py +++ b/pystencils/typing/__init__.py @@ -5,7 +5,7 @@ from pystencils.typing.types import (is_supported_type, numpy_name_to_c, Abstrac from pystencils.typing.typed_sympy import (assumptions_from_dtype, TypedSymbol, FieldStrideSymbol, FieldShapeSymbol, FieldPointerSymbol) from pystencils.typing.utilities import (typed_symbols, get_base_type, result_type, collate_types, - get_type_of_expression, insert_casts, get_next_parent_of_type, parents_of_type) + get_type_of_expression, get_next_parent_of_type, parents_of_type) __all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCastFunc', 'PointerArithmeticFunc', @@ -13,4 +13,4 @@ __all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCast 'VectorType', 'PointerType', 'StructType', 'create_type', 'assumptions_from_dtype', 'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol', 'typed_symbols', 'get_base_type', 'result_type', 'collate_types', - 'get_type_of_expression', 'insert_casts', 'get_next_parent_of_type', 'parents_of_type'] + 'get_type_of_expression', 'get_next_parent_of_type', 'parents_of_type'] diff --git a/pystencils/typing/utilities.py b/pystencils/typing/utilities.py index 6a43c7984..7d37e3886 100644 --- a/pystencils/typing/utilities.py +++ b/pystencils/typing/utilities.py @@ -211,101 +211,6 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109: sp.Basic.__reduce_ex__ = basic_reduce_ex -def insert_casts(node): - """Checks the types and inserts casts and pointer arithmetic where necessary. - - Args: - node: the head node of the ast - - Returns: - modified AST - """ - from pystencils.astnodes import SympyAssignment, ResolvedFieldAccess, LoopOverCoordinate, Block - - def cast(zipped_args_types, target_dtype): - """ - Adds casts to the arguments if their type differs from the target type - :param zipped_args_types: a zipped list of args and types - :param target_dtype: The target data type - :return: args with possible casts - """ - casted_args = [] - for argument, data_type in zipped_args_types: - if data_type.numpy_dtype != target_dtype.numpy_dtype: # ignoring const - casted_args.append(CastFunc(argument, target_dtype)) - else: - casted_args.append(argument) - return casted_args - - def pointer_arithmetic(expr_args): - """ - Creates a valid pointer arithmetic function - :param expr_args: Arguments of the add expression - :return: pointer_arithmetic_func - """ - pointer = None - new_args = [] - for arg, data_type in expr_args: - if data_type.func is PointerType: - assert pointer is None - pointer = arg - for arg, data_type in expr_args: - if arg != pointer: - assert data_type.is_int() or data_type.is_uint() - new_args.append(arg) - new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args - return PointerArithmeticFunc(pointer, new_args) - - if isinstance(node, sp.AtomicExpr) or isinstance(node, CastFunc): - return node - args = [] - for arg in node.args: - args.append(insert_casts(arg)) - # TODO indexed, LoopOverCoordinate - if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge): - # TODO optimize pow, don't cast integer on double - types = [get_type_of_expression(arg) for arg in args] - assert len(types) > 0 - # Never ever, ever collate to float type for boolean functions! - target = collate_types(types, forbid_collation_to_float=isinstance(node.func, BooleanFunction)) - zipped = list(zip(args, types)) - if target.func is PointerType: - assert node.func is sp.Add - return pointer_arithmetic(zipped) - else: - return node.func(*cast(zipped, target)) - elif node.func is SympyAssignment: - lhs = args[0] - rhs = args[1] - target = get_type_of_expression(lhs) - if target.func is PointerType: - return node.func(*args) # TODO fix, not complete - else: - return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target)) - elif node.func is ResolvedFieldAccess: - return node - elif node.func is Block: - for old_arg, new_arg in zip(node.args, args): - node.replace(old_arg, new_arg) - return node - elif node.func is LoopOverCoordinate: - for old_arg, new_arg in zip(node.args, args): - node.replace(old_arg, new_arg) - return node - elif node.func is sp.Piecewise: - expressions = [expr for (expr, _) in args] - types = [get_type_of_expression(expr) for expr in expressions] - target = collate_types(types) - zipped = list(zip(expressions, types)) - casted_expressions = cast(zipped, target) - args = [ - arg.func(*[expr, arg.cond]) - for (arg, expr) in zip(args, casted_expressions) - ] - - return node.func(*args) - - def get_next_parent_of_type(node, parent_type): """Returns the next parent node of given type or None, if root is reached. -- GitLab