stencil.py 16.1 KB
 Martin Bauer committed May 05, 2019 1 ``````"""This submodule offers functions to work with stencils in expression an offset-list form.""" `````` Martin Bauer committed Oct 16, 2018 2 3 ``````from typing import Sequence import numpy as np `````` Martin Bauer committed Sep 19, 2018 4 5 6 7 8 ``````import sympy as sp from collections import defaultdict def inverse_direction(direction): `````` Martin Bauer committed May 05, 2019 9 10 11 12 13 14 `````` """Returns inverse i.e. negative of given direction tuple Example: >>> inverse_direction((1, -1, 0)) (-1, 1, 0) """ `````` Martin Bauer committed Sep 19, 2018 15 16 17 `````` return tuple([-i for i in direction]) `````` Martin Bauer committed May 05, 2019 18 ``````def is_valid(stencil, max_neighborhood=None): `````` Martin Bauer committed Sep 19, 2018 19 20 21 22 `````` """ Tests if a nested sequence is a valid stencil i.e. all the inner sequences have the same length. If max_neighborhood is specified, it is also verified that the stencil does not contain any direction components with absolute value greater than the maximal neighborhood. `````` Martin Bauer committed May 05, 2019 23 24 25 26 27 28 29 30 `````` Examples: >>> is_valid([(1, 0), (1, 0, 0)]) # stencil entries have different length False >>> is_valid([(2, 0), (1, 0)]) True >>> is_valid([(2, 0), (1, 0)], max_neighborhood=1) False `````` Martin Bauer committed Sep 19, 2018 31 32 33 34 35 36 37 38 39 40 41 42 `````` """ expected_dim = len(stencil[0]) for d in stencil: if len(d) != expected_dim: return False if max_neighborhood is not None: for d_i in d: if abs(d_i) > max_neighborhood: return False return True `````` Martin Bauer committed May 05, 2019 43 44 45 46 47 48 49 50 51 ``````def is_symmetric(stencil): """Tests for every direction d, that -d is also in the stencil Examples: >>> is_symmetric([(1, 0), (0, 1)]) False >>> is_symmetric([(1, 0), (-1, 0)]) True """ `````` Martin Bauer committed Sep 19, 2018 52 53 54 55 56 57 `````` for d in stencil: if inverse_direction(d) not in stencil: return False return True `````` Martin Bauer committed May 05, 2019 58 59 60 61 62 63 64 65 66 ``````def have_same_entries(s1, s2): """Checks if two stencils are the same Examples: >>> stencil1 = [(1, 0), (-1, 0), (0, 1), (0, -1)] >>> stencil2 = [(-1, 0), (0, -1), (1, 0), (0, 1)] >>> have_same_entries(stencil1, stencil2) True """ `````` Martin Bauer committed Sep 19, 2018 67 68 69 70 71 72 73 74 `````` if len(s1) != len(s2): return False return len(set(s1) - set(s2)) == 0 # -------------------------------------Expression - Coefficient Form Conversion ---------------------------------------- `````` Martin Bauer committed May 05, 2019 75 ``````def coefficient_dict(expr): `````` Martin Bauer committed Sep 19, 2018 76 77 78 79 80 81 82 83 84 85 86 87 88 `````` """Extracts coefficients in front of field accesses in a expression. Expression may only access a single field at a single index. Returns: center, coefficient dict, nonlinear part where center is the single field that is accessed in expression accessed at center and coefficient dict maps offsets to coefficients. The nonlinear part is everything that is not in the form of coefficient times field access. Examples: >>> import pystencils as ps >>> f = ps.fields("f(3) : double[2D]") `````` Martin Bauer committed May 05, 2019 89 `````` >>> field, coeffs, nonlinear_part = coefficient_dict(2 * f[0, 1](1) + 3 * f[-1, 0](1) + 123) `````` Martin Bauer committed Sep 19, 2018 90 91 92 93 `````` >>> assert nonlinear_part == 123 and field == f(1) >>> sorted(coeffs.items()) [((-1, 0), 3), ((0, 1), 2)] """ `````` Martin Bauer committed May 05, 2019 94 `````` from pystencils import Field `````` Martin Bauer committed Sep 19, 2018 95 96 97 98 99 100 101 102 103 104 105 106 107 108 `````` expr = expr.expand() field_accesses = expr.atoms(Field.Access) fields = set(fa.field for fa in field_accesses) accessed_indices = set(fa.index for fa in field_accesses) if len(fields) != 1: raise ValueError("Could not extract stencil coefficients. " "Expression has to be a linear function of exactly one field.") if len(accessed_indices) != 1: raise ValueError("Could not extract stencil coefficients. Field is accessed at multiple indices") field = fields.pop() idx = accessed_indices.pop() `````` Martin Bauer committed May 05, 2019 109 110 `````` coeffs = defaultdict(lambda: 0) coeffs.update({fa.offsets: expr.coeff(fa) for fa in field_accesses}) `````` Martin Bauer committed Sep 19, 2018 111 `````` `````` Martin Bauer committed May 05, 2019 112 `````` linear_part = sum(c * field[off](*idx) for off, c in coeffs.items()) `````` Martin Bauer committed Sep 19, 2018 113 `````` nonlinear_part = expr - linear_part `````` Martin Bauer committed May 05, 2019 114 `````` return field(*idx), coeffs, nonlinear_part `````` Martin Bauer committed Sep 19, 2018 115 116 `````` `````` Martin Bauer committed May 05, 2019 117 ``````def coefficients(expr): `````` Martin Bauer committed Sep 19, 2018 118 119 `````` """Returns two lists - one with accessed offsets and one with their coefficients. `````` Martin Bauer committed May 05, 2019 120 `````` Same restrictions as `coefficient_dict` apply. Expression must not have any nonlinear part `````` Martin Bauer committed Sep 19, 2018 121 122 123 `````` >>> import pystencils as ps >>> f = ps.fields("f(3) : double[2D]") `````` Martin Bauer committed May 05, 2019 124 `````` >>> coff = coefficients(2 * f[0, 1](1) + 3 * f[-1, 0](1)) `````` Martin Bauer committed Sep 19, 2018 125 `````` """ `````` Martin Bauer committed May 05, 2019 126 `````` field_center, coeffs, nonlinear_part = coefficient_dict(expr) `````` Martin Bauer committed Sep 19, 2018 127 `````` assert nonlinear_part == 0 `````` Martin Bauer committed May 05, 2019 128 129 `````` stencil = list(coeffs.keys()) entries = [coeffs[c] for c in stencil] `````` Martin Bauer committed Sep 19, 2018 130 131 132 `````` return stencil, entries `````` Martin Bauer committed May 05, 2019 133 ``````def coefficient_list(expr, matrix_form=False): `````` Martin Bauer committed Sep 19, 2018 134 135 `````` """Returns stencil coefficients in the form of nested lists `````` Martin Bauer committed May 05, 2019 136 `````` Same restrictions as `coefficient_dict` apply. Expression must not have any nonlinear part `````` Martin Bauer committed Sep 19, 2018 137 138 139 140 `````` Examples: >>> import pystencils as ps >>> f = ps.fields("f: double[2D]") `````` Martin Bauer committed May 05, 2019 141 `````` >>> coefficient_list(2 * f[0, 1] + 3 * f[-1, 0]) `````` Martin Bauer committed Sep 19, 2018 142 `````` [[0, 0, 0], [3, 0, 0], [0, 2, 0]] `````` Martin Bauer committed May 05, 2019 143 `````` >>> coefficient_list(2 * f[0, 1] + 3 * f[-1, 0], matrix_form=True) `````` Martin Bauer committed Sep 19, 2018 144 145 146 147 148 `````` Matrix([ [0, 2, 0], [3, 0, 0], [0, 0, 0]]) """ `````` Martin Bauer committed May 05, 2019 149 `````` field_center, coeffs, nonlinear_part = coefficient_dict(expr) `````` Martin Bauer committed Sep 19, 2018 150 151 152 153 154 `````` assert nonlinear_part == 0 field = field_center.field dim = field.spatial_dimensions max_offsets = defaultdict(lambda: 0) `````` Martin Bauer committed May 05, 2019 155 `````` for offset in coeffs.keys(): `````` Martin Bauer committed Sep 19, 2018 156 157 158 159 `````` for d, off in enumerate(offset): max_offsets[d] = max(max_offsets[d], abs(off)) if dim == 1: `````` Martin Bauer committed May 05, 2019 160 `````` result = [coeffs[(i,)] for i in range(-max_offsets[0], max_offsets[0] + 1)] `````` Martin Bauer committed Sep 19, 2018 161 162 163 164 165 166 `````` return sp.Matrix(result) if matrix_form else result else: y_range = list(range(-max_offsets[1], max_offsets[1] + 1)) if matrix_form: y_range.reverse() if dim == 2: `````` Martin Bauer committed May 05, 2019 167 `````` result = [[coeffs[(i, j)] `````` Martin Bauer committed Sep 19, 2018 168 169 170 171 `````` for i in range(-max_offsets[0], max_offsets[0] + 1)] for j in y_range] return sp.Matrix(result) if matrix_form else result elif dim == 3: `````` Martin Bauer committed May 05, 2019 172 `````` result = [[[coeffs[(i, j, k)] `````` Martin Bauer committed Sep 19, 2018 173 174 175 176 177 178 179 180 `````` for i in range(-max_offsets[0], max_offsets[0] + 1)] for j in y_range] for k in range(-max_offsets[2], max_offsets[2] + 1)] return [sp.Matrix(l) for l in result] if matrix_form else result else: raise ValueError("Can only handle fields with 1,2 or 3 spatial dimensions") `````` Martin Bauer committed Oct 16, 2018 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 ``````# ------------------------------------- Point-on-compass notation ------------------------------------------------------ def offset_component_to_direction_string(coordinate_id: int, value: int) -> str: """Translates numerical offset to string notation. x offsets are labeled with east 'E' and 'W', y offsets with north 'N' and 'S' and z offsets with top 'T' and bottom 'B' If the absolute value of the offset is bigger than 1, this number is prefixed. Args: coordinate_id: integer 0, 1 or 2 standing for x,y and z value: integer offset Examples: >>> offset_component_to_direction_string(0, 1) 'E' >>> offset_component_to_direction_string(1, 2) '2N' """ assert 0 <= coordinate_id < 3, "Works only for at most 3D arrays" name_components = (('W', 'E'), # west, east ('S', 'N'), # south, north ('B', 'T')) # bottom, top if value == 0: result = "" elif value < 0: result = name_components[coordinate_id][0] else: result = name_components[coordinate_id][1] if abs(value) > 1: result = "%d%s" % (abs(value), result) return result def offset_to_direction_string(offsets: Sequence[int]) -> str: """ Translates numerical offset to string notation. For details see :func:`offset_component_to_direction_string` Args: offsets: 3-tuple with x,y,z offset Examples: >>> offset_to_direction_string([1, -1, 0]) 'SE' >>> offset_to_direction_string(([-3, 0, -2])) '2B3W' """ if len(offsets) > 3: return str(offsets) names = ["", "", ""] for i in range(len(offsets)): names[i] = offset_component_to_direction_string(i, offsets[i]) name = "".join(reversed(names)) if name == "": name = "C" return name def direction_string_to_offset(direction: str, dim: int = 3): """ Reverse mapping of :func:`offset_to_direction_string` Args: direction: string representation of offset dim: dimension of offset, i.e the length of the returned list Examples: >>> direction_string_to_offset('NW', dim=3) array([-1, 1, 0]) >>> direction_string_to_offset('NW', dim=2) array([-1, 1]) >>> direction_string_to_offset(offset_to_direction_string((3,-2,1))) array([ 3, -2, 1]) """ offset_dict = { 'C': np.array([0, 0, 0]), 'W': np.array([-1, 0, 0]), 'E': np.array([1, 0, 0]), 'S': np.array([0, -1, 0]), 'N': np.array([0, 1, 0]), 'B': np.array([0, 0, -1]), 'T': np.array([0, 0, 1]), } offset = np.array([0, 0, 0]) while len(direction) > 0: factor = 1 first_non_digit = 0 while direction[first_non_digit].isdigit(): first_non_digit += 1 if first_non_digit > 0: factor = int(direction[:first_non_digit]) direction = direction[first_non_digit:] cur_offset = offset_dict[direction[0]] offset += factor * cur_offset direction = direction[1:] return offset[:dim] `````` Martin Bauer committed Sep 19, 2018 285 286 287 ``````# -------------------------------------- Visualization ----------------------------------------------------------------- `````` Martin Bauer committed May 05, 2019 288 ``````def plot(stencil, **kwargs): `````` Martin Bauer committed Sep 19, 2018 289 290 `````` dim = len(stencil[0]) if dim == 2: `````` Martin Bauer committed May 05, 2019 291 `````` plot_2d(stencil, **kwargs) `````` Martin Bauer committed Sep 19, 2018 292 293 294 295 296 297 298 `````` else: slicing = False if 'slice' in kwargs: slicing = kwargs['slice'] del kwargs['slice'] if slicing: `````` Martin Bauer committed May 05, 2019 299 `````` plot_3d_slicing(stencil, **kwargs) `````` Martin Bauer committed Sep 19, 2018 300 `````` else: `````` Martin Bauer committed May 05, 2019 301 `````` plot_3d(stencil, **kwargs) `````` Martin Bauer committed Sep 19, 2018 302 303 `````` `````` Martin Bauer committed May 05, 2019 304 ``````def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs): `````` Martin Bauer committed Sep 19, 2018 305 306 307 `````` """ Creates a matplotlib 2D plot of the stencil `````` Martin Bauer committed Sep 20, 2018 308 309 310 311 312 `````` Args: stencil: sequence of directions axes: optional matplotlib axes data: data to annotate the directions with, if none given, the indices are used textsize: size of annotation text `````` Martin Bauer committed Sep 19, 2018 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 `````` """ from matplotlib.patches import BoxStyle import matplotlib.pyplot as plt if axes is None: if figure is None: figure = plt.gcf() axes = figure.gca() text_box_style = BoxStyle("Round", pad=0.3) head_length = 0.1 max_offsets = [max(abs(d[c]) for d in stencil) for c in (0, 1)] if data is None: data = list(range(len(stencil))) for direction, annotation in zip(stencil, data): assert len(direction) == 2, "Works only for 2D stencils" if not(direction[0] == 0 and direction[1] == 0): axes.arrow(0, 0, direction[0], direction[1], head_width=0.08, head_length=head_length, color='k') if isinstance(annotation, sp.Basic): annotation = "\$" + sp.latex(annotation) + "\$" else: annotation = str(annotation) def position_correction(d, magnitude=0.18): if d < 0: return -magnitude elif d > 0: return +magnitude else: return 0 text_position = [direction[c] + position_correction(direction[c]) for c in (0, 1)] axes.text(*text_position, annotation, verticalalignment='center', zorder=30, horizontalalignment='center', size=textsize, bbox=dict(boxstyle=text_box_style, facecolor='#00b6eb', alpha=0.85, linewidth=0)) axes.set_axis_off() axes.set_aspect('equal') max_offsets = [m if m > 0 else 0.1 for m in max_offsets] border = 0.1 axes.set_xlim([-border - max_offsets[0], border + max_offsets[0]]) axes.set_ylim([-border - max_offsets[1], border + max_offsets[1]]) `````` Martin Bauer committed May 05, 2019 360 ``````def plot_3d_slicing(stencil, slice_axis=2, figure=None, data=None, **kwargs): `````` Martin Bauer committed Sep 19, 2018 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 `````` """Visualizes a 3D, first-neighborhood stencil by plotting 3 slices along a given axis. Args: stencil: stencil as sequence of directions slice_axis: 0, 1, or 2 indicating the axis to slice through data: optional data to print as text besides the arrows """ import matplotlib.pyplot as plt for d in stencil: for element in d: assert element == -1 or element == 0 or element == 1, "This function can only first neighborhood stencils" if figure is None: figure = plt.gcf() axes = [figure.add_subplot(1, 3, i + 1) for i in range(3)] splitted_directions = [[], [], []] splitted_data = [[], [], []] axes_names = ['x', 'y', 'z'] for i, d in enumerate(stencil): split_idx = d[slice_axis] + 1 reduced_dir = tuple([element for j, element in enumerate(d) if j != slice_axis]) splitted_directions[split_idx].append(reduced_dir) splitted_data[split_idx].append(i if data is None else data[i]) for i in range(3): `````` Martin Bauer committed May 05, 2019 389 `````` plot_2d(splitted_directions[i], axes=axes[i], data=splitted_data[i], **kwargs) `````` Martin Bauer committed Sep 19, 2018 390 `````` for i in [-1, 0, 1]: `````` Martin Bauer committed May 05, 2019 391 `````` axes[i + 1].set_title("Cut at %s=%d" % (axes_names[slice_axis], i), y=1.08) `````` Martin Bauer committed Sep 19, 2018 392 393 `````` `````` Martin Bauer committed May 05, 2019 394 ``````def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'): `````` Martin Bauer committed Sep 19, 2018 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 `````` """ Draws 3D stencil into a 3D coordinate system, parameters are similar to :func:`visualize_stencil_2d` If data is None, no labels are drawn. To draw the labels as in the 2D case, use ``data=list(range(len(stencil)))`` """ from matplotlib.patches import FancyArrowPatch from mpl_toolkits.mplot3d import proj3d import matplotlib.pyplot as plt from matplotlib.patches import BoxStyle from itertools import product, combinations import numpy as np class Arrow3D(FancyArrowPatch): def __init__(self, xs, ys, zs, *args, **kwargs): FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) self._verts3d = xs, ys, zs def draw(self, renderer): xs3d, ys3d, zs3d = self._verts3d xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) FancyArrowPatch.draw(self, renderer) if axes is None: if figure is None: figure = plt.figure() axes = figure.gca(projection='3d') axes.set_aspect("equal") if data is None: data = [None] * len(stencil) text_offset = 1.25 text_box_style = BoxStyle("Round", pad=0.3) # Draw cell (cube) r = [-1, 1] for s, e in combinations(np.array(list(product(r, r, r))), 2): if np.sum(np.abs(s - e)) == r[1] - r[0]: axes.plot3D(*zip(s, e), color="k", alpha=0.5) for d, annotation in zip(stencil, data): assert len(d) == 3, "Works only for 3D stencils" if not (d[0] == 0 and d[1] == 0 and d[2] == 0): if d[0] == 0: color = '#348abd' elif d[1] == 0: color = '#fac364' elif sum([abs(d) for d in d]) == 2: color = '#95bd50' else: color = '#808080' a = Arrow3D([0, d[0]], [0, d[1]], [0, d[2]], mutation_scale=20, lw=2, arrowstyle="-|>", color=color) axes.add_artist(a) if annotation: if isinstance(annotation, sp.Basic): annotation = "\$" + sp.latex(annotation) + "\$" else: annotation = str(annotation) axes.text(d[0] * text_offset, d[1] * text_offset, d[2] * text_offset, annotation, verticalalignment='center', zorder=30, size=textsize, bbox=dict(boxstyle=text_box_style, facecolor='#777777', alpha=0.6, linewidth=0)) axes.set_xlim([-text_offset * 1.1, text_offset * 1.1]) axes.set_ylim([-text_offset * 1.1, text_offset * 1.1]) axes.set_zlim([-text_offset * 1.1, text_offset * 1.1]) axes.set_axis_off() `````` Martin Bauer committed May 05, 2019 466 ``````def plot_expression(expr, **kwargs): `````` Martin Bauer committed Sep 20, 2018 467 `````` """Displays coefficients of a linear update expression of a single field as matplotlib arrow drawing.""" `````` Martin Bauer committed May 05, 2019 468 `````` stencil, coeffs = coefficients(expr) `````` Martin Bauer committed Sep 19, 2018 469 470 471 `````` dim = len(stencil[0]) assert 0 < dim <= 3 if dim == 1: `````` Martin Bauer committed May 05, 2019 472 `````` return coefficient_list(expr, matrix_form=True) `````` Martin Bauer committed Sep 19, 2018 473 `````` elif dim == 2: `````` Martin Bauer committed May 05, 2019 474 `````` return plot_2d(stencil, data=coeffs, **kwargs) `````` Martin Bauer committed Sep 19, 2018 475 `````` elif dim == 3: `````` Martin Bauer committed May 05, 2019 476 `` return plot_3d_slicing(stencil, data=coeffs, **kwargs)``