From 8a517fd10853f1c7300d8ab3d78377f0b676b9e2 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Wed, 19 Sep 2018 12:50:08 +0200 Subject: [PATCH] Moved stencils from lbmpy->pystencils, New coeff extraction funcs - new functions to easily extract stencil coefficients and visualize them - Moved stencil functions from lbmpy to pystencils --- __init__.py | 4 +- datahandling/serial_datahandling.py | 13 ++ field.py | 10 + stencils.py | 340 ++++++++++++++++++++++++++++ utils.py | 17 ++ 5 files changed, 383 insertions(+), 1 deletion(-) create mode 100644 stencils.py diff --git a/__init__.py b/__init__.py index 5dd9cd572..699662a56 100644 --- a/__init__.py +++ b/__init__.py @@ -10,6 +10,7 @@ from .assignment import Assignment from .sympyextensions import SymbolCreator from .datahandling import create_data_handling from .kernel_decorator import kernel +from .stencils import visualize_stencil_expression from . import fd __all__ = ['Field', 'FieldType', 'fields', @@ -22,4 +23,5 @@ __all__ = ['Field', 'FieldType', 'fields', 'SymbolCreator', 'create_data_handling', 'kernel', - 'fd'] + 'fd', + 'visualize_stencil_expression'] diff --git a/datahandling/serial_datahandling.py b/datahandling/serial_datahandling.py index 95d4e6220..7aa5d2de6 100644 --- a/datahandling/serial_datahandling.py +++ b/datahandling/serial_datahandling.py @@ -381,3 +381,16 @@ class SerialDataHandling(DataHandling): def save_all(self, file): np.savez_compressed(file, **self.cpu_arrays) + + def load_all(self, file): + file_contents = np.load(file) + for arr_name, arr_contents in self.cpu_arrays.items(): + if arr_name not in file_contents: + print("Skipping read data {} because there is no data with this name in data handling".format(arr_name)) + continue + if file_contents[arr_name].shape != arr_contents.shape: + print("Skipping read data {} because shapes don't match. " + "Read array shape {}, exising array shape {}".format(arr_name, file_contents[arr_name].shape, + arr_contents.shape)) + continue + np.copyto(arr_contents, file_contents[arr_name]) diff --git a/field.py b/field.py index 47cd0955c..d0e95b489 100644 --- a/field.py +++ b/field.py @@ -602,6 +602,16 @@ class Field: else: return "{{%s}_{%s}}" % (n, offset_str) + def __str__(self): + n = self._field.latex_name if self._field.latex_name else self._field.name + offset_str = ",".join([sp.latex(o) for o in self.offsets]) + if self.is_absolute_access: + offset_str = "[abs]{}".format(offset_str) + if self.index and self.index != (0,): + return "%s[%s](%s)" % (n, offset_str, self.index if len(self.index) > 1 else self.index[0]) + else: + return "%s[%s]" % (n, offset_str) + def get_layout_from_strides(strides: Sequence[int], index_dimension_ids: Optional[List[int]] = None): index_dimension_ids = [] if index_dimension_ids is None else index_dimension_ids diff --git a/stencils.py b/stencils.py new file mode 100644 index 000000000..0efe601fd --- /dev/null +++ b/stencils.py @@ -0,0 +1,340 @@ +import sympy as sp +from collections import defaultdict +from pystencils import Field + + +def inverse_direction(direction): + """Returns inverse i.e. negative of given direction tuple""" + return tuple([-i for i in direction]) + + +def is_valid_stencil(stencil, max_neighborhood=None): + """ + Tests if a nested sequence is a valid stencil i.e. all the inner sequences have the same length. + If max_neighborhood is specified, it is also verified that the stencil does not contain any direction components + with absolute value greater than the maximal neighborhood. + """ + expected_dim = len(stencil[0]) + for d in stencil: + if len(d) != expected_dim: + return False + if max_neighborhood is not None: + for d_i in d: + if abs(d_i) > max_neighborhood: + return False + return True + + +def is_symmetric_stencil(stencil): + """Tests for every direction d, that -d is also in the stencil""" + for d in stencil: + if inverse_direction(d) not in stencil: + return False + return True + + +def stencils_have_same_entries(s1, s2): + if len(s1) != len(s2): + return False + return len(set(s1) - set(s2)) == 0 + + +# -------------------------------------Expression - Coefficient Form Conversion ---------------------------------------- + + +def stencil_coefficient_dict(expr): + """Extracts coefficients in front of field accesses in a expression. + + Expression may only access a single field at a single index. + + Returns: + center, coefficient dict, nonlinear part + where center is the single field that is accessed in expression accessed at center + and coefficient dict maps offsets to coefficients. The nonlinear part is everything that is not in the form of + coefficient times field access. + + Examples: + >>> import pystencils as ps + >>> f = ps.fields("f(3) : double[2D]") + >>> field, coeffs, nonlinear_part = stencil_coefficient_dict(2 * f[0, 1](1) + 3 * f[-1, 0](1) + 123) + >>> assert nonlinear_part == 123 and field == f(1) + >>> sorted(coeffs.items()) + [((-1, 0), 3), ((0, 1), 2)] + """ + expr = expr.expand() + field_accesses = expr.atoms(Field.Access) + fields = set(fa.field for fa in field_accesses) + accessed_indices = set(fa.index for fa in field_accesses) + + if len(fields) != 1: + raise ValueError("Could not extract stencil coefficients. " + "Expression has to be a linear function of exactly one field.") + if len(accessed_indices) != 1: + raise ValueError("Could not extract stencil coefficients. Field is accessed at multiple indices") + + field = fields.pop() + idx = accessed_indices.pop() + + coefficients = defaultdict(lambda: 0) + coefficients.update({fa.offsets: expr.coeff(fa) for fa in field_accesses}) + + linear_part = sum(c * field[off](*idx) for off, c in coefficients.items()) + nonlinear_part = expr - linear_part + return field(*idx), coefficients, nonlinear_part + + +def stencil_coefficients(expr): + """Returns two lists - one with accessed offsets and one with their coefficients. + + Same restrictions as `stencil_coefficient_dict` apply. Expression must not have any nonlinear part + + >>> import pystencils as ps + >>> f = ps.fields("f(3) : double[2D]") + >>> stencil_coefficients(2 * f[0, 1](1) + 3 * f[-1, 0](1)) + ([(-1, 0), (0, 1)], [3, 2]) + """ + field_center, coefficients, nonlinear_part = stencil_coefficient_dict(expr) + assert nonlinear_part == 0 + stencil = list(coefficients.keys()) + entries = [coefficients[c] for c in stencil] + return stencil, entries + + +def stencil_coefficient_list(expr, matrix_form=False): + """Returns stencil coefficients in the form of nested lists + + Same restrictions as `stencil_coefficient_dict` apply. Expression must not have any nonlinear part + + Examples: + >>> import pystencils as ps + >>> f = ps.fields("f: double[2D]") + >>> stencil_coefficient_list(2 * f[0, 1] + 3 * f[-1, 0]) + [[0, 0, 0], [3, 0, 0], [0, 2, 0]] + >>> stencil_coefficient_list(2 * f[0, 1] + 3 * f[-1, 0], matrix_form=True) + Matrix([ + [0, 2, 0], + [3, 0, 0], + [0, 0, 0]]) + """ + field_center, coefficients, nonlinear_part = stencil_coefficient_dict(expr) + assert nonlinear_part == 0 + field = field_center.field + + dim = field.spatial_dimensions + max_offsets = defaultdict(lambda: 0) + for offset in coefficients.keys(): + for d, off in enumerate(offset): + max_offsets[d] = max(max_offsets[d], abs(off)) + + if dim == 1: + result = [coefficients[(i,)] for i in range(-max_offsets[0], max_offsets[0] + 1)] + return sp.Matrix(result) if matrix_form else result + else: + y_range = list(range(-max_offsets[1], max_offsets[1] + 1)) + if matrix_form: + y_range.reverse() + if dim == 2: + result = [[coefficients[(i, j)] + for i in range(-max_offsets[0], max_offsets[0] + 1)] + for j in y_range] + return sp.Matrix(result) if matrix_form else result + elif dim == 3: + result = [[[coefficients[(i, j, k)] + for i in range(-max_offsets[0], max_offsets[0] + 1)] + for j in y_range] + for k in range(-max_offsets[2], max_offsets[2] + 1)] + return [sp.Matrix(l) for l in result] if matrix_form else result + else: + raise ValueError("Can only handle fields with 1,2 or 3 spatial dimensions") + + +# -------------------------------------- Visualization ----------------------------------------------------------------- + + +def visualize_stencil(stencil, **kwargs): + dim = len(stencil[0]) + if dim == 2: + visualize_stencil_2d(stencil, **kwargs) + else: + slicing = False + if 'slice' in kwargs: + slicing = kwargs['slice'] + del kwargs['slice'] + + if slicing: + visualize_stencil_3d_by_slicing(stencil, **kwargs) + else: + visualize_stencil_3d(stencil, **kwargs) + + +def visualize_stencil_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs): + """ + Creates a matplotlib 2D plot of the stencil + + :param stencil: sequence of directions + :param axes: optional matplotlib axes + :param data: data to annotate the directions with, if none given, the indices are used + :param textsize: size of annotation text + """ + from matplotlib.patches import BoxStyle + import matplotlib.pyplot as plt + + if axes is None: + if figure is None: + figure = plt.gcf() + axes = figure.gca() + + text_box_style = BoxStyle("Round", pad=0.3) + head_length = 0.1 + max_offsets = [max(abs(d[c]) for d in stencil) for c in (0, 1)] + + if data is None: + data = list(range(len(stencil))) + + for direction, annotation in zip(stencil, data): + assert len(direction) == 2, "Works only for 2D stencils" + + if not(direction[0] == 0 and direction[1] == 0): + axes.arrow(0, 0, direction[0], direction[1], head_width=0.08, head_length=head_length, color='k') + + if isinstance(annotation, sp.Basic): + annotation = "$" + sp.latex(annotation) + "$" + else: + annotation = str(annotation) + + def position_correction(d, magnitude=0.18): + if d < 0: + return -magnitude + elif d > 0: + return +magnitude + else: + return 0 + text_position = [direction[c] + position_correction(direction[c]) for c in (0, 1)] + axes.text(*text_position, annotation, verticalalignment='center', + zorder=30, horizontalalignment='center', size=textsize, + bbox=dict(boxstyle=text_box_style, facecolor='#00b6eb', alpha=0.85, linewidth=0)) + + axes.set_axis_off() + axes.set_aspect('equal') + max_offsets = [m if m > 0 else 0.1 for m in max_offsets] + border = 0.1 + axes.set_xlim([-border - max_offsets[0], border + max_offsets[0]]) + axes.set_ylim([-border - max_offsets[1], border + max_offsets[1]]) + + +def visualize_stencil_3d_by_slicing(stencil, slice_axis=2, figure=None, data=None, **kwargs): + """Visualizes a 3D, first-neighborhood stencil by plotting 3 slices along a given axis. + + Args: + stencil: stencil as sequence of directions + slice_axis: 0, 1, or 2 indicating the axis to slice through + data: optional data to print as text besides the arrows + """ + import matplotlib.pyplot as plt + + for d in stencil: + for element in d: + assert element == -1 or element == 0 or element == 1, "This function can only first neighborhood stencils" + + if figure is None: + figure = plt.gcf() + + axes = [figure.add_subplot(1, 3, i + 1) for i in range(3)] + splitted_directions = [[], [], []] + splitted_data = [[], [], []] + axes_names = ['x', 'y', 'z'] + + for i, d in enumerate(stencil): + split_idx = d[slice_axis] + 1 + reduced_dir = tuple([element for j, element in enumerate(d) if j != slice_axis]) + splitted_directions[split_idx].append(reduced_dir) + splitted_data[split_idx].append(i if data is None else data[i]) + + for i in range(3): + visualize_stencil_2d(splitted_directions[i], axes=axes[i], data=splitted_data[i], **kwargs) + for i in [-1, 0, 1]: + axes[i + 1].set_title("Cut at %s=%d" % (axes_names[slice_axis], i)) + + +def visualize_stencil_3d(stencil, figure=None, axes=None, data=None, textsize='8'): + """ + Draws 3D stencil into a 3D coordinate system, parameters are similar to :func:`visualize_stencil_2d` + If data is None, no labels are drawn. To draw the labels as in the 2D case, use ``data=list(range(len(stencil)))`` + """ + from matplotlib.patches import FancyArrowPatch + from mpl_toolkits.mplot3d import proj3d + import matplotlib.pyplot as plt + from matplotlib.patches import BoxStyle + from itertools import product, combinations + import numpy as np + + class Arrow3D(FancyArrowPatch): + def __init__(self, xs, ys, zs, *args, **kwargs): + FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) + self._verts3d = xs, ys, zs + + def draw(self, renderer): + xs3d, ys3d, zs3d = self._verts3d + xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) + self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) + FancyArrowPatch.draw(self, renderer) + + if axes is None: + if figure is None: + figure = plt.figure() + axes = figure.gca(projection='3d') + axes.set_aspect("equal") + + if data is None: + data = [None] * len(stencil) + + text_offset = 1.25 + text_box_style = BoxStyle("Round", pad=0.3) + + # Draw cell (cube) + r = [-1, 1] + for s, e in combinations(np.array(list(product(r, r, r))), 2): + if np.sum(np.abs(s - e)) == r[1] - r[0]: + axes.plot3D(*zip(s, e), color="k", alpha=0.5) + + for d, annotation in zip(stencil, data): + assert len(d) == 3, "Works only for 3D stencils" + if not (d[0] == 0 and d[1] == 0 and d[2] == 0): + if d[0] == 0: + color = '#348abd' + elif d[1] == 0: + color = '#fac364' + elif sum([abs(d) for d in d]) == 2: + color = '#95bd50' + else: + color = '#808080' + + a = Arrow3D([0, d[0]], [0, d[1]], [0, d[2]], mutation_scale=20, lw=2, arrowstyle="-|>", color=color) + axes.add_artist(a) + + if annotation: + if isinstance(annotation, sp.Basic): + annotation = "$" + sp.latex(annotation) + "$" + else: + annotation = str(annotation) + + axes.text(d[0] * text_offset, d[1] * text_offset, d[2] * text_offset, + annotation, verticalalignment='center', zorder=30, + size=textsize, bbox=dict(boxstyle=text_box_style, facecolor='#777777', alpha=0.6, linewidth=0)) + + axes.set_xlim([-text_offset * 1.1, text_offset * 1.1]) + axes.set_ylim([-text_offset * 1.1, text_offset * 1.1]) + axes.set_zlim([-text_offset * 1.1, text_offset * 1.1]) + axes.set_axis_off() + + +def visualize_stencil_expression(expr, **kwargs): + stencil, coefficients = stencil_coefficients(expr) + dim = len(stencil[0]) + assert 0 < dim <= 3 + if dim == 1: + return stencil_coefficient_list(expr, matrix_form=True) + elif dim == 2: + return visualize_stencil_2d(stencil, data=coefficients, **kwargs) + elif dim == 3: + return visualize_stencil_3d_by_slicing(stencil, data=coefficients, **kwargs) diff --git a/utils.py b/utils.py index bc7833629..da3fec7d7 100644 --- a/utils.py +++ b/utils.py @@ -2,6 +2,7 @@ import os from tempfile import NamedTemporaryFile from contextlib import contextmanager from typing import Mapping +from collections import Counter class DotDict(dict): @@ -66,3 +67,19 @@ def atomic_file_write(file_path): f.file.close() yield f.name os.rename(f.name, file_path) + + +def fully_contains(l1, l2): + """Tests if elements of sequence 1 are in sequence 2 in same or higher number. + + >>> fully_contains([1, 1, 2], [1, 2]) # 1 is only present once in second list + False + >>> fully_contains([1, 1, 2], [1, 1, 4, 2]) + True + """ + l1_counter = Counter(l1) + l2_counter = Counter(l2) + for element, count in l1_counter.items(): + if l2_counter[element] < count: + return False + return True -- GitLab