From 21fb888aa743c1f8c3bf117f778566688bb9c1ae Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 15 Mar 2024 10:07:01 +0100
Subject: [PATCH] move integer_functions and fast_approximations into
 sympyextensions module

---
 .../backend/kernelcreation/freeze.py          |  3 +-
 src/pystencils/boundaries/inkernel.py         |  2 +-
 src/pystencils/sympyextensions/astnodes.py    |  2 +-
 .../fast_approximation.py                     |  0
 .../integer_functions.py                      | 66 +++++++------------
 src/pystencils/sympyextensions/math.py        |  2 +-
 src/pystencils/timeloop.py                    |  2 +-
 tests/test_fast_approximation.py              |  2 +-
 tests/test_math_functions.py                  |  2 +-
 tests/test_sympyextensions.py                 |  2 +-
 10 files changed, 33 insertions(+), 50 deletions(-)
 rename src/pystencils/{ => sympyextensions}/fast_approximation.py (100%)
 rename src/pystencils/{ => sympyextensions}/integer_functions.py (65%)

diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index ecdcf2f94..ebaf22812 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -53,9 +53,8 @@ class FreezeExpressions:
 
      - Augmented Assignments
      - AddressOf
-     - Conditionals (+ frontend class)
      - Relations (sp.Relational)
-     - pystencils.integer_functions
+     - pystencils.sympyextensions.integer_functions
      - pystencils.sympyextensions.bit_masks
      - GPU fast approximations (pystencils.fast_approximation)
      - ConditionalFieldAccess
diff --git a/src/pystencils/boundaries/inkernel.py b/src/pystencils/boundaries/inkernel.py
index 7cd9e628b..057704b13 100644
--- a/src/pystencils/boundaries/inkernel.py
+++ b/src/pystencils/boundaries/inkernel.py
@@ -4,7 +4,7 @@ from pystencils.boundaries.boundaryhandling import DEFAULT_FLAG_TYPE
 from pystencils.sympyextensions import TypedSymbol
 from pystencils.types import create_type
 from pystencils.field import Field
-from pystencils.integer_functions import bitwise_and
+from pystencils.sympyextensions.integer_functions import bitwise_and
 
 
 def add_neumann_boundary(eqs, fields, flag_field, boundary_flag="neumann_flag", inverse_flag=False):
diff --git a/src/pystencils/sympyextensions/astnodes.py b/src/pystencils/sympyextensions/astnodes.py
index 8483977d8..4fdc0f612 100644
--- a/src/pystencils/sympyextensions/astnodes.py
+++ b/src/pystencils/sympyextensions/astnodes.py
@@ -4,7 +4,7 @@ import uuid
 from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
 
 import sympy as sp
-from sympy.codegen.ast import Assignment, AugmentedAssignment, AddAugmentedAssignment
+from sympy.codegen.ast import Assignment, AugmentedAssignment
 from sympy.printing.latex import LatexPrinter
 import numpy as np
 
diff --git a/src/pystencils/fast_approximation.py b/src/pystencils/sympyextensions/fast_approximation.py
similarity index 100%
rename from src/pystencils/fast_approximation.py
rename to src/pystencils/sympyextensions/fast_approximation.py
diff --git a/src/pystencils/integer_functions.py b/src/pystencils/sympyextensions/integer_functions.py
similarity index 65%
rename from src/pystencils/integer_functions.py
rename to src/pystencils/sympyextensions/integer_functions.py
index 90e79798d..f9c156971 100644
--- a/src/pystencils/integer_functions.py
+++ b/src/pystencils/sympyextensions/integer_functions.py
@@ -1,9 +1,4 @@
-# TODO #47 move to a module functions
-import numpy as np
 import sympy as sp
-
-from pystencils.sympyextensions import CastFunc
-from pystencils.types import create_type
 from pystencils.sympyextensions import is_integer_sequence
 
 
@@ -11,22 +6,7 @@ class IntegerFunctionTwoArgsMixIn(sp.Function):
     is_integer = True
 
     def __new__(cls, arg1, arg2):
-        args = []
-        for a in (arg1, arg2):
-            if isinstance(a, sp.Number) or isinstance(a, int):
-                args.append(CastFunc(a, create_type("int")))
-            elif isinstance(a, np.generic):
-                args.append(CastFunc(a, a.dtype))
-            else:
-                args.append(a)
-
-        for a in args:
-            try:
-                dtype = get_type_of_expression(a)
-                if not dtype.is_int():
-                    raise ValueError("Argument to integer function is not an int but " + str(dtype))
-            except NotImplementedError:
-                raise ValueError("Integer functions can only be constructed with typed expressions")
+        args = [arg1, arg2]
         return super().__new__(cls, *args)
 
     def _eval_evalf(self, *pargs, **kwargs):
@@ -100,11 +80,12 @@ class modulo_floor(sp.Function):
         else:
             return super().__new__(cls, integer, divisor)
 
-    def to_c(self, print_func):
-        dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
-        assert dtype.is_int()
-        return "({dtype})(({0}) / ({1})) * ({1})".format(print_func(self.args[0]),
-                                                         print_func(self.args[1]), dtype=dtype)
+    #   TODO: Implement this in FreezeExpressions
+    # def to_c(self, print_func):
+    #     dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
+    #     assert dtype.is_int()
+    #     return "({dtype})(({0}) / ({1})) * ({1})".format(print_func(self.args[0]),
+    #                                                      print_func(self.args[1]), dtype=dtype)
 
 
 # noinspection PyPep8Naming
@@ -132,11 +113,12 @@ class modulo_ceil(sp.Function):
         else:
             return super().__new__(cls, integer, divisor)
 
-    def to_c(self, print_func):
-        dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
-        assert dtype.is_int()
-        code = "(({0}) % ({1}) == 0 ? {0} : (({dtype})(({0}) / ({1}))+1) * ({1}))"
-        return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
+    #   TODO: Implement this in FreezeExpressions
+    # def to_c(self, print_func):
+    #     dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
+    #     assert dtype.is_int()
+    #     code = "(({0}) % ({1}) == 0 ? {0} : (({dtype})(({0}) / ({1}))+1) * ({1}))"
+    #     return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
 
 
 # noinspection PyPep8Naming
@@ -162,11 +144,12 @@ class div_ceil(sp.Function):
         else:
             return super().__new__(cls, integer, divisor)
 
-    def to_c(self, print_func):
-        dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
-        assert dtype.is_int()
-        code = "( ({0}) % ({1}) == 0 ? ({dtype})({0}) / ({dtype})({1}) : ( ({dtype})({0}) / ({dtype})({1}) ) +1 )"
-        return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
+    #   TODO: Implement this in FreezeExpressions
+    # def to_c(self, print_func):
+    #     dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
+    #     assert dtype.is_int()
+    #     code = "( ({0}) % ({1}) == 0 ? ({dtype})({0}) / ({dtype})({1}) : ( ({dtype})({0}) / ({dtype})({1}) ) +1 )"
+    #     return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
 
 
 # noinspection PyPep8Naming
@@ -192,8 +175,9 @@ class div_floor(sp.Function):
         else:
             return super().__new__(cls, integer, divisor)
 
-    def to_c(self, print_func):
-        dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
-        assert dtype.is_int()
-        code = "(({dtype})({0}) / ({dtype})({1}))"
-        return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
+    #   TODO: Implement this in FreezeExpressions
+    # def to_c(self, print_func):
+    #     dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
+    #     assert dtype.is_int()
+    #     code = "(({dtype})({0}) / ({dtype})({1}))"
+    #     return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
diff --git a/src/pystencils/sympyextensions/math.py b/src/pystencils/sympyextensions/math.py
index 508132ce5..21a98ad78 100644
--- a/src/pystencils/sympyextensions/math.py
+++ b/src/pystencils/sympyextensions/math.py
@@ -549,7 +549,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
     Returns:
         dict with 'adds', 'muls' and 'divs' keys
     """
-    from pystencils.fast_approximation import fast_sqrt, fast_inv_sqrt, fast_division
+    from pystencils.sympyextensions.fast_approximation import fast_sqrt, fast_inv_sqrt, fast_division
 
     result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0,
               'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0}
diff --git a/src/pystencils/timeloop.py b/src/pystencils/timeloop.py
index 55129afb6..5c3438680 100644
--- a/src/pystencils/timeloop.py
+++ b/src/pystencils/timeloop.py
@@ -1,6 +1,6 @@
 import time
 
-from pystencils.integer_functions import modulo_ceil
+from pystencils.sympyextensions.integer_functions import modulo_ceil
 
 
 class TimeLoop:
diff --git a/tests/test_fast_approximation.py b/tests/test_fast_approximation.py
index e211d6897..375bcabac 100644
--- a/tests/test_fast_approximation.py
+++ b/tests/test_fast_approximation.py
@@ -2,7 +2,7 @@ import pytest
 import sympy as sp
 
 import pystencils as ps
-from pystencils.fast_approximation import (
+from pystencils.sympyextensions.fast_approximation import (
     fast_division, fast_inv_sqrt, fast_sqrt, insert_fast_divisions, insert_fast_sqrts)
 
 
diff --git a/tests/test_math_functions.py b/tests/test_math_functions.py
index 1fd393788..6bd80644a 100644
--- a/tests/test_math_functions.py
+++ b/tests/test_math_functions.py
@@ -2,7 +2,7 @@ import pytest
 import sympy as sp
 import numpy as np
 import pystencils as ps
-from pystencils.fast_approximation import fast_division
+from pystencils.sympyextensions.fast_approximation import fast_division
 
 
 @pytest.mark.parametrize('dtype', ["float64", "float32"])
diff --git a/tests/test_sympyextensions.py b/tests/test_sympyextensions.py
index 7afa99810..05c119968 100644
--- a/tests/test_sympyextensions.py
+++ b/tests/test_sympyextensions.py
@@ -15,7 +15,7 @@ from pystencils.sympyextensions import scalar_product
 from pystencils.sympyextensions import kronecker_delta
 
 from pystencils import Assignment
-from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt,
+from pystencils.sympyextensions.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt,
                                            insert_fast_divisions, insert_fast_sqrts)
 
 
-- 
GitLab