From 43bb69a0648157308b1fa006f2977d36101123fc Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Fri, 19 Nov 2021 10:17:20 +0000
Subject: [PATCH] Introduce default assignment simplifications

---
 .gitlab-ci.yml                             |  7 +-
 doc/index.rst                              |  1 +
 doc/sphinx/simplifications.rst             | 18 ++++-
 pystencils/cpu/cpujit.py                   |  4 +-
 pystencils/cpu/kernelcreation.py           |  2 +-
 pystencils/kernelcreation.py               | 40 +++++++++--
 pystencils/simp/assignment_collection.py   | 17 +++--
 pystencils/simp/simplifications.py         | 25 ++++++-
 pystencils/simp/subexpression_insertion.py |  4 +-
 pystencils/simplificationfactory.py        | 18 +++++
 pystencils/sympyextensions.py              |  2 +-
 pystencils_tests/test_Min_Max.py           | 55 ++++-----------
 pystencils_tests/test_astnodes.py          | 35 +++++-----
 pystencils_tests/test_simplifications.py   | 51 ++++++++++++++
 pystencils_tests/test_sum_prod.py          | 81 +++++++++++-----------
 15 files changed, 239 insertions(+), 121 deletions(-)
 create mode 100644 pystencils/simplificationfactory.py

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index d7e410860..c4e477014 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -142,17 +142,13 @@ ubuntu:
     - sed -i 's/--doctest-modules //g' pytest.ini
     - env
     - pip3 list
-    - pytest-3 -v -n $NUM_CORES --cov-report html --cov-report term --cov=. --junitxml=report.xml pystencils_tests/test_*vec*.py pystencils_tests/test_random.py
-    - python3 -m coverage xml
+    - pytest-3 -v -n $NUM_CORES --junitxml=report.xml pystencils_tests/test_*vec*.py pystencils_tests/test_random.py
   tags:
     - docker
     - AVX
   artifacts:
     when: always
-    paths:
-      - coverage_report
     reports:
-      cobertura: coverage.xml
       junit: report.xml
 
 arm64v8:
@@ -252,6 +248,7 @@ pycodegen-integration:
     - pip install -e pystencils/
     - pip install -e lbmpy/
     - pip install -e pygrandchem/
+    - cmake --version
     - ./install_walberla.sh
     - export NUM_CORES=$(nproc --all)
     - mkdir -p ~/.config/matplotlib
diff --git a/doc/index.rst b/doc/index.rst
index a52161449..3030d3361 100644
--- a/doc/index.rst
+++ b/doc/index.rst
@@ -14,5 +14,6 @@ pystencils can help you to generate blazingly fast code for image processing, nu
 
 .. image:: /img/pystencils_arch_block_diagram.svg
     :height: 450px
+    :align: center
 
 
diff --git a/doc/sphinx/simplifications.rst b/doc/sphinx/simplifications.rst
index ca1508c5b..63f394ca5 100644
--- a/doc/sphinx/simplifications.rst
+++ b/doc/sphinx/simplifications.rst
@@ -10,13 +10,27 @@ AssignmentCollection
    :members:
 
 
+SimplificationStrategy
+======================
+
+.. autoclass:: pystencils.simp.SimplificationStrategy
+    :members:
+
 Simplifications
 ===============
 
-.. automodule:: pystencils.simp
-   :members:
+.. automodule:: pystencils.simp.simplifications
+    :members:
+
+Subexpression insertion
+=======================
 
+The subexpression insertions have the goal to insert subexpressions which will not reduce the number of FLOPs.
+For example a constant value kept as subexpression will lead to a new variable in the code which will occupy
+a register slot. On the other side a single variable could just be inserted in all assignments.
 
+.. automodule:: pystencils.simp.subexpression_insertion
+    :members:
 
 
 
diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py
index ac1e8ac4f..ff577c442 100644
--- a/pystencils/cpu/cpujit.py
+++ b/pystencils/cpu/cpujit.py
@@ -554,7 +554,7 @@ class ExtensionModuleCode:
                       if os.path.exists(os.path.join(os.path.dirname(__file__), '..', 'include', h[1:-1]))]
         header_hash = b''.join([hashlib.sha256(open(h, 'rb').read()).digest() for h in ps_headers])
 
-        includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list])
+        includes = "\n".join([f"#include {include_file}" for include_file in header_list])
         self._code_string += includes
         self._code_string += "\n"
         self._code_string += f"#define RESTRICT {restrict_qualifier} \n"
@@ -563,7 +563,7 @@ class ExtensionModuleCode:
 
         for ast, name in zip(self._ast_nodes, self._function_names):
             old_name = ast.function_name
-            ast.function_name = "kernel_" + name
+            ast.function_name = f"kernel_{name}"
             self._code_string += generate_c(ast, custom_backend=self._custom_backend)
             self._code_string += create_function_boilerplate_code(ast.get_parameters(), name, ast)
             ast.function_name = old_name
diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py
index 4237dcd34..865beefa9 100644
--- a/pystencils/cpu/kernelcreation.py
+++ b/pystencils/cpu/kernelcreation.py
@@ -205,5 +205,5 @@ def add_openmp(ast_node, schedule="static", num_threads=True, collapse=None, ass
 
         prefix = f"#pragma omp for schedule({schedule})"
         if collapse:
-            prefix += " collapse(%d)" % (collapse, )
+            prefix += f" collapse({collapse})"
         loop_to_parallelize.prefix_lines.append(prefix)
diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index e50851ae5..ac4412256 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -14,6 +14,8 @@ from pystencils.enums import Target, Backend
 from pystencils.field import Field, FieldType
 from pystencils.gpucuda.indexing import indexing_creator_from_params
 from pystencils.simp.assignment_collection import AssignmentCollection
+from pystencils.simp.simplifications import apply_sympy_optimisations
+from pystencils.simplificationfactory import create_simplification_strategy
 from pystencils.stencil import direction_string_to_offset, inverse_direction_string
 from pystencils.transformations import (
     loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel)
@@ -83,6 +85,18 @@ class CreateKernelConfig:
     Dict with indexing parameters (constructor parameters of indexing class)
     e.g. for 'block' one can specify '{'block_size': (20, 20, 10) }'.
     """
+    default_assignment_simplifications: bool = False
+    """
+    If `True` default simplifications are first performed on the Assignments. If problems occur during the
+    simplification a warning will be thrown. 
+    Furthermore, it is essential to know that this is a two-stage process. The first stage of the process acts 
+    on the level of the `AssignmentCollection`.  In this part, `create_simplification_strategy` 
+    from pystencils.simplificationfactory will be used to apply optimisations like insertion of constants to 
+    remove pressure from the registers. Thus the first part of the optimisations can only be executed if 
+    an `AssignmentCollection` is passed. The second part of the optimisation acts on the level of each Assignment 
+    individually. In this stage, all optimisations from `sympy.codegen.rewriting.optims_c99` are applied 
+    to each Assignment. Thus this stage can also be applied if a list of Assignments is passed.
+    """
     cpu_prepend_optimizations: List[Callable] = field(default_factory=list)
     """
     List of extra optimizations to perform first on the AST.
@@ -195,8 +209,8 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
         >>> import numpy as np
         >>> s, d = ps.fields('s, d: [2D]')
         >>> assignment = ps.Assignment(d[0,0], s[0, 1] + s[0, -1] + s[1, 0] + s[-1, 0])
-        >>> config = ps.CreateKernelConfig(cpu_openmp=True)
-        >>> kernel_ast = ps.kernelcreation.create_domain_kernel([assignment], config=config)
+        >>> kernel_config = ps.CreateKernelConfig(cpu_openmp=True)
+        >>> kernel_ast = ps.kernelcreation.create_domain_kernel([assignment], config=kernel_config)
         >>> kernel = kernel_ast.compile()
         >>> d_arr = np.zeros([5, 5])
         >>> kernel(d=d_arr, s=np.ones([5, 5]))
@@ -207,6 +221,15 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
                [0., 4., 4., 4., 0.],
                [0., 0., 0., 0., 0.]])
     """
+    # --- applying first default simplifications
+    try:
+        if config.default_assignment_simplifications and isinstance(assignments, AssignmentCollection):
+            simplification = create_simplification_strategy()
+            assignments = simplification(assignments)
+    except Exception as e:
+        warnings.warn(f"It was not possible to apply the default pystencils optimisations to the "
+                      f"AssignmentCollection due to the following problem :{e}")
+
     # ----  Normalizing parameters
     split_groups = ()
     if isinstance(assignments, AssignmentCollection):
@@ -214,6 +237,13 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
             split_groups = assignments.simplification_hints['split_groups']
         assignments = assignments.all_assignments
 
+    try:
+        if config.default_assignment_simplifications:
+            assignments = apply_sympy_optimisations(assignments)
+    except Exception as e:
+        warnings.warn(f"It was not possible to apply the default SymPy optimisations to the "
+                      f"Assignments due to the following problem :{e}")
+
     # ----  Creating ast
     ast = None
     if config.target == Target.CPU:
@@ -304,9 +334,9 @@ def create_indexed_kernel(assignments: List[Assignment], *, config: CreateKernel
         >>>
         >>> # Additional values  stored in index field can be accessed in the kernel as well
         >>> s, d = ps.fields('s, d: [2D]')
-        >>> assignment = ps.Assignment(d[0,0], 2 * s[0, 1] + 2 * s[1, 0] + idx_field('val'))
-        >>> config = ps.CreateKernelConfig(index_fields=[idx_field], coordinate_names=('x', 'y'))
-        >>> kernel_ast = ps.create_indexed_kernel([assignment], config=config)
+        >>> assignment = ps.Assignment(d[0, 0], 2 * s[0, 1] + 2 * s[1, 0] + idx_field('val'))
+        >>> kernel_config = ps.CreateKernelConfig(index_fields=[idx_field], coordinate_names=('x', 'y'))
+        >>> kernel_ast = ps.create_indexed_kernel([assignment], config=kernel_config)
         >>> kernel = kernel_ast.compile()
         >>> d_arr = np.zeros([5, 5])
         >>> kernel(s=np.ones([5, 5]), d=d_arr, idx=index_arr)
diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py
index 33102dee5..f493b08e9 100644
--- a/pystencils/simp/assignment_collection.py
+++ b/pystencils/simp/assignment_collection.py
@@ -19,13 +19,14 @@ class AssignmentCollection:
     Additionally a dictionary of simplification hints is stored, which are set by the functions that create
     assignment collections to transport information to the simplification system.
 
-    Attributes:
-        main_assignments: list of assignments
-        subexpressions: list of assignments defining subexpressions used in main equations
-        simplification_hints: dict that is used to annotate the assignment collection with hints that are
+    Args:
+        main_assignments: List of assignments. Main assignments are characterised, that the right hand side of each
+                          assignment is a field access. Thus the generated equations write on arrays.
+        subexpressions: List of assignments defining subexpressions used in main equations
+        simplification_hints: Dict that is used to annotate the assignment collection with hints that are
                               used by the simplification system. See documentation of the simplification rules for
                               potentially required hints and their meaning.
-        subexpression_symbol_generator: generator for new symbols that are used when new subexpressions are added
+        subexpression_symbol_generator: Generator for new symbols that are used when new subexpressions are added
                                         used to get new symbols that are unique for this AssignmentCollection
 
     """
@@ -33,9 +34,13 @@ class AssignmentCollection:
     # ------------------------------- Creation & Inplace Manipulation --------------------------------------------------
 
     def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
-                 subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = {},
+                 subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = None,
                  simplification_hints: Optional[Dict[str, Any]] = None,
                  subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None:
+
+        if subexpressions is None:
+            subexpressions = {}
+
         if isinstance(main_assignments, Dict):
             main_assignments = [Assignment(k, v)
                                 for k, v in main_assignments.items()]
diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py
index 7345ad6f3..114b86a40 100644
--- a/pystencils/simp/simplifications.py
+++ b/pystencils/simp/simplifications.py
@@ -3,9 +3,11 @@ from typing import Callable, List, Sequence, Union
 from collections import defaultdict
 
 import sympy as sp
+from sympy.codegen.rewriting import optims_c99, optimize
+from sympy.codegen.rewriting import ReplaceOptim
 
 from pystencils.assignment import Assignment
-from pystencils.astnodes import Node
+from pystencils.astnodes import Node, SympyAssignment
 from pystencils.field import AbstractField, Field
 from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
 
@@ -223,3 +225,24 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):
 
     f.__name__ = operation.__name__
     return f
+
+
+def apply_sympy_optimisations(assignments):
+    """ Evaluates constant expressions (e.g. :math:`\\sqrt{3}` will be replaced by its floating point representation)
+        and applies the default sympy optimisations. See sympy.codegen.rewriting
+    """
+
+    # Evaluates all constant terms
+    evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
+                                           lambda p: p.evalf())
+
+    sympy_optimisations = [evaluate_constant_terms] + list(optims_c99)
+
+    assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
+                   if hasattr(a, 'lhs')
+                   else a for a in assignments]
+    assignments_nodes = [a.atoms(SympyAssignment) for a in assignments]
+    for a in chain.from_iterable(assignments_nodes):
+        a.optimize(sympy_optimisations)
+
+    return assignments
diff --git a/pystencils/simp/subexpression_insertion.py b/pystencils/simp/subexpression_insertion.py
index 9293b56be..27ac79826 100644
--- a/pystencils/simp/subexpression_insertion.py
+++ b/pystencils/simp/subexpression_insertion.py
@@ -4,7 +4,7 @@ from pystencils.sympyextensions import is_constant
 #   Subexpression Insertion
 
 
-def insert_subexpressions(ac, selection_callback, skip=set()):
+def insert_subexpressions(ac, selection_callback, skip=None):
     """
         Removes a number of subexpressions from an assignment collection by
         inserting their right-hand side wherever they occur.
@@ -16,6 +16,8 @@ def insert_subexpressions(ac, selection_callback, skip=set()):
          - skip: Set of symbols (left-hand sides of subexpressions) that should be 
             ignored even if qualified by the callback.
     """
+    if skip is None:
+        skip = set()
     i = 0
     while i < len(ac.subexpressions):
         exp = ac.subexpressions[i]
diff --git a/pystencils/simplificationfactory.py b/pystencils/simplificationfactory.py
new file mode 100644
index 000000000..50ee2d7f8
--- /dev/null
+++ b/pystencils/simplificationfactory.py
@@ -0,0 +1,18 @@
+from pystencils.simp import (SimplificationStrategy, insert_constants, insert_symbol_times_minus_one,
+                             insert_constant_multiples, insert_constant_additions, insert_squares, insert_zeros)
+
+
+def create_simplification_strategy():
+    """
+    Creates a default simplification `ps.simp.SimplificationStrategy`. The idea behind the default simplification
+    strategy is to reduce the number of subexpressions by inserting single constants and to evaluate constant
+    terms beforehand.
+    """
+    s = SimplificationStrategy()
+    s.add(insert_symbol_times_minus_one)
+    s.add(insert_constant_multiples)
+    s.add(insert_constant_additions)
+    s.add(insert_squares)
+    s.add(insert_zeros)
+    s.add(insert_constants)
+    s.add(lambda ac: ac.new_without_unused_subexpressions())
diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py
index 29b524eef..f63328d81 100644
--- a/pystencils/sympyextensions.py
+++ b/pystencils/sympyextensions.py
@@ -453,7 +453,7 @@ def recursive_collect(expr, symbols, order_by_occurences=False):
     return rec_sum
 
 
-def count_operations(term: Union[sp.Expr, List[sp.Expr]],
+def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
                      only_type: Optional[str] = 'real') -> Dict[str, int]:
     """Counts the number of additions, multiplications and division.
 
diff --git a/pystencils_tests/test_Min_Max.py b/pystencils_tests/test_Min_Max.py
index 18cd2d99d..c227fbf14 100644
--- a/pystencils_tests/test_Min_Max.py
+++ b/pystencils_tests/test_Min_Max.py
@@ -1,10 +1,13 @@
-import sympy
+import pytest
+
+import sympy as sp
 import numpy
 import pystencils
 from pystencils.datahandling import create_data_handling
 
 
-def test_max():
+@pytest.mark.parametrize('sympy_function', [sp.Min, sp.Max])
+def test_max(sympy_function):
     dh = create_data_handling(domain_size=(10, 10), periodicity=True)
 
     x = dh.add_array('x', values_per_cell=1)
@@ -15,56 +18,28 @@ def test_max():
     dh.fill("z", 2.0, ghost_layers=True)
 
     # test sp.Max with one argument
-    assignment_1 = pystencils.Assignment(x.center, sympy.Max(y.center + 3.3))
+    assignment_1 = pystencils.Assignment(x.center, sympy_function(y.center + 3.3))
     ast_1 = pystencils.create_kernel(assignment_1)
     kernel_1 = ast_1.compile()
 
     # test sp.Max with two arguments
-    assignment_2 = pystencils.Assignment(x.center, sympy.Max(0.5, y.center - 1.5))
+    assignment_2 = pystencils.Assignment(x.center, sympy_function(0.5, y.center - 1.5))
     ast_2 = pystencils.create_kernel(assignment_2)
     kernel_2 = ast_2.compile()
 
     # test sp.Max with many arguments
-    assignment_3 = pystencils.Assignment(x.center, sympy.Max(z.center, 4.5, y.center - 1.5, y.center + z.center))
+    assignment_3 = pystencils.Assignment(x.center, sympy_function(z.center, 4.5, y.center - 1.5, y.center + z.center))
     ast_3 = pystencils.create_kernel(assignment_3)
     kernel_3 = ast_3.compile()
 
-    dh.run_kernel(kernel_1)
-    assert numpy.all(dh.cpu_arrays["x"] == 4.3)
-    dh.run_kernel(kernel_2)
-    assert numpy.all(dh.cpu_arrays["x"] == 0.5)
-    dh.run_kernel(kernel_3)
-    assert numpy.all(dh.cpu_arrays["x"] == 4.5)
-
-
-def test_min():
-    dh = create_data_handling(domain_size=(10, 10), periodicity=True)
-
-    x = dh.add_array('x', values_per_cell=1)
-    dh.fill("x", 0.0, ghost_layers=True)
-    y = dh.add_array('y', values_per_cell=1)
-    dh.fill("y", 1.0, ghost_layers=True)
-    z = dh.add_array('z', values_per_cell=1)
-    dh.fill("z", 2.0, ghost_layers=True)
-
-    # test sp.Min with one argument
-    assignment_1 = pystencils.Assignment(x.center, sympy.Min(y.center + 3.3))
-    ast_1 = pystencils.create_kernel(assignment_1)
-    kernel_1 = ast_1.compile()
-
-    # test sp.Min with two arguments
-    assignment_2 = pystencils.Assignment(x.center, sympy.Min(0.5, y.center - 1.5))
-    ast_2 = pystencils.create_kernel(assignment_2)
-    kernel_2 = ast_2.compile()
-
-    # test sp.Min with many arguments
-    assignment_3 = pystencils.Assignment(x.center, sympy.Min(z.center, 4.5, y.center - 1.5, y.center + z.center))
-    ast_3 = pystencils.create_kernel(assignment_3)
-    kernel_3 = ast_3.compile()
+    if sympy_function is sp.Max:
+        results = [4.3, 0.5, 4.5]
+    else:
+        results = [4.3, -0.5, -0.5]
 
     dh.run_kernel(kernel_1)
-    assert numpy.all(dh.cpu_arrays["x"] == 4.3)
+    assert numpy.all(dh.gather_array('x') == results[0])
     dh.run_kernel(kernel_2)
-    assert numpy.all(dh.cpu_arrays["x"] == - 0.5)
+    assert numpy.all(dh.gather_array('x') == results[1])
     dh.run_kernel(kernel_3)
-    assert numpy.all(dh.cpu_arrays["x"] == - 0.5)
+    assert numpy.all(dh.gather_array('x') == results[2])
diff --git a/pystencils_tests/test_astnodes.py b/pystencils_tests/test_astnodes.py
index 385c4e223..688f63ed9 100644
--- a/pystencils_tests/test_astnodes.py
+++ b/pystencils_tests/test_astnodes.py
@@ -1,24 +1,19 @@
 import pytest
+import sys
 import sympy as sp
 
 import pystencils as ps
 from pystencils import Assignment
 from pystencils.astnodes import Block, LoopOverCoordinate, SkipIteration, SympyAssignment
 
-sympy_numeric_version = [int(x, 10) for x in sp.__version__.split('.') if x.isdigit()]
-if len(sympy_numeric_version) < 3:
-    sympy_numeric_version.append(0)
-sympy_numeric_version.reverse()
-sympy_version = sum(x * (100 ** i) for i, x in enumerate(sympy_numeric_version))
-
 dst = ps.fields('dst(8): double[2D]')
 s = sp.symbols('s_:8')
 x = sp.symbols('x')
 y = sp.symbols('y')
 
+python_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
+
 
-@pytest.mark.skipif(sympy_version < 10501,
-                    reason="Old Sympy Versions behave differently which wont be supported in the near future")
 def test_kernel_function():
     assignments = [
         Assignment(dst[0, 0](0), s[0]),
@@ -44,8 +39,6 @@ def test_skip_iteration():
     assert skipped.undefined_symbols == set()
 
 
-@pytest.mark.skipif(sympy_version < 10501,
-                    reason="Old Sympy Versions behave differently which wont be supported in the near future")
 def test_block():
     assignments = [
         Assignment(dst[0, 0](0), s[0]),
@@ -92,17 +85,23 @@ def test_loop_over_coordinate():
     assert loop.step == 2
 
 
-def test_sympy_assignment():
-    pytest.importorskip('sympy.codegen.rewriting')
-    from sympy.codegen.rewriting import optims_c99
+@pytest.mark.parametrize('default_assignment_simplifications', [False, True])
+@pytest.mark.skipif(python_version == '3.8.2', reason="For this python version a strange bug in mpmath occurs")
+def test_sympy_assignment(default_assignment_simplifications):
     assignment = SympyAssignment(dst[0, 0](0), sp.log(x + 3) / sp.log(2) + sp.log(x ** 2 + 1))
-    assignment.optimize(optims_c99)
 
-    ast = ps.create_kernel([assignment])
+    config = ps.CreateKernelConfig(default_assignment_simplifications=default_assignment_simplifications)
+    ast = ps.create_kernel([assignment], config=config)
     code = ps.get_code_str(ast)
-
-    assert 'log1p' in code
-    assert 'log2' in code
+        
+    if default_assignment_simplifications:
+        assert 'log1p' in code
+        # constant term is directly evaluated
+        assert 'log2' not in code
+    else:
+        # no optimisations will be applied so the optimised version of log will not be in the code
+        assert 'log1p' not in code
+        assert 'log2' not in code
 
     assignment.replace(assignment.lhs, dst[0, 0](1))
     assignment.replace(assignment.rhs, sp.log(2))
diff --git a/pystencils_tests/test_simplifications.py b/pystencils_tests/test_simplifications.py
index 62fa262fb..1c9ed3c0c 100644
--- a/pystencils_tests/test_simplifications.py
+++ b/pystencils_tests/test_simplifications.py
@@ -1,5 +1,7 @@
+from sys import version_info as vs
 import pytest
 import sympy as sp
+import pystencils as ps
 
 from pystencils.simp import subexpression_substitution_in_main_assignments
 from pystencils.simp import add_subexpressions_for_divisions
@@ -136,3 +138,52 @@ def test_add_subexpressions_for_field_reads():
     assert len(ac.subexpressions) == 0
     ac = add_subexpressions_for_field_reads(ac)
     assert len(ac.subexpressions) == 2
+
+
+@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
+@pytest.mark.parametrize('simplification', (True, False))
+@pytest.mark.skipif((vs.major, vs.minor, vs.micro) == (3, 8, 2), reason="does not work on python 3.8.2 for some reason")
+def test_sympy_optimizations(target, simplification):
+    if target == ps.Target.GPU:
+        pytest.importorskip("pycuda")
+    src, dst = ps.fields('src, dst:  float32[2d]')
+
+    # Triggers Sympy's expm1 optimization
+    # Sympy's expm1 optimization is tedious to use and the behaviour is highly depended on the sympy version. In
+    # some cases the exp expression has to be encapsulated in brackets or multiplied with 1 or 1.0
+    # for sympy to work properly ...
+    assignments = ps.AssignmentCollection({
+        src[0, 0]: 1.0 * (sp.exp(dst[0, 0]) - 1)
+    })
+
+    config = ps.CreateKernelConfig(target=target, default_assignment_simplifications=simplification)
+    ast = ps.create_kernel(assignments, config=config)
+
+    code = ps.get_code_str(ast)
+    if simplification:
+        assert 'expm1(' in code
+    else:
+        assert 'expm1(' not in code
+
+
+@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
+@pytest.mark.parametrize('simplification', (True, False))
+@pytest.mark.skipif((vs.major, vs.minor, vs.micro) == (3, 8, 2), reason="does not work on python 3.8.2 for some reason")
+def test_evaluate_constant_terms(target, simplification):
+    if target == ps.Target.GPU:
+        pytest.importorskip("pycuda")
+    src, dst = ps.fields('src, dst:  float32[2d]')
+
+    # Triggers Sympy's cos optimization
+    assignments = ps.AssignmentCollection({
+        src[0, 0]: -sp.cos(1) + dst[0, 0]
+    })
+
+    config = ps.CreateKernelConfig(target=target, default_assignment_simplifications=simplification)
+    ast = ps.create_kernel(assignments, config=config)
+    code = ps.get_code_str(ast)
+    if simplification:
+        assert 'cos(' not in code
+    else:
+        assert 'cos(' in code
+    print(code)
diff --git a/pystencils_tests/test_sum_prod.py b/pystencils_tests/test_sum_prod.py
index 4b4cd7131..2f6bf7359 100644
--- a/pystencils_tests/test_sum_prod.py
+++ b/pystencils_tests/test_sum_prod.py
@@ -7,34 +7,36 @@
 """
 
 """
+import pytest
 import numpy as np
-import sympy
-from sympy.abc import k
+import sympy as sp
+import sympy.abc
 
-import pystencils
+import pystencils as ps
 from pystencils.data_types import create_type
 
 
-def test_sum():
+@pytest.mark.parametrize('default_assignment_simplifications', [False, True])
+def test_sum(default_assignment_simplifications):
 
-    sum = sympy.Sum(k, (k, 1, 100))
+    sum = sp.Sum(sp.abc.k, (sp.abc.k, 1, 100))
     expanded_sum = sum.doit()
 
     print(sum)
     print(expanded_sum)
 
-    x = pystencils.fields('x: float32[1d]')
+    x = ps.fields('x: float32[1d]')
 
-    assignments = pystencils.AssignmentCollection({
-        x.center(): sum
-    })
+    assignments = ps.AssignmentCollection({x.center(): sum})
 
-    ast = pystencils.create_kernel(assignments)
-    code = str(pystencils.get_code_obj(ast))
+    config = ps.CreateKernelConfig(default_assignment_simplifications=default_assignment_simplifications)
+    ast = ps.create_kernel(assignments, config=config)
+    code = ps.get_code_str(ast)
     kernel = ast.compile()
 
     print(code)
-    assert 'double sum' in code
+    if default_assignment_simplifications is False:
+        assert 'double sum' in code
 
     array = np.zeros((10,), np.float32)
 
@@ -43,27 +45,28 @@ def test_sum():
     assert np.allclose(array, int(expanded_sum) * np.ones_like(array))
 
 
-def test_sum_use_float():
+@pytest.mark.parametrize('default_assignment_simplifications', [False, True])
+def test_sum_use_float(default_assignment_simplifications):
 
-    sum = sympy.Sum(k, (k, 1, 100))
+    sum = sympy.Sum(sp.abc.k, (sp.abc.k, 1, 100))
     expanded_sum = sum.doit()
 
     print(sum)
     print(expanded_sum)
 
-    x = pystencils.fields('x: float32[1d]')
+    x = ps.fields('x: float32[1d]')
 
-    assignments = pystencils.AssignmentCollection({
-        x.center(): sum
-    })
+    assignments = ps.AssignmentCollection({x.center(): sum})
 
-    ast = pystencils.create_kernel(assignments, data_type=create_type('float32'))
-    code = str(pystencils.get_code_obj(ast))
+    config = ps.CreateKernelConfig(default_assignment_simplifications=default_assignment_simplifications,
+                                   data_type=create_type('float32'))
+    ast = ps.create_kernel(assignments, config=config)
+    code = ps.get_code_str(ast)
     kernel = ast.compile()
 
     print(code)
-    print(pystencils.get_code_obj(ast))
-    assert 'float sum' in code
+    if default_assignment_simplifications is False:
+        assert 'float sum' in code
 
     array = np.zeros((10,), np.float32)
 
@@ -72,9 +75,10 @@ def test_sum_use_float():
     assert np.allclose(array, int(expanded_sum) * np.ones_like(array))
 
 
-def test_product():
+@pytest.mark.parametrize('default_assignment_simplifications', [False, True])
+def test_product(default_assignment_simplifications):
 
-    k = pystencils.TypedSymbol('k', create_type('int64'))
+    k = ps.TypedSymbol('k', create_type('int64'))
 
     sum = sympy.Product(k, (k, 1, 10))
     expanded_sum = sum.doit()
@@ -82,18 +86,19 @@ def test_product():
     print(sum)
     print(expanded_sum)
 
-    x = pystencils.fields('x: int64[1d]')
+    x = ps.fields('x: int64[1d]')
 
-    assignments = pystencils.AssignmentCollection({
-        x.center(): sum
-    })
+    assignments = ps.AssignmentCollection({x.center(): sum})
 
-    ast = pystencils.create_kernel(assignments)
-    code = pystencils.get_code_str(ast)
+    config = ps.CreateKernelConfig(default_assignment_simplifications=default_assignment_simplifications)
+
+    ast = ps.create_kernel(assignments, config=config)
+    code = ps.get_code_str(ast)
     kernel = ast.compile()
 
     print(code)
-    assert 'int64_t product' in code
+    if default_assignment_simplifications is False:
+        assert 'int64_t product' in code
 
     array = np.zeros((10,), np.int64)
 
@@ -104,8 +109,8 @@ def test_product():
 
 def test_prod_var_limit():
 
-    k = pystencils.TypedSymbol('k', create_type('int64'))
-    limit = pystencils.TypedSymbol('limit', create_type('int64'))
+    k = ps.TypedSymbol('k', create_type('int64'))
+    limit = ps.TypedSymbol('limit', create_type('int64'))
 
     sum = sympy.Sum(k, (k, 1, limit))
     expanded_sum = sum.replace(limit, 100).doit()
@@ -113,14 +118,12 @@ def test_prod_var_limit():
     print(sum)
     print(expanded_sum)
 
-    x = pystencils.fields('x: int64[1d]')
+    x = ps.fields('x: int64[1d]')
 
-    assignments = pystencils.AssignmentCollection({
-        x.center(): sum
-    })
+    assignments = ps.AssignmentCollection({x.center(): sum})
 
-    ast = pystencils.create_kernel(assignments)
-    pystencils.show_code(ast)
+    ast = ps.create_kernel(assignments)
+    ps.show_code(ast)
     kernel = ast.compile()
 
     array = np.zeros((10,), np.int64)
-- 
GitLab