From 362b4611cfa294a84c154428a0126fab62afea8e Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Mon, 23 Apr 2018 12:17:54 +0200
Subject: [PATCH] Clean submodules for data handling and finite differences

---
 __init__.py                                   |   4 +-
 boundaries/boundaryhandling.py                |   4 +-
 vectorization.py => cpu/vectorization.py      |   0
 datahandling/__init__.py                      |  25 +-
 datahandling/parallel_datahandling.py         |   2 +-
 datahandling/serial_datahandling.py           |   4 +-
 vtk.py => datahandling/vtk.py                 |   0
 fd/__init__.py                                |  10 +
 derivative.py => fd/derivative.py             | 206 ++++----
 .../finitedifferences.py                      | 481 +++++++++---------
 kernelcreation.py                             |   2 +-
 test_simplification_strategy.py               |  43 --
 12 files changed, 390 insertions(+), 391 deletions(-)
 rename vectorization.py => cpu/vectorization.py (100%)
 rename vtk.py => datahandling/vtk.py (100%)
 create mode 100644 fd/__init__.py
 rename derivative.py => fd/derivative.py (90%)
 rename finitedifferences.py => fd/finitedifferences.py (98%)
 delete mode 100644 test_simplification_strategy.py

diff --git a/__init__.py b/__init__.py
index dc7ee7421..8f7deed40 100644
--- a/__init__.py
+++ b/__init__.py
@@ -8,6 +8,7 @@ from .display_utils import show_code, to_dot
 from .assignment_collection import AssignmentCollection
 from .assignment import Assignment
 from .sympyextensions import SymbolCreator
+from .datahandling import create_data_handling
 
 __all__ = ['Field', 'FieldType',
            'TypedSymbol',
@@ -16,4 +17,5 @@ __all__ = ['Field', 'FieldType',
            'show_code', 'to_dot',
            'AssignmentCollection',
            'Assignment',
-           'SymbolCreator']
+           'SymbolCreator',
+           'create_data_handling']
diff --git a/boundaries/boundaryhandling.py b/boundaries/boundaryhandling.py
index 314490e7a..73330c33a 100644
--- a/boundaries/boundaryhandling.py
+++ b/boundaries/boundaryhandling.py
@@ -14,8 +14,8 @@ class FlagInterface:
     """Manages the reservation of bits (i.e. flags) in an array of unsigned integers.
 
     Examples:
-        >>> from pystencils.datahandling import SerialDataHandling
-        >>> dh = SerialDataHandling((4, 5))
+        >>> from pystencils import create_data_handling
+        >>> dh = create_data_handling((4, 5))
         >>> fi = FlagInterface(dh, 'flag_field', np.uint8)
         >>> assert dh.has_data('flag_field')
         >>> fi.reserve_next_flag()
diff --git a/vectorization.py b/cpu/vectorization.py
similarity index 100%
rename from vectorization.py
rename to cpu/vectorization.py
diff --git a/datahandling/__init__.py b/datahandling/__init__.py
index 519e9b2fd..54ef48c2c 100644
--- a/datahandling/__init__.py
+++ b/datahandling/__init__.py
@@ -1,4 +1,6 @@
-from pystencils.datahandling.serial_datahandling import SerialDataHandling
+from typing import Tuple, Union
+from .serial_datahandling import SerialDataHandling
+from .datahandling_interface import DataHandling
 
 try:
     # noinspection PyPep8Naming
@@ -12,7 +14,23 @@ except ImportError:
     ParallelDataHandling = None
 
 
-def create_data_handling(parallel, domain_size, periodicity, default_layout='SoA', default_ghost_layers=1):
+def create_data_handling(domain_size: Tuple[int, ...],
+                         periodicity: Union[bool, Tuple[bool, ...]] = False,
+                         default_layout: str = 'SoA',
+                         parallel: bool = False,
+                         default_ghost_layers: int = 1) -> DataHandling:
+    """Creates a data handling instance.
+
+    Args:
+        parallel:
+        domain_size:
+        periodicity:
+        default_layout:
+        default_ghost_layers:
+
+    Returns:
+
+    """
     if parallel:
         if wlb is None:
             raise ValueError("Cannot create parallel data handling because walberla module is not available")
@@ -39,3 +57,6 @@ def create_data_handling(parallel, domain_size, periodicity, default_layout='SoA
     else:
         return SerialDataHandling(domain_size, periodicity=periodicity,
                                   default_layout=default_layout, default_ghost_layers=default_ghost_layers)
+
+
+__all__ = ['create_data_handling']
diff --git a/datahandling/parallel_datahandling.py b/datahandling/parallel_datahandling.py
index e023c6076..6ac9e9241 100644
--- a/datahandling/parallel_datahandling.py
+++ b/datahandling/parallel_datahandling.py
@@ -1,7 +1,7 @@
 import numpy as np
 from pystencils import Field
 from pystencils.datahandling.datahandling_interface import DataHandling
-from pystencils.parallel.blockiteration import sliced_block_iteration, block_iteration
+from pystencils.datahandling.blockiteration import sliced_block_iteration, block_iteration
 from pystencils.utils import DotDict
 # noinspection PyPep8Naming
 import waLBerla as wlb
diff --git a/datahandling/serial_datahandling.py b/datahandling/serial_datahandling.py
index 93e702651..b46845278 100644
--- a/datahandling/serial_datahandling.py
+++ b/datahandling/serial_datahandling.py
@@ -309,7 +309,7 @@ class SerialDataHandling(DataHandling):
         return np.array(sequence)
 
     def create_vtk_writer(self, file_name, data_names, ghost_layers=False):
-        from pystencils.vtk import image_to_vtk
+        from pystencils.datahandling.vtk import image_to_vtk
 
         def writer(step):
             full_file_name = "%s_%08d" % (file_name, step,)
@@ -336,7 +336,7 @@ class SerialDataHandling(DataHandling):
         return writer
 
     def create_vtk_writer_for_flag_array(self, file_name, data_name, masks_to_name, ghost_layers=False):
-        from pystencils.vtk import image_to_vtk
+        from pystencils.datahandling.vtk import image_to_vtk
 
         def writer(step):
             full_file_name = "%s_%08d" % (file_name, step,)
diff --git a/vtk.py b/datahandling/vtk.py
similarity index 100%
rename from vtk.py
rename to datahandling/vtk.py
diff --git a/fd/__init__.py b/fd/__init__.py
new file mode 100644
index 000000000..2fcf3d9cf
--- /dev/null
+++ b/fd/__init__.py
@@ -0,0 +1,10 @@
+from .derivative import Diff, DiffOperator, \
+    diff_terms, collect_diffs, create_nested_diff, replace_diff, zero_diffs, evaluate_diffs, normalize_diff_order, \
+    expand_diff_full, expand_diff_linear, expand_diff_products, combine_diff_products, \
+    functional_derivative
+from .finitedifferences import advection, diffusion, transient, Discretization2ndOrder
+
+
+__all__ = ['Diff', 'DiffOperator', 'diff_terms', 'collect_diffs', 'create_nested_diff', 'replace_diff', 'zero_diffs',
+           'evaluate_diffs', 'normalize_diff_order', 'expand_diff_full', 'expand_diff_linear',
+           'expand_diff_products', 'combine_diff_products', 'functional_derivative']
diff --git a/derivative.py b/fd/derivative.py
similarity index 90%
rename from derivative.py
rename to fd/derivative.py
index 1610f83e8..c2b497a3d 100644
--- a/derivative.py
+++ b/fd/derivative.py
@@ -3,13 +3,14 @@ from collections import namedtuple, defaultdict
 from pystencils.sympyextensions import normalize_product, prod
 
 
-def default_diff_sort_key(d):
+def _default_diff_sort_key(d):
     return str(d.superscript), str(d.target)
 
 
 class Diff(sp.Expr):
-    """
-    Sympy Node representing a derivative. The difference to sympy's built in differential is:
+    """Sympy Node representing a derivative.
+
+    The difference to sympy's built in differential is:
         - shortened latex representation
         - all simplifications have to be done manually
         - optional marker displayed as superscript
@@ -156,7 +157,7 @@ class DiffOperator(sp.Expr):
             if len(diffs) == 0:
                 return mul * argument if apply_to_constants else mul
             rest = [a for a in args if not isinstance(a, DiffOperator)]
-            diffs.sort(key=default_diff_sort_key)
+            diffs.sort(key=_default_diff_sort_key)
             result = argument
             for d in reversed(diffs):
                 result = Diff(result, target=d.target, superscript=d.superscript)
@@ -174,10 +175,10 @@ class DiffOperator(sp.Expr):
 # ----------------------------------------------------------------------------------------------------------------------
 
 
-def derivative_terms(expr):
-    """
-    Returns set of all derivatives in an expression
-    this is different from `expr.atoms(Diff)` when nested derivatives are in the expression,
+def diff_terms(expr):
+    """Returns set of all derivatives in an expression.
+
+    This function yields different results than `expr.atoms(Diff)` when nested derivatives are in the expression,
     since this function only returns the outer derivatives
     """
     result = set()
@@ -193,9 +194,9 @@ def derivative_terms(expr):
     return result
 
 
-def collect_derivatives(expr):
+def collect_diffs(expr):
     """Rewrites expression into a sum of distinct derivatives with pre-factors"""
-    return expr.collect(derivative_terms(expr))
+    return expr.collect(diff_terms(expr))
 
 
 def create_nested_diff(arg, *args):
@@ -208,39 +209,73 @@ def create_nested_diff(arg, *args):
     return res
 
 
-def expand_using_linearity(expr, functions=None, constants=None):
-    """
-    Expands all derivative nodes by applying Diff.split_linear
-    :param expr: expression containing derivatives
-    :param functions: sequence of symbols that are considered functions and can not be pulled before the derivative.
-                      if None, all symbols are viewed as functions
-    :param constants: sequence of symbols which are considered constants and can be pulled before the derivative
-    """
-    if functions is None:
-        functions = expr.atoms(sp.Symbol)
-        if constants is not None:
-            functions.difference_update(constants)
+def replace_diff(expr, replacement_dict):
+    """replacement_dict: maps variable (target) to a new Differential operator"""
+
+    def visit(e):
+        if isinstance(e, Diff):
+            if e.target in replacement_dict:
+                return DiffOperator.apply(replacement_dict[e.target], visit(e.arg))
+        new_args = [visit(arg) for arg in e.args]
+        return e.func(*new_args) if new_args else e
+
+    return visit(expr)
+
+
+def zero_diffs(expr, label):
+    """Replaces all differentials with the given target by 0"""
+
+    def visit(e):
+        if isinstance(e, Diff):
+            if e.target == label:
+                return 0
+        new_args = [visit(arg) for arg in e.args]
+        return e.func(*new_args) if new_args else e
+
+    return visit(expr)
+
+
+def evaluate_diffs(expr, var=None):
+    """Replaces pystencils diff objects by sympy diff objects and evaluates them.
 
+    Replaces Diff nodes by sp.diff , the free variable is either the target (if var=None) otherwise
+    the specified var
+    """
     if isinstance(expr, Diff):
-        arg = expand_using_linearity(expr.arg, functions)
-        if hasattr(arg, 'func') and arg.func == sp.Add:
-            result = 0
-            for a in arg.args:
-                result += Diff(a, target=expr.target, superscript=expr.superscript).split_linear(functions)
+        if var is None:
+            var = expr.target
+        return sp.diff(evaluate_diffs(expr.arg, var), var)
+    else:
+        new_args = [evaluate_diffs(arg, var) for arg in expr.args]
+        return expr.func(*new_args) if new_args else expr
+
+
+def normalize_diff_order(expression, functions=None, constants=None, sort_key=_default_diff_sort_key):
+    """Assumes order of differentiation can be exchanged. Changes the order of nested Diffs to a standard order defined
+    by the sorting key 'sort_key' such that the derivative terms can be further simplified """
+
+    def visit(expr):
+        if isinstance(expr, Diff):
+            nodes = [expr]
+            while isinstance(nodes[-1].arg, Diff):
+                nodes.append(nodes[-1].arg)
+
+            processed_arg = visit(nodes[-1].arg)
+            nodes.sort(key=sort_key)
+
+            result = processed_arg
+            for d in reversed(nodes):
+                result = Diff(result, target=d.target, superscript=d.superscript)
             return result
         else:
-            diff = Diff(arg, target=expr.target, superscript=expr.superscript)
-            if diff == 0:
-                return 0
-            else:
-                return diff.split_linear(functions)
-    else:
-        new_args = [expand_using_linearity(e, functions) for e in expr.args]
-        result = sp.expand(expr.func(*new_args) if new_args else expr)
-        return result
+            new_args = [visit(e) for e in expr.args]
+            return expr.func(*new_args) if new_args else expr
+
+    expression = expand_diff_linear(expression.expand(), functions, constants).expand()
+    return visit(expression)
 
 
-def full_diff_expand(expr, functions=None, constants=None):
+def expand_diff_full(expr, functions=None, constants=None):
     if functions is None:
         functions = expr.atoms(sp.Symbol)
         if constants is not None:
@@ -278,35 +313,43 @@ def full_diff_expand(expr, functions=None, constants=None):
         return visit(expr)
 
 
-def normalize_diff_order(expression, functions=None, constants=None, sort_key=default_diff_sort_key):
-    """Assumes order of differentiation can be exchanged. Changes the order of nested Diffs to a standard order defined
-    by the sorting key 'sort_key' such that the derivative terms can be further simplified """
+def expand_diff_linear(expr, functions=None, constants=None):
+    """Expands all derivative nodes by applying Diff.split_linear
 
-    def visit(expr):
-        if isinstance(expr, Diff):
-            nodes = [expr]
-            while isinstance(nodes[-1].arg, Diff):
-                nodes.append(nodes[-1].arg)
-
-            processed_arg = visit(nodes[-1].arg)
-            nodes.sort(key=sort_key)
+    Args:
+        expr: expression containing derivatives
+        functions: sequence of symbols that are considered functions and can not be pulled before the derivative.
+                   if None, all symbols are viewed as functions
+        constants: sequence of symbols which are considered constants and can be pulled before the derivative
+    """
+    if functions is None:
+        functions = expr.atoms(sp.Symbol)
+        if constants is not None:
+            functions.difference_update(constants)
 
-            result = processed_arg
-            for d in reversed(nodes):
-                result = Diff(result, target=d.target, superscript=d.superscript)
+    if isinstance(expr, Diff):
+        arg = expand_diff_linear(expr.arg, functions)
+        if hasattr(arg, 'func') and arg.func == sp.Add:
+            result = 0
+            for a in arg.args:
+                result += Diff(a, target=expr.target, superscript=expr.superscript).split_linear(functions)
             return result
         else:
-            new_args = [visit(e) for e in expr.args]
-            return expr.func(*new_args) if new_args else expr
-
-    expression = expand_using_linearity(expression.expand(), functions, constants).expand()
-    return visit(expression)
+            diff = Diff(arg, target=expr.target, superscript=expr.superscript)
+            if diff == 0:
+                return 0
+            else:
+                return diff.split_linear(functions)
+    else:
+        new_args = [expand_diff_linear(e, functions) for e in expr.args]
+        result = sp.expand(expr.func(*new_args) if new_args else expr)
+        return result
 
 
-def expand_using_product_rule(expr):
+def expand_diff_products(expr):
     """Fully expands all derivatives by applying product rule"""
     if isinstance(expr, Diff):
-        arg = expand_using_product_rule(expr.args[0])
+        arg = expand_diff_products(expr.args[0])
         if arg.func == sp.Add:
             new_args = [Diff(e, target=expr.target, superscript=expr.superscript)
                         for e in arg.args]
@@ -321,11 +364,11 @@ def expand_using_product_rule(expr):
                 result += pre_factor * Diff(prod_list[i], target=expr.target, superscript=expr.superscript)
             return result
     else:
-        new_args = [expand_using_product_rule(e) for e in expr.args]
+        new_args = [expand_diff_products(e) for e in expr.args]
         return expr.func(*new_args) if new_args else expr
 
 
-def combine_using_product_rule(expr):
+def combine_diff_products(expr):
     """Inverse product rule"""
 
     def expr_to_diff_decomposition(expression):
@@ -408,53 +451,14 @@ def combine_using_product_rule(expr):
                 rest += process_diff_list(diff_list, label, superscript)
             return rest
         else:
-            new_args = [combine_using_product_rule(e) for e in expression.args]
+            new_args = [combine_diff_products(e) for e in expression.args]
             return expression.func(*new_args) if new_args else expression
 
     return combine(expr)
 
 
-def replace_diff(expr, replacement_dict):
-    """replacement_dict: maps variable (target) to a new Differential operator"""
-
-    def visit(e):
-        if isinstance(e, Diff):
-            if e.target in replacement_dict:
-                return DiffOperator.apply(replacement_dict[e.target], visit(e.arg))
-        new_args = [visit(arg) for arg in e.args]
-        return e.func(*new_args) if new_args else e
-
-    return visit(expr)
-
-
-def zero_diffs(expr, label):
-    """Replaces all differentials with the given target by 0"""
-
-    def visit(e):
-        if isinstance(e, Diff):
-            if e.target == label:
-                return 0
-        new_args = [visit(arg) for arg in e.args]
-        return e.func(*new_args) if new_args else e
-
-    return visit(expr)
-
-
-def evaluate_diffs(expr, var=None):
-    """Replaces Diff nodes by sp.diff , the free variable is either the target (if var=None) otherwise
-    the specified var"""
-    if isinstance(expr, Diff):
-        if var is None:
-            var = expr.target
-        return sp.diff(evaluate_diffs(expr.arg, var), var)
-    else:
-        new_args = [evaluate_diffs(arg, var) for arg in expr.args]
-        return expr.func(*new_args) if new_args else expr
-
-
 def functional_derivative(functional, v):
-    r"""
-    Computes functional derivative of functional with respect to v using Euler-Lagrange equation
+    r"""Computes functional derivative of functional with respect to v using Euler-Lagrange equation
 
     .. math ::
 
diff --git a/finitedifferences.py b/fd/finitedifferences.py
similarity index 98%
rename from finitedifferences.py
rename to fd/finitedifferences.py
index 26124f999..b6f05f30e 100644
--- a/finitedifferences.py
+++ b/fd/finitedifferences.py
@@ -4,146 +4,143 @@ import sympy as sp
 from pystencils.assignment_collection import AssignmentCollection
 from pystencils.field import Field
 from pystencils.sympyextensions import fast_subs
-from pystencils.derivative import Diff
+from pystencils.fd.derivative import Diff
 
 
-def grad(var, dim=3):
-    r"""
-    Gradients are represented as a special symbol:
-    e.g. :math:`\nabla x = (x^{\Delta 0}, x^{\Delta 1}, x^{\Delta 2})`
-
-    This function takes a symbol and creates the gradient symbols according to convention above
+# --------------------------------------- Advection Diffusion ----------------------------------------------------------
 
-    :param var: symbol to take the gradient of
-    :param dim: dimension (length) of the gradient vector
-    """
-    if hasattr(var, "__getitem__"):
-        return [[sp.Symbol("%s^Delta^%d" % (v.name, i)) for v in var] for i in range(dim)]
+def advection(advected_scalar, velocity_field, idx=None):
+    """Advection term: divergence( velocity_field * advected_scalar )"""
+    if isinstance(advected_scalar, Field):
+        first_arg = advected_scalar.center
+    elif isinstance(advected_scalar, Field.Access):
+        first_arg = advected_scalar
     else:
-        return [sp.Symbol("%s^Delta^%d" % (var.name, i)) for i in range(dim)]
-
-
-def discretize_center(term, symbols_to_field_dict, dx, dim=3):
-    """
-    Expects term that contains given symbols and gradient components of these symbols and replaces them
-    by field accesses. Gradients are replaced by centralized approximations:
-    ``(upper neighbor - lower neighbor ) / ( 2*dx)``
-    :param term: term where symbols and gradient(symbol) should be replaced
-    :param symbols_to_field_dict: mapping of symbols to Field
-    :param dx: width and height of one cell
-    :param dim: dimension
+        raise ValueError("Advected scalar has to be a pystencils Field or Field.Access")
 
-    Example:
-      >>> x = sp.Symbol("x")
-      >>> grad_x = grad(x, dim=3)
-      >>> term = x * grad_x[0]
-      >>> term
-      x*x^Delta^0
-      >>> f = Field.create_generic('f', spatial_dimensions=3)
-      >>> discretize_center(term, { x: f }, dx=1, dim=3)
-      f_C*(f_E/2 - f_W/2)
-    """
-    substitutions = {}
-    for symbols, field in symbols_to_field_dict.items():
-        if not hasattr(symbols, "__getitem__"):
-            symbols = [symbols]
-        g = grad(symbols, dim)
-        substitutions.update({symbol: field(i) for i, symbol in enumerate(symbols)})
-        for d in range(dim):
-            up, down = __up_down_offsets(d, dim)
-            substitutions.update({g[d][i]: (field[up](i) - field[down](i)) / dx / 2 for i in range(len(symbols))})
-    return term.subs(substitutions)
+    args = [first_arg, velocity_field if not isinstance(velocity_field, Field) else velocity_field.center]
+    if idx is not None:
+        args.append(idx)
+    return Advection(*args)
 
 
-def discretize_staggered(term, symbols_to_field_dict, coordinate, coordinate_offset, dx, dim=3):
-    """
-    Expects term that contains given symbols and gradient components of these symbols and replaces them
-    by field accesses. Gradients in coordinate direction  are replaced by staggered version at cell boundary.
-    Symbols themselves and gradients in other directions are replaced by interpolated version at cell face.
+def diffusion(scalar, diffusion_coeff, idx=None):
+    if isinstance(scalar, Field):
+        first_arg = scalar.center
+    elif isinstance(scalar, Field.Access):
+        first_arg = scalar
+    else:
+        raise ValueError("Diffused scalar has to be a pystencils Field or Field.Access")
 
-    Args:
-        term: input term where symbols and gradients are replaced
-        symbols_to_field_dict: mapping of symbols to Field
-        coordinate: id for coordinate (0 for x, 1 for y, ... ) defining cell boundary.
-                    Only gradients in this direction are replaced e.g. if symbol^Delta^coordinate
-        coordinate_offset: either +1 or -1 for upper or lower face in coordinate direction
-        dx: width and height of one cell
-        dim: dimension
+    args = [first_arg, diffusion_coeff if not isinstance(diffusion_coeff, Field) else diffusion_coeff.center]
+    if idx is not None:
+        args.append(idx)
+    return Diffusion(*args)
 
-    Examples:
-      Discretizing at right/east face of cell i.e. coordinate=0, offset=1)
-      >>> x, dx = sp.symbols("x dx")
-      >>> grad_x = grad(x, dim=3)
-      >>> term = x * grad_x[0]
-      >>> term
-      x*x^Delta^0
-      >>> f = Field.create_generic('f', spatial_dimensions=3)
-      >>> discretize_staggered(term, symbols_to_field_dict={ x: f}, dx=dx, coordinate=0, coordinate_offset=1, dim=3)
-      (-f_C + f_E)*(f_C/2 + f_E/2)/dx
-    """
-    assert coordinate_offset == 1 or coordinate_offset == -1
-    assert 0 <= coordinate < dim
 
-    substitutions = {}
-    for symbols, field in symbols_to_field_dict.items():
-        if not hasattr(symbols, "__getitem__"):
-            symbols = [symbols]
+def transient(scalar, idx=None):
+    if isinstance(scalar, Field):
+        args = [scalar.center]
+    elif isinstance(scalar, Field.Access):
+        args = [scalar]
+    else:
+        raise ValueError("Scalar has to be a pystencils Field or Field.Access")
+    if idx is not None:
+        args.append(idx)
+    return Transient(*args)
 
-        offset = [0] * dim
-        offset[coordinate] = coordinate_offset
-        offset = np.array(offset, dtype=np.int)
 
-        gradient = grad(symbols)[coordinate]
-        substitutions.update({s: (field[offset](i) + field(i)) / 2 for i, s in enumerate(symbols)})
-        substitutions.update({g: (field[offset](i) - field(i)) / dx * coordinate_offset
-                              for i, g in enumerate(gradient)})
-        for d in range(dim):
-            if d == coordinate:
-                continue
-            up, down = __up_down_offsets(d, dim)
-            for i, s in enumerate(symbols):
-                center_grad = (field[up](i) - field[down](i)) / (2 * dx)
-                neighbor_grad = (field[up + offset](i) - field[down + offset](i)) / (2 * dx)
-                substitutions[grad(s)[d]] = (center_grad + neighbor_grad) / 2
+class Discretization2ndOrder:
+    def __init__(self, dx=sp.Symbol("dx"), dt=sp.Symbol("dt")):
+        self.dx = dx
+        self.dt = dt
 
-    return fast_subs(term, substitutions)
+    @staticmethod
+    def __diff_order(e):
+        if not isinstance(e, Diff):
+            return 0
+        else:
+            return 1 + Discretization2ndOrder.__diff_order(e.args[0])
 
+    def _discretize_diffusion(self, expr):
+        result = 0
+        for c in range(expr.dim):
+            first_diffs = [offset *
+                           (expr.diffusion_scalar_at_offset(c, offset) * expr.diffusion_coefficient_at_offset(c, offset)
+                            - expr.diffusion_scalar_at_offset(0, 0) * expr.diffusion_coefficient_at_offset(0, 0))
+                           for offset in [-1, 1]]
+            result += first_diffs[1] - first_diffs[0]
+        return result / (self.dx ** 2)
 
-def discretize_divergence(vector_term, symbols_to_field_dict, dx):
-    """
-    Computes discrete divergence of symbolic vector
+    def _discretize_advection(self, expr):
+        result = 0
+        for c in range(expr.dim):
+            interpolated = [(expr.advected_scalar_at_offset(c, offset) * expr.velocity_field_at_offset(c, offset, c) +
+                             expr.advected_scalar_at_offset(c, 0) * expr.velocity_field_at_offset(c, 0, c)) / 2
+                            for offset in [-1, 1]]
+            result += interpolated[1] - interpolated[0]
+        return result / self.dx
 
-    Args:
-        vector_term: sequence of terms, interpreted as vector
-        symbols_to_field_dict: mapping of symbols to Field
-        dx: length of a cell
+    def _discretize_spatial(self, e):
+        if isinstance(e, Diffusion):
+            return self._discretize_diffusion(e)
+        elif isinstance(e, Advection):
+            return self._discretize_advection(e)
+        elif isinstance(e, Diff):
+            return self._discretize_diff(e)
+        else:
+            new_args = [self._discretize_spatial(a) for a in e.args]
+            return e.func(*new_args) if new_args else e
 
-    Examples:
-        Laplace stencil
-        >>> x, dx = sp.symbols("x dx")
-        >>> grad_x = grad(x, dim=3)
-        >>> f = Field.create_generic('f', spatial_dimensions=3)
-        >>> sp.simplify(discretize_divergence(grad_x, {x : f}, dx))
-        (f_B - 6*f_C + f_E + f_N + f_S + f_T + f_W)/dx**2
-    """
-    dim = len(vector_term)
-    result = 0
-    for d in range(dim):
-        for offset in [-1, 1]:
-            result += offset * discretize_staggered(vector_term[d], symbols_to_field_dict, d, offset, dx, dim)
-    return result / dx
+    def _discretize_diff(self, e):
+        order = self.__diff_order(e)
+        if order == 1:
+            fa = e.args[0]
+            index = e.target
+            return (fa.neighbor(index, 1) - fa.neighbor(index, -1)) / (2 * self.dx)
+        elif order == 2:
+            indices = sorted([e.target, e.args[0].target])
+            fa = e.args[0].args[0]
+            if indices[0] == indices[1] and all(i >= 0 for i in indices):
+                result = (-2 * fa + fa.neighbor(indices[0], -1) + fa.neighbor(indices[0], +1))
+            elif indices[0] == indices[1]:
+                result = 0
+                for d in range(fa.field.spatial_dimensions):
+                    result += (-2 * fa + fa.neighbor(d, -1) + fa.neighbor(d, +1))
+            else:
+                assert all(i >= 0 for i in indices)
+                offsets = [(1, 1), [-1, 1], [1, -1], [-1, -1]]
+                result = sum(o1 * o2 * fa.neighbor(indices[0], o1).neighbor(indices[1], o2) for o1, o2 in offsets) / 4
+            return result / (self.dx ** 2)
+        else:
+            raise NotImplementedError("Term contains derivatives of order > 2")
 
+    def __call__(self, expr):
+        if isinstance(expr, list):
+            return [self(e) for e in expr]
+        elif isinstance(expr, sp.Matrix):
+            return expr.applyfunc(self.__call__)
+        elif isinstance(expr, AssignmentCollection):
+            return expr.copy(main_assignments=[e for e in expr.main_assignments],
+                             subexpressions=[e for e in expr.subexpressions])
 
-def __up_down_offsets(d, dim):
-    coord = [0] * dim
-    coord[d] = 1
-    up = np.array(coord, dtype=np.int)
-    coord[d] = -1
-    down = np.array(coord, dtype=np.int)
-    return up, down
+        transient_terms = expr.atoms(Transient)
+        if len(transient_terms) == 0:
+            return self._discretize_spatial(expr)
+        elif len(transient_terms) == 1:
+            transient_term = transient_terms.pop()
+            solve_result = sp.solve(expr, transient_term)
+            if len(solve_result) != 1:
+                raise ValueError("Could not solve for transient term" + str(solve_result))
+            rhs = solve_result.pop()
+            # explicit euler
+            return transient_term.scalar + self.dt * self._discretize_spatial(rhs)
+        else:
+            print(transient_terms)
+            raise NotImplementedError("Cannot discretize expression with more than one transient term")
 
 
-# --------------------------------------- Advection Diffusion ----------------------------------------------------------
+# -------------------------------------- Helper Classes ----------------------------------------------------------------
 
 class Advection(sp.Function):
 
@@ -192,21 +189,6 @@ class Advection(sp.Function):
         return self.scalar.neighbor(offset_dim, offset_value)(idx)
 
 
-def advection(advected_scalar, velocity_field, idx=None):
-    """Advection term: divergence( velocity_field * advected_scalar )"""
-    if isinstance(advected_scalar, Field):
-        first_arg = advected_scalar.center
-    elif isinstance(advected_scalar, Field.Access):
-        first_arg = advected_scalar
-    else:
-        raise ValueError("Advected scalar has to be a pystencils Field or Field.Access")
-
-    args = [first_arg, velocity_field if not isinstance(velocity_field, Field) else velocity_field.center]
-    if idx is not None:
-        args.append(idx)
-    return Advection(*args)
-
-
 class Diffusion(sp.Function):
 
     @property
@@ -249,20 +231,6 @@ class Diffusion(sp.Function):
             return d
 
 
-def diffusion(scalar, diffusion_coeff, idx=None):
-    if isinstance(scalar, Field):
-        first_arg = scalar.center
-    elif isinstance(scalar, Field.Access):
-        first_arg = scalar
-    else:
-        raise ValueError("Diffused scalar has to be a pystencils Field or Field.Access")
-
-    args = [first_arg, diffusion_coeff if not isinstance(diffusion_coeff, Field) else diffusion_coeff.center]
-    if idx is not None:
-        args.append(idx)
-    return Diffusion(*args)
-
-
 class Transient(sp.Function):
     @property
     def scalar(self):
@@ -280,103 +248,140 @@ class Transient(sp.Function):
         return r"\partial_t %s" % (printer.doprint(sp.Symbol(self.scalar.name + name_suffix)),)
 
 
-def transient(scalar, idx=None):
-    if isinstance(scalar, Field):
-        args = [scalar.center]
-    elif isinstance(scalar, Field.Access):
-        args = [scalar]
+# -------------------------------------------- Deprecated Functions ----------------------------------------------------
+
+
+def grad(var, dim=3):
+    r"""
+    Gradients are represented as a special symbol:
+    e.g. :math:`\nabla x = (x^{\Delta 0}, x^{\Delta 1}, x^{\Delta 2})`
+
+    This function takes a symbol and creates the gradient symbols according to convention above
+
+    :param var: symbol to take the gradient of
+    :param dim: dimension (length) of the gradient vector
+    """
+    if hasattr(var, "__getitem__"):
+        return [[sp.Symbol("%s^Delta^%d" % (v.name, i)) for v in var] for i in range(dim)]
     else:
-        raise ValueError("Scalar has to be a pystencils Field or Field.Access")
-    if idx is not None:
-        args.append(idx)
-    return Transient(*args)
+        return [sp.Symbol("%s^Delta^%d" % (var.name, i)) for i in range(dim)]
 
 
-class Discretization2ndOrder:
-    def __init__(self, dx=sp.Symbol("dx"), dt=sp.Symbol("dt")):
-        self.dx = dx
-        self.dt = dt
+def discretize_center(term, symbols_to_field_dict, dx, dim=3):
+    """
+    Expects term that contains given symbols and gradient components of these symbols and replaces them
+    by field accesses. Gradients are replaced by centralized approximations:
+    ``(upper neighbor - lower neighbor ) / ( 2*dx)``
+    :param term: term where symbols and gradient(symbol) should be replaced
+    :param symbols_to_field_dict: mapping of symbols to Field
+    :param dx: width and height of one cell
+    :param dim: dimension
 
-    @staticmethod
-    def __diff_order(e):
-        if not isinstance(e, Diff):
-            return 0
-        else:
-            return 1 + Discretization2ndOrder.__diff_order(e.args[0])
+    Example:
+      >>> x = sp.Symbol("x")
+      >>> grad_x = grad(x, dim=3)
+      >>> term = x * grad_x[0]
+      >>> term
+      x*x^Delta^0
+      >>> f = Field.create_generic('f', spatial_dimensions=3)
+      >>> discretize_center(term, { x: f }, dx=1, dim=3)
+      f_C*(f_E/2 - f_W/2)
+    """
+    substitutions = {}
+    for symbols, field in symbols_to_field_dict.items():
+        if not hasattr(symbols, "__getitem__"):
+            symbols = [symbols]
+        g = grad(symbols, dim)
+        substitutions.update({symbol: field(i) for i, symbol in enumerate(symbols)})
+        for d in range(dim):
+            up, down = __up_down_offsets(d, dim)
+            substitutions.update({g[d][i]: (field[up](i) - field[down](i)) / dx / 2 for i in range(len(symbols))})
+    return term.subs(substitutions)
 
-    def _discretize_diffusion(self, expr):
-        result = 0
-        for c in range(expr.dim):
-            first_diffs = [offset *
-                           (expr.diffusion_scalar_at_offset(c, offset) * expr.diffusion_coefficient_at_offset(c, offset)
-                            - expr.diffusion_scalar_at_offset(0, 0) * expr.diffusion_coefficient_at_offset(0, 0))
-                           for offset in [-1, 1]]
-            result += first_diffs[1] - first_diffs[0]
-        return result / (self.dx ** 2)
 
-    def _discretize_advection(self, expr):
-        result = 0
-        for c in range(expr.dim):
-            interpolated = [(expr.advected_scalar_at_offset(c, offset) * expr.velocity_field_at_offset(c, offset, c) +
-                             expr.advected_scalar_at_offset(c, 0) * expr.velocity_field_at_offset(c, 0, c)) / 2
-                            for offset in [-1, 1]]
-            result += interpolated[1] - interpolated[0]
-        return result / self.dx
+def discretize_staggered(term, symbols_to_field_dict, coordinate, coordinate_offset, dx, dim=3):
+    """
+    Expects term that contains given symbols and gradient components of these symbols and replaces them
+    by field accesses. Gradients in coordinate direction  are replaced by staggered version at cell boundary.
+    Symbols themselves and gradients in other directions are replaced by interpolated version at cell face.
 
-    def _discretize_spatial(self, e):
-        if isinstance(e, Diffusion):
-            return self._discretize_diffusion(e)
-        elif isinstance(e, Advection):
-            return self._discretize_advection(e)
-        elif isinstance(e, Diff):
-            return self._discretize_diff(e)
-        else:
-            new_args = [self._discretize_spatial(a) for a in e.args]
-            return e.func(*new_args) if new_args else e
+    Args:
+        term: input term where symbols and gradients are replaced
+        symbols_to_field_dict: mapping of symbols to Field
+        coordinate: id for coordinate (0 for x, 1 for y, ... ) defining cell boundary.
+                    Only gradients in this direction are replaced e.g. if symbol^Delta^coordinate
+        coordinate_offset: either +1 or -1 for upper or lower face in coordinate direction
+        dx: width and height of one cell
+        dim: dimension
 
-    def _discretize_diff(self, e):
-        order = self.__diff_order(e)
-        if order == 1:
-            fa = e.args[0]
-            index = e.target
-            return (fa.neighbor(index, 1) - fa.neighbor(index, -1)) / (2 * self.dx)
-        elif order == 2:
-            indices = sorted([e.target, e.args[0].target])
-            fa = e.args[0].args[0]
-            if indices[0] == indices[1] and all(i >= 0 for i in indices):
-                result = (-2 * fa + fa.neighbor(indices[0], -1) + fa.neighbor(indices[0], +1))
-            elif indices[0] == indices[1]:
-                result = 0
-                for d in range(fa.field.spatial_dimensions):
-                    result += (-2 * fa + fa.neighbor(d, -1) + fa.neighbor(d, +1))
-            else:
-                assert all(i >= 0 for i in indices)
-                offsets = [(1, 1), [-1, 1], [1, -1], [-1, -1]]
-                result = sum(o1 * o2 * fa.neighbor(indices[0], o1).neighbor(indices[1], o2) for o1, o2 in offsets) / 4
-            return result / (self.dx ** 2)
-        else:
-            raise NotImplementedError("Term contains derivatives of order > 2")
+    Examples:
+      Discretizing at right/east face of cell i.e. coordinate=0, offset=1)
+      >>> x, dx = sp.symbols("x dx")
+      >>> grad_x = grad(x, dim=3)
+      >>> term = x * grad_x[0]
+      >>> term
+      x*x^Delta^0
+      >>> f = Field.create_generic('f', spatial_dimensions=3)
+      >>> discretize_staggered(term, symbols_to_field_dict={ x: f}, dx=dx, coordinate=0, coordinate_offset=1, dim=3)
+      (-f_C + f_E)*(f_C/2 + f_E/2)/dx
+    """
+    assert coordinate_offset == 1 or coordinate_offset == -1
+    assert 0 <= coordinate < dim
 
-    def __call__(self, expr):
-        if isinstance(expr, list):
-            return [self(e) for e in expr]
-        elif isinstance(expr, sp.Matrix):
-            return expr.applyfunc(self.__call__)
-        elif isinstance(expr, AssignmentCollection):
-            return expr.copy(main_assignments=[e for e in expr.main_assignments],
-                             subexpressions=[e for e in expr.subexpressions])
+    substitutions = {}
+    for symbols, field in symbols_to_field_dict.items():
+        if not hasattr(symbols, "__getitem__"):
+            symbols = [symbols]
 
-        transient_terms = expr.atoms(Transient)
-        if len(transient_terms) == 0:
-            return self._discretize_spatial(expr)
-        elif len(transient_terms) == 1:
-            transient_term = transient_terms.pop()
-            solve_result = sp.solve(expr, transient_term)
-            if len(solve_result) != 1:
-                raise ValueError("Could not solve for transient term" + str(solve_result))
-            rhs = solve_result.pop()
-            # explicit euler
-            return transient_term.scalar + self.dt * self._discretize_spatial(rhs)
-        else:
-            print(transient_terms)
-            raise NotImplementedError("Cannot discretize expression with more than one transient term")
+        offset = [0] * dim
+        offset[coordinate] = coordinate_offset
+        offset = np.array(offset, dtype=np.int)
+
+        gradient = grad(symbols)[coordinate]
+        substitutions.update({s: (field[offset](i) + field(i)) / 2 for i, s in enumerate(symbols)})
+        substitutions.update({g: (field[offset](i) - field(i)) / dx * coordinate_offset
+                              for i, g in enumerate(gradient)})
+        for d in range(dim):
+            if d == coordinate:
+                continue
+            up, down = __up_down_offsets(d, dim)
+            for i, s in enumerate(symbols):
+                center_grad = (field[up](i) - field[down](i)) / (2 * dx)
+                neighbor_grad = (field[up + offset](i) - field[down + offset](i)) / (2 * dx)
+                substitutions[grad(s)[d]] = (center_grad + neighbor_grad) / 2
+
+    return fast_subs(term, substitutions)
+
+
+def discretize_divergence(vector_term, symbols_to_field_dict, dx):
+    """
+    Computes discrete divergence of symbolic vector
+
+    Args:
+        vector_term: sequence of terms, interpreted as vector
+        symbols_to_field_dict: mapping of symbols to Field
+        dx: length of a cell
+
+    Examples:
+        Laplace stencil
+        >>> x, dx = sp.symbols("x dx")
+        >>> grad_x = grad(x, dim=3)
+        >>> f = Field.create_generic('f', spatial_dimensions=3)
+        >>> sp.simplify(discretize_divergence(grad_x, {x : f}, dx))
+        (f_B - 6*f_C + f_E + f_N + f_S + f_T + f_W)/dx**2
+    """
+    dim = len(vector_term)
+    result = 0
+    for d in range(dim):
+        for offset in [-1, 1]:
+            result += offset * discretize_staggered(vector_term[d], symbols_to_field_dict, d, offset, dx, dim)
+    return result / dx
+
+
+def __up_down_offsets(d, dim):
+    coord = [0] * dim
+    coord[d] = 1
+    up = np.array(coord, dtype=np.int)
+    coord[d] = -1
+    down = np.array(coord, dtype=np.int)
+    return up, down
diff --git a/kernelcreation.py b/kernelcreation.py
index ffd8c0d38..be918b0bd 100644
--- a/kernelcreation.py
+++ b/kernelcreation.py
@@ -54,7 +54,7 @@ def create_kernel(assignments, target='cpu', data_type="double", iteration_slice
             add_openmp(ast, num_threads=cpu_openmp)
         if cpu_vectorize_info:
             import pystencils.backends.simd_instruction_sets as vec
-            from pystencils.vectorization import vectorize
+            from pystencils.cpu.vectorization import vectorize
             vec_params = cpu_vectorize_info
             vec.selected_instruction_set = vec.x86_vector_instruction_set(instruction_set=vec_params[0],
                                                                           data_type=vec_params[1])
diff --git a/test_simplification_strategy.py b/test_simplification_strategy.py
deleted file mode 100644
index 8087c65b8..000000000
--- a/test_simplification_strategy.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import sympy as sp
-from pystencils import Assignment, AssignmentCollection
-from pystencils.assignment_collection import SimplificationStrategy, apply_on_all_subexpressions, \
-    subexpression_substitution_in_existing_subexpressions
-
-
-def test_simplification_strategy():
-    a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
-    s0, s1, s2, s3 = sp.symbols("s_:4")
-    a0, a1, a2, a3 = sp.symbols("a_:4")
-
-    subexpressions = [
-        Assignment(s0, 2 * a + 2 * b),
-        Assignment(s1, 2 * a + 2 * b + 2 * c),
-        Assignment(s2, 2 * a + 2 * b + 2 * c + 2 * d),
-    ]
-    main = [
-        Assignment(a0, s0 + s1),
-        Assignment(a1, s0 + s2),
-        Assignment(a2, s1 + s2),
-    ]
-    ac = AssignmentCollection(main, subexpressions)
-
-    strategy = SimplificationStrategy()
-    strategy.add(subexpression_substitution_in_existing_subexpressions)
-    strategy.add(apply_on_all_subexpressions(sp.factor))
-
-    result = strategy(ac)
-    assert result.operation_count['adds'] == 7
-    assert result.operation_count['muls'] == 5
-    assert result.operation_count['divs'] == 0
-
-    # Trigger display routines, such that they are at least executed
-    report = strategy.show_intermediate_results(ac, symbols=[s0])
-    assert 's_0' in str(report)
-    report = strategy.show_intermediate_results(ac)
-    assert 's_{1}' in report._repr_html_()
-
-    report = strategy.create_simplification_report(ac)
-    assert 'Adds' in str(report)
-    assert 'Adds' in report._repr_html_()
-
-    assert 'factor' in str(strategy)
-- 
GitLab