diff --git a/fd/derivative.py b/fd/derivative.py index b6f6839f68e58a3a779bf560011e5f5c48d7194b..cc7b050c59427d7afa048ab39b0b6236932858eb 100644 --- a/fd/derivative.py +++ b/fd/derivative.py @@ -1,5 +1,7 @@ import sympy as sp from collections import namedtuple, defaultdict + +from pystencils import Field from pystencils.sympyextensions import normalize_product, prod @@ -22,6 +24,8 @@ class Diff(sp.Expr): def __new__(cls, argument, target=-1, superscript=-1): if argument == 0: return sp.Rational(0, 1) + if isinstance(argument, Field): + argument = argument.center return sp.Expr.__new__(cls, argument.expand(), sp.sympify(target), sp.sympify(superscript)) @property @@ -176,6 +180,35 @@ class DiffOperator(sp.Expr): # ---------------------------------------------------------------------------------------------------------------------- +def diff(expr, *args): + """Shortcut function to create nested derivatives + + >>> f = sp.Symbol("f") + >>> diff(f, 0, 0, 1) == Diff(Diff( Diff(f, 1), 0), 0) + True + """ + if len(args) == 0: + return expr + result = expr + for index in reversed(args): + result = Diff(result, index) + return result + + +def diff_args(expr): + """Extracts the indices and argument of possibly nested derivative - inverse of diff function + + >>> args = (sp.Symbol("x"), 0, 1, 2, 5, 1) + >>> e = diff(*args) + >>> assert diff_args(e) == args + """ + if not isinstance(expr, Diff): + return expr, + else: + inner_res = diff_args(expr.args[0]) + return (inner_res[0], expr.args[1], *inner_res[1:]) + + def diff_terms(expr): """Returns set of all derivatives in an expression. @@ -200,16 +233,6 @@ def collect_diffs(expr): return expr.collect(diff_terms(expr)) -def create_nested_diff(arg, *args): - """Shortcut to create nested derivatives""" - assert arg is not None - args = sorted(args, reverse=True, key=lambda e: e.name if isinstance(e, sp.Symbol) else e) - res = arg - for i in args: - res = Diff(res, i) - return res - - def replace_diff(expr, replacement_dict): """replacement_dict: maps variable (target) to a new Differential operator""" @@ -464,6 +487,27 @@ def combine_diff_products(expr): return combine(expr) +def replace_generic_laplacian(expr, dim=None): + """Laplacian can be written as Diff(Diff(term)) without explicitly giving the dimensions. + + This function replaces these constructs by diff(term, 0, 0) + diff(term, 1, 1) + ... + For this to work, the arguments of the derivative have to be field or field accesses such that the spatial + dimension can be determined. + """ + if isinstance(expr, Diff): + arg, *indices = diff_args(expr) + if isinstance(arg, Field.Access): + dim = arg.field.spatial_dimensions + assert dim is not None + if len(indices) == 2 and all(i == -1 for i in indices): + return sum(diff(arg, i, i) for i in range(dim)) + else: + return expr + else: + new_args = [replace_generic_laplacian(a, dim) for a in expr.args] + return expr.func(*new_args) if new_args else expr + + def functional_derivative(functional, v): r"""Computes functional derivative of functional with respect to v using Euler-Lagrange equation diff --git a/fd/finitedifferences.py b/fd/finitedifferences.py index 4b584779b66d45de1b74ec581f38fc74b51cb7a5..0e36590be9280f506ae28d8f8ce45435a28b2a40 100644 --- a/fd/finitedifferences.py +++ b/fd/finitedifferences.py @@ -73,11 +73,11 @@ class Discretization2ndOrder: self.dt = dt @staticmethod - def __diff_order(e): + def _diff_order(e): if not isinstance(e, Diff): return 0 else: - return 1 + Discretization2ndOrder.__diff_order(e.args[0]) + return 1 + Discretization2ndOrder._diff_order(e.args[0]) def _discretize_diffusion(self, expr): result = 0 @@ -110,7 +110,7 @@ class Discretization2ndOrder: return e.func(*new_args) if new_args else e def _discretize_diff(self, e): - order = self.__diff_order(e) + order = self._diff_order(e) if order == 1: fa = e.args[0] index = e.target diff --git a/fd/spatial.py b/fd/spatial.py new file mode 100644 index 0000000000000000000000000000000000000000..ff0a28bb5ad979ab7ca07c5318957830d9d750b2 --- /dev/null +++ b/fd/spatial.py @@ -0,0 +1,77 @@ +import sympy as sp +from functools import partial +from pystencils import AssignmentCollection, Field +from pystencils.fd import Diff +from .derivative import diff_args + + +def fd_stencils_standard(indices, dx, fa): + order = len(indices) + if order == 1: + idx = indices[0] + return (fa.neighbor(idx, 1) - fa.neighbor(idx, -1)) / (2 * dx) + elif order == 2: + if indices[0] == indices[1]: + return (-2 * fa + fa.neighbor(indices[0], -1) + fa.neighbor(indices[0], +1)) / (dx ** 2) + else: + offsets = [(1, 1), [-1, 1], [1, -1], [-1, -1]] + return sum(o1 * o2 * fa.neighbor(indices[0], o1).neighbor(indices[1], o2) + for o1, o2 in offsets) / (4 * dx ** 2) + raise NotImplementedError("Supports only derivatives up to order 2") + + +def fd_stencils_isotropic(indices, dx, fa): + dim = fa.field.spatial_dimensions + if dim == 1: + return fd_stencils_standard(indices, dx, fa) + elif dim == 2: + order = len(indices) + + if order == 1: + idx = indices[0] + assert 0 <= idx < 2 + other_idx = 1 if indices[0] == 0 else 0 + weights = {-1: sp.Rational(1, 12) / dx, + 0: sp.Rational(1, 3) / dx, + 1: sp.Rational(1, 12) / dx} + upper_terms = sum(fa.neighbor(idx, +1).neighbor(other_idx, off) * w for off, w in weights.items()) + lower_terms = sum(fa.neighbor(idx, -1).neighbor(other_idx, off) * w for off, w in weights.items()) + return upper_terms - lower_terms + elif order == 2: + if indices[0] == indices[1]: + idx = indices[0] + other_idx = 1 if idx == 0 else 0 + diagonals = sp.Rational(1, 12) * sum(fa.neighbor(0, i).neighbor(1, j) for i in (-1, 1) for j in (-1, 1)) + div_direction = sp.Rational(5, 6) * sum(fa.neighbor(idx, i) for i in (-1, 1)) + other_direction = - sp.Rational(1, 6) * sum(fa.neighbor(other_idx, i) for i in (-1, 1)) + center = - sp.Rational(5, 3) * fa + return (diagonals + div_direction + other_direction + center) / (dx ** 2) + else: + return fd_stencils_standard(indices, dx, fa) + raise NotImplementedError("Supports only derivatives up to order 2 for 1D and 2D setups") + + +def discretize_spatial(expr, dx, stencil=fd_stencils_standard): + if isinstance(stencil, str): + if stencil == 'standard': + stencil = fd_stencils_standard + elif stencil == 'isotropic': + stencil = fd_stencils_isotropic + else: + raise ValueError("Unknown stencil. Supported 'standard' and 'isotropic'") + + if isinstance(expr, list): + return [discretize_spatial(e, dx, stencil) for e in expr] + elif isinstance(expr, sp.Matrix): + return expr.applyfunc(partial(discretize_spatial, dx=dx, stencil=stencil)) + elif isinstance(expr, AssignmentCollection): + return expr.copy(main_assignments=[e for e in expr.main_assignments], + subexpressions=[e for e in expr.subexpressions]) + elif isinstance(expr, Diff): + arg, *indices = diff_args(expr) + if not isinstance(arg, Field.Access): + raise ValueError("Only derivatives with field or field accesses as arguments can be discretized") + return stencil(indices, dx, arg) + else: + new_args = [discretize_spatial(a, dx, stencil) for a in expr.args] + return expr.func(*new_args) if new_args else expr diff --git a/field.py b/field.py index d0e95b489f6357b26994577e304d741b7e77e8e0..e1e4401823454fda3d9b5d72e5f17a3dc25fc711 100644 --- a/field.py +++ b/field.py @@ -386,6 +386,8 @@ class Field: return hash((self._layout, self.shape, self.strides, self._dtype, self.field_type, self._field_name)) def __eq__(self, other): + if not isinstance(other, Field): + return False self_tuple = (self.shape, self.strides, self.name, self.dtype, self.field_type) other_tuple = (other.shape, other.strides, other.name, other.dtype, other.field_type) return self_tuple == other_tuple diff --git a/stencils.py b/stencils.py index 0efe601fdc744868b182e2f30c92b03e26d1d0af..d304cb5ae8a1fff4661b6caa49d27242e8f98612 100644 --- a/stencils.py +++ b/stencils.py @@ -171,10 +171,11 @@ def visualize_stencil_2d(stencil, axes=None, figure=None, data=None, textsize='1 """ Creates a matplotlib 2D plot of the stencil - :param stencil: sequence of directions - :param axes: optional matplotlib axes - :param data: data to annotate the directions with, if none given, the indices are used - :param textsize: size of annotation text + Args: + stencil: sequence of directions + axes: optional matplotlib axes + data: data to annotate the directions with, if none given, the indices are used + textsize: size of annotation text """ from matplotlib.patches import BoxStyle import matplotlib.pyplot as plt @@ -329,6 +330,7 @@ def visualize_stencil_3d(stencil, figure=None, axes=None, data=None, textsize='8 def visualize_stencil_expression(expr, **kwargs): + """Displays coefficients of a linear update expression of a single field as matplotlib arrow drawing.""" stencil, coefficients = stencil_coefficients(expr) dim = len(stencil[0]) assert 0 < dim <= 3