From 8a517fd10853f1c7300d8ab3d78377f0b676b9e2 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Wed, 19 Sep 2018 12:50:08 +0200
Subject: [PATCH] Moved stencils from lbmpy->pystencils, New coeff extraction
 funcs

- new functions to easily extract stencil coefficients and visualize them
- Moved stencil functions from lbmpy to pystencils
---
 __init__.py                         |   4 +-
 datahandling/serial_datahandling.py |  13 ++
 field.py                            |  10 +
 stencils.py                         | 340 ++++++++++++++++++++++++++++
 utils.py                            |  17 ++
 5 files changed, 383 insertions(+), 1 deletion(-)
 create mode 100644 stencils.py

diff --git a/__init__.py b/__init__.py
index 5dd9cd572..699662a56 100644
--- a/__init__.py
+++ b/__init__.py
@@ -10,6 +10,7 @@ from .assignment import Assignment
 from .sympyextensions import SymbolCreator
 from .datahandling import create_data_handling
 from .kernel_decorator import kernel
+from .stencils import visualize_stencil_expression
 from . import fd
 
 __all__ = ['Field', 'FieldType', 'fields',
@@ -22,4 +23,5 @@ __all__ = ['Field', 'FieldType', 'fields',
            'SymbolCreator',
            'create_data_handling',
            'kernel',
-           'fd']
+           'fd',
+           'visualize_stencil_expression']
diff --git a/datahandling/serial_datahandling.py b/datahandling/serial_datahandling.py
index 95d4e6220..7aa5d2de6 100644
--- a/datahandling/serial_datahandling.py
+++ b/datahandling/serial_datahandling.py
@@ -381,3 +381,16 @@ class SerialDataHandling(DataHandling):
 
     def save_all(self, file):
         np.savez_compressed(file, **self.cpu_arrays)
+
+    def load_all(self, file):
+        file_contents = np.load(file)
+        for arr_name, arr_contents in self.cpu_arrays.items():
+            if arr_name not in file_contents:
+                print("Skipping read data {} because there is no data with this name in data handling".format(arr_name))
+                continue
+            if file_contents[arr_name].shape != arr_contents.shape:
+                print("Skipping read data {} because shapes don't match. "
+                      "Read array shape {}, exising array shape {}".format(arr_name, file_contents[arr_name].shape,
+                                                                           arr_contents.shape))
+                continue
+            np.copyto(arr_contents, file_contents[arr_name])
diff --git a/field.py b/field.py
index 47cd0955c..d0e95b489 100644
--- a/field.py
+++ b/field.py
@@ -602,6 +602,16 @@ class Field:
             else:
                 return "{{%s}_{%s}}" % (n, offset_str)
 
+        def __str__(self):
+            n = self._field.latex_name if self._field.latex_name else self._field.name
+            offset_str = ",".join([sp.latex(o) for o in self.offsets])
+            if self.is_absolute_access:
+                offset_str = "[abs]{}".format(offset_str)
+            if self.index and self.index != (0,):
+                return "%s[%s](%s)" % (n, offset_str, self.index if len(self.index) > 1 else self.index[0])
+            else:
+                return "%s[%s]" % (n, offset_str)
+
 
 def get_layout_from_strides(strides: Sequence[int], index_dimension_ids: Optional[List[int]] = None):
     index_dimension_ids = [] if index_dimension_ids is None else index_dimension_ids
diff --git a/stencils.py b/stencils.py
new file mode 100644
index 000000000..0efe601fd
--- /dev/null
+++ b/stencils.py
@@ -0,0 +1,340 @@
+import sympy as sp
+from collections import defaultdict
+from pystencils import Field
+
+
+def inverse_direction(direction):
+    """Returns inverse i.e. negative of given direction tuple"""
+    return tuple([-i for i in direction])
+
+
+def is_valid_stencil(stencil, max_neighborhood=None):
+    """
+    Tests if a nested sequence is a valid stencil i.e. all the inner sequences have the same length.
+    If max_neighborhood is specified, it is also verified that the stencil does not contain any direction components
+    with absolute value greater than the maximal neighborhood.
+    """
+    expected_dim = len(stencil[0])
+    for d in stencil:
+        if len(d) != expected_dim:
+            return False
+        if max_neighborhood is not None:
+            for d_i in d:
+                if abs(d_i) > max_neighborhood:
+                    return False
+    return True
+
+
+def is_symmetric_stencil(stencil):
+    """Tests for every direction d, that -d is also in the stencil"""
+    for d in stencil:
+        if inverse_direction(d) not in stencil:
+            return False
+    return True
+
+
+def stencils_have_same_entries(s1, s2):
+    if len(s1) != len(s2):
+        return False
+    return len(set(s1) - set(s2)) == 0
+
+
+# -------------------------------------Expression - Coefficient Form Conversion ----------------------------------------
+
+
+def stencil_coefficient_dict(expr):
+    """Extracts coefficients in front of field accesses in a expression.
+
+    Expression may only access a single field at a single index.
+
+    Returns:
+        center, coefficient dict, nonlinear part
+        where center is the single field that is accessed in expression accessed at center
+        and coefficient dict maps offsets to coefficients. The nonlinear part is everything that is not in the form of
+        coefficient times field access.
+
+    Examples:
+        >>> import pystencils as ps
+        >>> f = ps.fields("f(3) : double[2D]")
+        >>> field, coeffs, nonlinear_part = stencil_coefficient_dict(2 * f[0, 1](1) + 3 * f[-1, 0](1) + 123)
+        >>> assert nonlinear_part == 123 and field == f(1)
+        >>> sorted(coeffs.items())
+        [((-1, 0), 3), ((0, 1), 2)]
+    """
+    expr = expr.expand()
+    field_accesses = expr.atoms(Field.Access)
+    fields = set(fa.field for fa in field_accesses)
+    accessed_indices = set(fa.index for fa in field_accesses)
+
+    if len(fields) != 1:
+        raise ValueError("Could not extract stencil coefficients. "
+                         "Expression has to be a linear function of exactly one field.")
+    if len(accessed_indices) != 1:
+        raise ValueError("Could not extract stencil coefficients. Field is accessed at multiple indices")
+
+    field = fields.pop()
+    idx = accessed_indices.pop()
+
+    coefficients = defaultdict(lambda: 0)
+    coefficients.update({fa.offsets: expr.coeff(fa) for fa in field_accesses})
+
+    linear_part = sum(c * field[off](*idx) for off, c in coefficients.items())
+    nonlinear_part = expr - linear_part
+    return field(*idx), coefficients, nonlinear_part
+
+
+def stencil_coefficients(expr):
+    """Returns two lists - one with accessed offsets and one with their coefficients.
+
+    Same restrictions as `stencil_coefficient_dict` apply. Expression must not have any nonlinear part
+
+    >>> import pystencils as ps
+    >>> f = ps.fields("f(3) : double[2D]")
+    >>> stencil_coefficients(2 * f[0, 1](1) + 3 * f[-1, 0](1))
+    ([(-1, 0), (0, 1)], [3, 2])
+    """
+    field_center, coefficients, nonlinear_part = stencil_coefficient_dict(expr)
+    assert nonlinear_part == 0
+    stencil = list(coefficients.keys())
+    entries = [coefficients[c] for c in stencil]
+    return stencil, entries
+
+
+def stencil_coefficient_list(expr, matrix_form=False):
+    """Returns stencil coefficients in the form of nested lists
+
+    Same restrictions as `stencil_coefficient_dict` apply. Expression must not have any nonlinear part
+
+    Examples:
+        >>> import pystencils as ps
+        >>> f = ps.fields("f: double[2D]")
+        >>> stencil_coefficient_list(2 * f[0, 1] + 3 * f[-1, 0])
+        [[0, 0, 0], [3, 0, 0], [0, 2, 0]]
+        >>> stencil_coefficient_list(2 * f[0, 1] + 3 * f[-1, 0], matrix_form=True)
+        Matrix([
+        [0, 2, 0],
+        [3, 0, 0],
+        [0, 0, 0]])
+    """
+    field_center, coefficients, nonlinear_part = stencil_coefficient_dict(expr)
+    assert nonlinear_part == 0
+    field = field_center.field
+
+    dim = field.spatial_dimensions
+    max_offsets = defaultdict(lambda: 0)
+    for offset in coefficients.keys():
+        for d, off in enumerate(offset):
+            max_offsets[d] = max(max_offsets[d], abs(off))
+
+    if dim == 1:
+        result = [coefficients[(i,)] for i in range(-max_offsets[0], max_offsets[0] + 1)]
+        return sp.Matrix(result) if matrix_form else result
+    else:
+        y_range = list(range(-max_offsets[1], max_offsets[1] + 1))
+        if matrix_form:
+            y_range.reverse()
+        if dim == 2:
+            result = [[coefficients[(i, j)]
+                       for i in range(-max_offsets[0], max_offsets[0] + 1)]
+                      for j in y_range]
+            return sp.Matrix(result) if matrix_form else result
+        elif dim == 3:
+            result = [[[coefficients[(i, j, k)]
+                        for i in range(-max_offsets[0], max_offsets[0] + 1)]
+                       for j in y_range]
+                      for k in range(-max_offsets[2], max_offsets[2] + 1)]
+            return [sp.Matrix(l) for l in result] if matrix_form else result
+        else:
+            raise ValueError("Can only handle fields with 1,2 or 3 spatial dimensions")
+
+
+# -------------------------------------- Visualization -----------------------------------------------------------------
+
+
+def visualize_stencil(stencil, **kwargs):
+    dim = len(stencil[0])
+    if dim == 2:
+        visualize_stencil_2d(stencil, **kwargs)
+    else:
+        slicing = False
+        if 'slice' in kwargs:
+            slicing = kwargs['slice']
+            del kwargs['slice']
+
+        if slicing:
+            visualize_stencil_3d_by_slicing(stencil, **kwargs)
+        else:
+            visualize_stencil_3d(stencil, **kwargs)
+
+
+def visualize_stencil_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs):
+    """
+    Creates a matplotlib 2D plot of the stencil
+
+    :param stencil: sequence of directions
+    :param axes: optional matplotlib axes
+    :param data: data to annotate the directions with, if none given, the indices are used
+    :param textsize: size of annotation text
+    """
+    from matplotlib.patches import BoxStyle
+    import matplotlib.pyplot as plt
+
+    if axes is None:
+        if figure is None:
+            figure = plt.gcf()
+        axes = figure.gca()
+
+    text_box_style = BoxStyle("Round", pad=0.3)
+    head_length = 0.1
+    max_offsets = [max(abs(d[c]) for d in stencil) for c in (0, 1)]
+
+    if data is None:
+        data = list(range(len(stencil)))
+
+    for direction, annotation in zip(stencil, data):
+        assert len(direction) == 2, "Works only for 2D stencils"
+
+        if not(direction[0] == 0 and direction[1] == 0):
+            axes.arrow(0, 0, direction[0], direction[1], head_width=0.08, head_length=head_length, color='k')
+
+        if isinstance(annotation, sp.Basic):
+            annotation = "$" + sp.latex(annotation) + "$"
+        else:
+            annotation = str(annotation)
+
+        def position_correction(d, magnitude=0.18):
+            if d < 0:
+                return -magnitude
+            elif d > 0:
+                return +magnitude
+            else:
+                return 0
+        text_position = [direction[c] + position_correction(direction[c]) for c in (0, 1)]
+        axes.text(*text_position, annotation, verticalalignment='center',
+                  zorder=30, horizontalalignment='center', size=textsize,
+                  bbox=dict(boxstyle=text_box_style, facecolor='#00b6eb', alpha=0.85, linewidth=0))
+
+    axes.set_axis_off()
+    axes.set_aspect('equal')
+    max_offsets = [m if m > 0 else 0.1 for m in max_offsets]
+    border = 0.1
+    axes.set_xlim([-border - max_offsets[0], border + max_offsets[0]])
+    axes.set_ylim([-border - max_offsets[1], border + max_offsets[1]])
+
+
+def visualize_stencil_3d_by_slicing(stencil, slice_axis=2, figure=None, data=None, **kwargs):
+    """Visualizes a 3D, first-neighborhood stencil by plotting 3 slices along a given axis.
+
+    Args:
+        stencil: stencil as sequence of directions
+        slice_axis: 0, 1, or 2 indicating the axis to slice through
+        data: optional data to print as text besides the arrows
+    """
+    import matplotlib.pyplot as plt
+
+    for d in stencil:
+        for element in d:
+            assert element == -1 or element == 0 or element == 1, "This function can only first neighborhood stencils"
+
+    if figure is None:
+        figure = plt.gcf()
+
+    axes = [figure.add_subplot(1, 3, i + 1) for i in range(3)]
+    splitted_directions = [[], [], []]
+    splitted_data = [[], [], []]
+    axes_names = ['x', 'y', 'z']
+
+    for i, d in enumerate(stencil):
+        split_idx = d[slice_axis] + 1
+        reduced_dir = tuple([element for j, element in enumerate(d) if j != slice_axis])
+        splitted_directions[split_idx].append(reduced_dir)
+        splitted_data[split_idx].append(i if data is None else data[i])
+
+    for i in range(3):
+        visualize_stencil_2d(splitted_directions[i], axes=axes[i], data=splitted_data[i], **kwargs)
+    for i in [-1, 0, 1]:
+        axes[i + 1].set_title("Cut at %s=%d" % (axes_names[slice_axis], i))
+
+
+def visualize_stencil_3d(stencil, figure=None, axes=None, data=None, textsize='8'):
+    """
+    Draws 3D stencil into a 3D coordinate system, parameters are similar to :func:`visualize_stencil_2d`
+    If data is None, no labels are drawn. To draw the labels as in the 2D case, use ``data=list(range(len(stencil)))``
+    """
+    from matplotlib.patches import FancyArrowPatch
+    from mpl_toolkits.mplot3d import proj3d
+    import matplotlib.pyplot as plt
+    from matplotlib.patches import BoxStyle
+    from itertools import product, combinations
+    import numpy as np
+
+    class Arrow3D(FancyArrowPatch):
+        def __init__(self, xs, ys, zs, *args, **kwargs):
+            FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
+            self._verts3d = xs, ys, zs
+
+        def draw(self, renderer):
+            xs3d, ys3d, zs3d = self._verts3d
+            xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
+            self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
+            FancyArrowPatch.draw(self, renderer)
+
+    if axes is None:
+        if figure is None:
+            figure = plt.figure()
+        axes = figure.gca(projection='3d')
+        axes.set_aspect("equal")
+
+    if data is None:
+        data = [None] * len(stencil)
+
+    text_offset = 1.25
+    text_box_style = BoxStyle("Round", pad=0.3)
+
+    # Draw cell (cube)
+    r = [-1, 1]
+    for s, e in combinations(np.array(list(product(r, r, r))), 2):
+        if np.sum(np.abs(s - e)) == r[1] - r[0]:
+            axes.plot3D(*zip(s, e), color="k", alpha=0.5)
+
+    for d, annotation in zip(stencil, data):
+        assert len(d) == 3, "Works only for 3D stencils"
+        if not (d[0] == 0 and d[1] == 0 and d[2] == 0):
+            if d[0] == 0:
+                color = '#348abd'
+            elif d[1] == 0:
+                color = '#fac364'
+            elif sum([abs(d) for d in d]) == 2:
+                color = '#95bd50'
+            else:
+                color = '#808080'
+
+            a = Arrow3D([0, d[0]], [0, d[1]], [0, d[2]], mutation_scale=20, lw=2, arrowstyle="-|>", color=color)
+            axes.add_artist(a)
+
+        if annotation:
+            if isinstance(annotation, sp.Basic):
+                annotation = "$" + sp.latex(annotation) + "$"
+            else:
+                annotation = str(annotation)
+
+            axes.text(d[0] * text_offset, d[1] * text_offset, d[2] * text_offset,
+                      annotation, verticalalignment='center', zorder=30,
+                      size=textsize, bbox=dict(boxstyle=text_box_style, facecolor='#777777', alpha=0.6, linewidth=0))
+
+    axes.set_xlim([-text_offset * 1.1, text_offset * 1.1])
+    axes.set_ylim([-text_offset * 1.1, text_offset * 1.1])
+    axes.set_zlim([-text_offset * 1.1, text_offset * 1.1])
+    axes.set_axis_off()
+
+
+def visualize_stencil_expression(expr, **kwargs):
+    stencil, coefficients = stencil_coefficients(expr)
+    dim = len(stencil[0])
+    assert 0 < dim <= 3
+    if dim == 1:
+        return stencil_coefficient_list(expr, matrix_form=True)
+    elif dim == 2:
+        return visualize_stencil_2d(stencil, data=coefficients, **kwargs)
+    elif dim == 3:
+        return visualize_stencil_3d_by_slicing(stencil, data=coefficients, **kwargs)
diff --git a/utils.py b/utils.py
index bc7833629..da3fec7d7 100644
--- a/utils.py
+++ b/utils.py
@@ -2,6 +2,7 @@ import os
 from tempfile import NamedTemporaryFile
 from contextlib import contextmanager
 from typing import Mapping
+from collections import Counter
 
 
 class DotDict(dict):
@@ -66,3 +67,19 @@ def atomic_file_write(file_path):
         f.file.close()
         yield f.name
     os.rename(f.name, file_path)
+
+
+def fully_contains(l1, l2):
+    """Tests if elements of sequence 1 are in sequence 2 in same or higher number.
+
+    >>> fully_contains([1, 1, 2], [1, 2])  # 1 is only present once in second list
+    False
+    >>> fully_contains([1, 1, 2], [1, 1, 4, 2])
+    True
+    """
+    l1_counter = Counter(l1)
+    l2_counter = Counter(l2)
+    for element, count in l1_counter.items():
+        if l2_counter[element] < count:
+            return False
+    return True
-- 
GitLab