From 3a8cc5ae34aeefc2eced5aef335a67f3d48e7a53 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de>
Date: Wed, 8 Dec 2021 15:55:31 +0100
Subject: [PATCH] Minor fixes

---
 pystencils/backends/cbackend.py          | 38 +++++----------------
 pystencils/functions.py                  | 26 +++++++++++++++
 pystencils/simp/assignment_collection.py | 19 ++++++++---
 pystencils/typing/cast_functions.py      |  8 +++--
 pystencils/typing/leaf_typing.py         | 12 +++----
 pystencils/typing/types.py               |  7 +---
 pystencils_tests/test_types.py           | 42 +++++-------------------
 7 files changed, 70 insertions(+), 82 deletions(-)
 create mode 100644 pystencils/functions.py

diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index b5dc66e10..4aa1d0964 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -16,6 +16,7 @@ from pystencils.typing import (
     ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol)
 from pystencils.enums import Backend
 from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
+from pystencils.functions import DivFunc
 from pystencils.integer_functions import (
     bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
     int_div, int_power_of_2, modulo_ceil)
@@ -436,19 +437,14 @@ class CustomSympyPrinter(CCodePrinter):
 
     def __init__(self):
         super(CustomSympyPrinter, self).__init__()
-        self._float_type = create_type("float32")
 
     def _print_Pow(self, expr):
         """Don't use std::pow function, for small integer exponents, write as multiplication"""
         if not expr.free_symbols:
             return self._typed_number(expr.evalf(), get_type_of_expression(expr.base))
+        return super(CustomSympyPrinter, self)._print_Pow(expr)
 
-        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
-            return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})"
-        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
-            return f"1 / ({self._print(sp.Mul(*([expr.base] * -expr.exp), evaluate=False))})"
-        else:
-            return super(CustomSympyPrinter, self)._print_Pow(expr)
+    # TODO don't print ones in sp.Mul
 
     def _print_Rational(self, expr):
         """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
@@ -491,7 +487,10 @@ class CustomSympyPrinter(CCodePrinter):
             return f"&({self._print(expr.args[0])})"
         elif isinstance(expr, CastFunc):
             arg, data_type = expr.args
-            return f"(({data_type})({self._print(arg)}))"
+            if arg.is_Number:
+                return self._typed_number(arg, data_type)
+            else:
+                return f"(({data_type})({self._print(arg)}))"
         elif isinstance(expr, fast_division):
             return f"({self._print(expr.args[0] / expr.args[1])})"
         elif isinstance(expr, fast_sqrt):
@@ -515,6 +514,8 @@ class CustomSympyPrinter(CCodePrinter):
             return f"(1 << ({self._print(expr.args[0])}))"
         elif expr.func == int_div:
             return f"(({self._print(expr.args[0])}) / ({self._print(expr.args[1])}))"
+        elif expr.func == DivFunc:
+            return f'(({self._print(expr.divisor)}) / ({self._print(expr.dividend)}))'
         else:
             name = expr.name if hasattr(expr, 'name') else expr.__class__.__name__
             arg_str = ', '.join(self._print(a) for a in expr.args)
@@ -606,27 +607,6 @@ class CustomSympyPrinter(CCodePrinter):
             return f"(({a} < {b}) ? {a} : {b})"
         return inner_print_min(expr.args)
 
-    def _print_re(self, expr):
-        return f"real({self._print(expr.args[0])})"
-
-    def _print_im(self, expr):
-        return f"imag({self._print(expr.args[0])})"
-
-    def _print_ImaginaryUnit(self, expr):
-        return "complex<double>{0,1}"
-
-    def _print_TypedImaginaryUnit(self, expr):
-        if expr.dtype.numpy_dtype == np.complex64:
-            return "complex<float>{0,1}"
-        elif expr.dtype.numpy_dtype == np.complex128:
-            return "complex<double>{0,1}"
-        else:
-            raise NotImplementedError(
-                "only complex64 and complex128 supported")
-
-    def _print_Complex(self, expr):
-        return self._typed_number(expr, np.complex64)
-
 
 # noinspection PyPep8Naming
 class VectorizedCustomSympyPrinter(CustomSympyPrinter):
diff --git a/pystencils/functions.py b/pystencils/functions.py
new file mode 100644
index 000000000..b1f349622
--- /dev/null
+++ b/pystencils/functions.py
@@ -0,0 +1,26 @@
+import sympy as sp
+
+
+class DivFunc(sp.Function):
+    # TODO: documentation
+    is_Atom = True
+    is_real = True
+
+    def __new__(cls, *args, **kwargs):
+        if len(args) != 2:
+            raise ValueError(f'{cls} takes only 2 arguments, instead {len(args)} received!')
+        divisor, dividend, *other_args = args
+
+        return sp.Function.__new__(cls, divisor, dividend, *other_args, **kwargs)
+
+    def _eval_evalf(self, *args, **kwargs):
+        return self.divisor.evalf() / self.dividend.evalf()
+
+    @property
+    def divisor(self):
+        return self.args[0]
+
+    @property
+    def dividend(self):
+        return self.args[1]
+
diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py
index 7309e7d87..5a6f0d010 100644
--- a/pystencils/simp/assignment_collection.py
+++ b/pystencils/simp/assignment_collection.py
@@ -7,6 +7,7 @@ from sympy.codegen.rewriting import ReplaceOptim, optimize
 
 import pystencils
 from pystencils.assignment import Assignment
+from pystencils.functions import DivFunc
 from pystencils.simp.simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
 from pystencils.sympyextensions import count_operations, fast_subs
 
@@ -371,15 +372,23 @@ class AssignmentCollection:
             lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
             lambda p: p.evalf())
 
-        sympy_optimisations = [evaluate_constant_terms]
+        evaluate_pow = ReplaceOptim(
+            lambda e: e.is_Pow and e.exp.is_Integer and abs(e.exp) <= 8,
+            lambda p: (
+                sp.UnevaluatedExpr(sp.Mul(*([p.base] * +p.exp), evaluate=False)) if p.exp > 0 else
+                DivFunc(sp.Integer(1), sp.Mul(*([p.base] * -p.exp), evaluate=False))
+            ))
+
+        sympy_optimisations = [evaluate_constant_terms, evaluate_pow]
 
         self.subexpressions = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
-                       if hasattr(a, 'lhs')
-                       else a for a in self.subexpressions]
+                               if hasattr(a, 'lhs')
+                               else a for a in self.subexpressions]
 
         self.main_assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
-                       if hasattr(a, 'lhs')
-                       else a for a in self.main_assignments]
+                                 if hasattr(a, 'lhs')
+                                 else a for a in self.main_assignments]
+
     # ----------------------------------------- Display and Printing   -------------------------------------------------
 
     def _repr_html_(self):
diff --git a/pystencils/typing/cast_functions.py b/pystencils/typing/cast_functions.py
index e93a410a8..2e29b74bb 100644
--- a/pystencils/typing/cast_functions.py
+++ b/pystencils/typing/cast_functions.py
@@ -8,16 +8,18 @@ from pystencils.typing.typed_sympy import TypedSymbol
 
 class CastFunc(sp.Function):
     # TODO: documentation
-    # TODO: move function to `functions.py`
     is_Atom = True
 
     def __new__(cls, *args, **kwargs):
         if len(args) != 2:
             pass
         expr, dtype, *other_args = args
+
+        # If we have two consecutive casts, throw the inner one away
+        if isinstance(expr, CastFunc):
+            expr = expr.args[0]
         if not isinstance(dtype, AbstractType):
-            raise NotImplementedError(f'{dtype} is not a subclass of AbstractType')
-            dtype = create_type(dtype)
+            dtype = BasicType(dtype)
         # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
         # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
         # to problems when for example comparing cast_func's for equality
diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py
index df36c0d91..ae66d3849 100644
--- a/pystencils/typing/leaf_typing.py
+++ b/pystencils/typing/leaf_typing.py
@@ -115,7 +115,7 @@ class TypeAdder:
                 data_type = self.default_number_float.get()
             else:
                 assert False, f'{sp.Number} is neither Float nor Integer'
-            return expr, data_type
+            return CastFunc(expr, data_type), data_type
         elif isinstance(expr, BooleanAtom):
             return expr, BasicType('bool')
         elif isinstance(expr, sp.Equality):
@@ -130,16 +130,16 @@ class TypeAdder:
         elif isinstance(expr, flag_cond):
             #   do not process the arguments to the bit shift - they must remain integers
             raise NotImplementedError('flag_cond')
-        elif isinstance(expr, sp.Mul):
-            raise NotImplementedError('sp.Mul')
-            # TODO can we ignore this and move it to general expr handling, i.e. removing Mul?
-            # args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)]
+        #elif isinstance(expr, sp.Mul):
+        #    raise NotImplementedError('sp.Mul')
+        #    # TODO can we ignore this and move it to general expr handling, i.e. removing Mul?
+        #    # args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)]
         elif isinstance(expr, sp.Indexed):
             raise NotImplementedError('sp.Indexed')
         elif isinstance(expr, sp.Pow):
             args_types = [self.figure_out_type(arg) for arg in expr.args]
             collated_type = collate_types([t for _, t in args_types])
-            return expr, collated_type
+            return expr.func(*[a for a, _ in args_types]), collated_type
         elif isinstance(expr, ExprCondPair):
             expr_expr, expr_type = self.figure_out_type(expr.expr)
             condition, condition_type = self.figure_out_type(expr.cond)
diff --git a/pystencils/typing/types.py b/pystencils/typing/types.py
index 9ec46d5c1..9bab35c5f 100644
--- a/pystencils/typing/types.py
+++ b/pystencils/typing/types.py
@@ -1,9 +1,8 @@
-from abc import ABC, abstractmethod
+from abc import abstractmethod
 from typing import Union
 
 import numpy as np
 import sympy as sp
-import sympy.codegen.ast
 
 
 def is_supported_type(dtype: np.dtype):
@@ -86,10 +85,6 @@ class BasicType(AbstractType):
     def base_type(self):
         return None
 
-    @property
-    def sympy_dtype(self):
-        return getattr(sympy.codegen.ast, str(self.numpy_dtype))
-
     @property
     def item_size(self):  # TODO: what is this? Do we want self.numpy_type.itemsize????
         return 1
diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py
index d1b0d33cb..631855d17 100644
--- a/pystencils_tests/test_types.py
+++ b/pystencils_tests/test_types.py
@@ -172,6 +172,7 @@ def test_sqrt_of_integer(dtype):
         assert constant not in code
 
 
+# TODO this
 @pytest.mark.parametrize('dtype', ('float64', 'float32'))
 def test_integer_comparision(dtype):
     f = ps.fields(f"f: {dtype}[2D]")
@@ -182,7 +183,6 @@ def test_integer_comparision(dtype):
     ast = ps.create_kernel(ur)
     code = ps.get_code_str(ast)
 
-    print(code)
     # There should be an explicit cast for the integer zero to the type of the field on the rhs
     if dtype == 'float64':
         t = "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (((double)(0))): (_data_f_00[_stride_f_1*ctr_1]));"
@@ -192,44 +192,21 @@ def test_integer_comparision(dtype):
     assert t in code
 
 
-# TODO this
-def test_Basic_data_type():
+def test_typed_symbols_dtype():
     assert typed_symbols(("s", "f"), np.uint) == typed_symbols("s, f", np.uint)
     t_symbols = typed_symbols(("s", "f"), np.uint)
     s = t_symbols[0]
 
     assert t_symbols[0] == TypedSymbol("s", np.uint)
     assert s.dtype.is_uint()
-    assert s.dtype.is_complex() == 0
-
-    assert typed_symbols("s", str).dtype.is_other()
-    assert typed_symbols("s", bool).dtype.is_other()
-    assert typed_symbols("s", np.void).dtype.is_other()
 
     assert typed_symbols("s", np.float64).dtype.c_name == 'double'
-    # removed for old sympy version
-    # assert typed_symbols(("s"), np.float64).dtype.sympy_dtype == typed_symbols(("s"), float).dtype.sympy_dtype
-
-    f, g = ps.fields("f, g : double[2D]")
-
-    expr = ps.Assignment(f.center(), 2 * g.center() + 5)
-    new_expr = type_all_numbers(expr, np.float64)
-
-    assert "cast_func(2, double)" in str(new_expr)
-    assert "cast_func(5, double)" in str(new_expr)
-
-    m = matrix_symbols("a, b", np.uint, 3, 3)
-    assert len(m) == 2
-    m = m[0]
-    for i, elem in enumerate(m):
-        assert elem == TypedSymbol(f"a{i}", np.uint)
-        assert elem.dtype.is_uint()
+    assert typed_symbols("s", np.float32).dtype.c_name == 'float'
 
     assert TypedSymbol("s", np.uint).canonical == TypedSymbol("s", np.uint)
     assert TypedSymbol("s", np.uint).reversed == TypedSymbol("s", np.uint)
 
 
-# TODO this
 def test_cast_func():
     assert CastFunc(TypedSymbol("s", np.uint), np.int64).canonical == TypedSymbol("s", np.uint).canonical
 
@@ -242,21 +219,19 @@ def test_pointer_arithmetic_func():
     assert PointerArithmeticFunc(TypedSymbol("s", np.uint), 1).canonical == TypedSymbol("s", np.uint).canonical
 
 
-# TODO this
 def test_division():
     f = ps.fields('f(10): float32[2D]')
     m, tau = sp.symbols("m, tau")
 
-    up = [ps.Assignment(tau, 1.0 / (0.5 + (3.0 * m))),
+    up = [ps.Assignment(tau, 1 / (0.5 + (3.0 * m))),
           ps.Assignment(f.center, tau)]
-
-    ast = ps.create_kernel(up, config=pystencils.config.CreateKernelConfig(data_type="float32"))
+    config = pystencils.config.CreateKernelConfig(data_type='float32', default_number_float='float32')
+    ast = ps.create_kernel(up, config=config)
     code = ps.get_code_str(ast)
 
-    assert "1.0f" in code
+    assert "((1.0f) / (m*3.0f + 0.5f))" in code
 
 
-# TODO this
 def test_pow():
     f = ps.fields('f(10): float32[2D]')
     m, tau = sp.symbols("m, tau")
@@ -264,7 +239,8 @@ def test_pow():
     up = [ps.Assignment(tau, m ** 1.5),
           ps.Assignment(f.center, tau)]
 
-    ast = ps.create_kernel(up, config=pystencils.config.CreateKernelConfig(data_type="float32"))
+    config = pystencils.config.CreateKernelConfig(data_type="float32", default_number_float='float32')
+    ast = ps.create_kernel(up, config=config)
     code = ps.get_code_str(ast)
 
     assert "1.5f" in code
-- 
GitLab