From 89fe68a693ea69f2c004f0651dd532f6506965bf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de>
Date: Wed, 16 Feb 2022 15:50:31 +0100
Subject: [PATCH] Removed inster_casts

---
 pystencils/typing/__init__.py  |  4 +-
 pystencils/typing/utilities.py | 95 ----------------------------------
 2 files changed, 2 insertions(+), 97 deletions(-)

diff --git a/pystencils/typing/__init__.py b/pystencils/typing/__init__.py
index 5bb560d10..ae4483da4 100644
--- a/pystencils/typing/__init__.py
+++ b/pystencils/typing/__init__.py
@@ -5,7 +5,7 @@ from pystencils.typing.types import (is_supported_type, numpy_name_to_c, Abstrac
 from pystencils.typing.typed_sympy import (assumptions_from_dtype, TypedSymbol, FieldStrideSymbol, FieldShapeSymbol,
                                            FieldPointerSymbol)
 from pystencils.typing.utilities import (typed_symbols, get_base_type, result_type, collate_types,
-                                         get_type_of_expression, insert_casts, get_next_parent_of_type, parents_of_type)
+                                         get_type_of_expression, get_next_parent_of_type, parents_of_type)
 
 
 __all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCastFunc', 'PointerArithmeticFunc',
@@ -13,4 +13,4 @@ __all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCast
            'VectorType', 'PointerType', 'StructType', 'create_type',
            'assumptions_from_dtype', 'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol',
            'typed_symbols', 'get_base_type', 'result_type', 'collate_types',
-           'get_type_of_expression', 'insert_casts', 'get_next_parent_of_type', 'parents_of_type']
+           'get_type_of_expression', 'get_next_parent_of_type', 'parents_of_type']
diff --git a/pystencils/typing/utilities.py b/pystencils/typing/utilities.py
index 6a43c7984..7d37e3886 100644
--- a/pystencils/typing/utilities.py
+++ b/pystencils/typing/utilities.py
@@ -211,101 +211,6 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
     sp.Basic.__reduce_ex__ = basic_reduce_ex
 
 
-def insert_casts(node):
-    """Checks the types and inserts casts and pointer arithmetic where necessary.
-
-    Args:
-        node: the head node of the ast
-
-    Returns:
-        modified AST
-    """
-    from pystencils.astnodes import SympyAssignment, ResolvedFieldAccess, LoopOverCoordinate, Block
-
-    def cast(zipped_args_types, target_dtype):
-        """
-        Adds casts to the arguments if their type differs from the target type
-        :param zipped_args_types: a zipped list of args and types
-        :param target_dtype: The target data type
-        :return: args with possible casts
-        """
-        casted_args = []
-        for argument, data_type in zipped_args_types:
-            if data_type.numpy_dtype != target_dtype.numpy_dtype:  # ignoring const
-                casted_args.append(CastFunc(argument, target_dtype))
-            else:
-                casted_args.append(argument)
-        return casted_args
-
-    def pointer_arithmetic(expr_args):
-        """
-        Creates a valid pointer arithmetic function
-        :param expr_args: Arguments of the add expression
-        :return: pointer_arithmetic_func
-        """
-        pointer = None
-        new_args = []
-        for arg, data_type in expr_args:
-            if data_type.func is PointerType:
-                assert pointer is None
-                pointer = arg
-        for arg, data_type in expr_args:
-            if arg != pointer:
-                assert data_type.is_int() or data_type.is_uint()
-                new_args.append(arg)
-        new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args
-        return PointerArithmeticFunc(pointer, new_args)
-
-    if isinstance(node, sp.AtomicExpr) or isinstance(node, CastFunc):
-        return node
-    args = []
-    for arg in node.args:
-        args.append(insert_casts(arg))
-    # TODO indexed, LoopOverCoordinate
-    if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
-        # TODO optimize pow, don't cast integer on double
-        types = [get_type_of_expression(arg) for arg in args]
-        assert len(types) > 0
-        # Never ever, ever collate to float type for boolean functions!
-        target = collate_types(types, forbid_collation_to_float=isinstance(node.func, BooleanFunction))
-        zipped = list(zip(args, types))
-        if target.func is PointerType:
-            assert node.func is sp.Add
-            return pointer_arithmetic(zipped)
-        else:
-            return node.func(*cast(zipped, target))
-    elif node.func is SympyAssignment:
-        lhs = args[0]
-        rhs = args[1]
-        target = get_type_of_expression(lhs)
-        if target.func is PointerType:
-            return node.func(*args)  # TODO fix, not complete
-        else:
-            return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
-    elif node.func is ResolvedFieldAccess:
-        return node
-    elif node.func is Block:
-        for old_arg, new_arg in zip(node.args, args):
-            node.replace(old_arg, new_arg)
-        return node
-    elif node.func is LoopOverCoordinate:
-        for old_arg, new_arg in zip(node.args, args):
-            node.replace(old_arg, new_arg)
-        return node
-    elif node.func is sp.Piecewise:
-        expressions = [expr for (expr, _) in args]
-        types = [get_type_of_expression(expr) for expr in expressions]
-        target = collate_types(types)
-        zipped = list(zip(expressions, types))
-        casted_expressions = cast(zipped, target)
-        args = [
-            arg.func(*[expr, arg.cond])
-            for (arg, expr) in zip(args, casted_expressions)
-        ]
-
-    return node.func(*args)
-
-
 def get_next_parent_of_type(node, parent_type):
     """Returns the next parent node of given type or None, if root is reached.
 
-- 
GitLab