Commit fe199929 authored by Martin Bauer's avatar Martin Bauer
Browse files

Refactored finite difference spatial discretization

- added isotropic version
parent 8a517fd1
import sympy as sp
from collections import namedtuple, defaultdict
from pystencils import Field
from pystencils.sympyextensions import normalize_product, prod
......@@ -22,6 +24,8 @@ class Diff(sp.Expr):
def __new__(cls, argument, target=-1, superscript=-1):
if argument == 0:
return sp.Rational(0, 1)
if isinstance(argument, Field):
argument = argument.center
return sp.Expr.__new__(cls, argument.expand(), sp.sympify(target), sp.sympify(superscript))
@property
......@@ -176,6 +180,35 @@ class DiffOperator(sp.Expr):
# ----------------------------------------------------------------------------------------------------------------------
def diff(expr, *args):
"""Shortcut function to create nested derivatives
>>> f = sp.Symbol("f")
>>> diff(f, 0, 0, 1) == Diff(Diff( Diff(f, 1), 0), 0)
True
"""
if len(args) == 0:
return expr
result = expr
for index in reversed(args):
result = Diff(result, index)
return result
def diff_args(expr):
"""Extracts the indices and argument of possibly nested derivative - inverse of diff function
>>> args = (sp.Symbol("x"), 0, 1, 2, 5, 1)
>>> e = diff(*args)
>>> assert diff_args(e) == args
"""
if not isinstance(expr, Diff):
return expr,
else:
inner_res = diff_args(expr.args[0])
return (inner_res[0], expr.args[1], *inner_res[1:])
def diff_terms(expr):
"""Returns set of all derivatives in an expression.
......@@ -200,16 +233,6 @@ def collect_diffs(expr):
return expr.collect(diff_terms(expr))
def create_nested_diff(arg, *args):
"""Shortcut to create nested derivatives"""
assert arg is not None
args = sorted(args, reverse=True, key=lambda e: e.name if isinstance(e, sp.Symbol) else e)
res = arg
for i in args:
res = Diff(res, i)
return res
def replace_diff(expr, replacement_dict):
"""replacement_dict: maps variable (target) to a new Differential operator"""
......@@ -464,6 +487,27 @@ def combine_diff_products(expr):
return combine(expr)
def replace_generic_laplacian(expr, dim=None):
"""Laplacian can be written as Diff(Diff(term)) without explicitly giving the dimensions.
This function replaces these constructs by diff(term, 0, 0) + diff(term, 1, 1) + ...
For this to work, the arguments of the derivative have to be field or field accesses such that the spatial
dimension can be determined.
"""
if isinstance(expr, Diff):
arg, *indices = diff_args(expr)
if isinstance(arg, Field.Access):
dim = arg.field.spatial_dimensions
assert dim is not None
if len(indices) == 2 and all(i == -1 for i in indices):
return sum(diff(arg, i, i) for i in range(dim))
else:
return expr
else:
new_args = [replace_generic_laplacian(a, dim) for a in expr.args]
return expr.func(*new_args) if new_args else expr
def functional_derivative(functional, v):
r"""Computes functional derivative of functional with respect to v using Euler-Lagrange equation
......
......@@ -73,11 +73,11 @@ class Discretization2ndOrder:
self.dt = dt
@staticmethod
def __diff_order(e):
def _diff_order(e):
if not isinstance(e, Diff):
return 0
else:
return 1 + Discretization2ndOrder.__diff_order(e.args[0])
return 1 + Discretization2ndOrder._diff_order(e.args[0])
def _discretize_diffusion(self, expr):
result = 0
......@@ -110,7 +110,7 @@ class Discretization2ndOrder:
return e.func(*new_args) if new_args else e
def _discretize_diff(self, e):
order = self.__diff_order(e)
order = self._diff_order(e)
if order == 1:
fa = e.args[0]
index = e.target
......
import sympy as sp
from functools import partial
from pystencils import AssignmentCollection, Field
from pystencils.fd import Diff
from .derivative import diff_args
def fd_stencils_standard(indices, dx, fa):
order = len(indices)
if order == 1:
idx = indices[0]
return (fa.neighbor(idx, 1) - fa.neighbor(idx, -1)) / (2 * dx)
elif order == 2:
if indices[0] == indices[1]:
return (-2 * fa + fa.neighbor(indices[0], -1) + fa.neighbor(indices[0], +1)) / (dx ** 2)
else:
offsets = [(1, 1), [-1, 1], [1, -1], [-1, -1]]
return sum(o1 * o2 * fa.neighbor(indices[0], o1).neighbor(indices[1], o2)
for o1, o2 in offsets) / (4 * dx ** 2)
raise NotImplementedError("Supports only derivatives up to order 2")
def fd_stencils_isotropic(indices, dx, fa):
dim = fa.field.spatial_dimensions
if dim == 1:
return fd_stencils_standard(indices, dx, fa)
elif dim == 2:
order = len(indices)
if order == 1:
idx = indices[0]
assert 0 <= idx < 2
other_idx = 1 if indices[0] == 0 else 0
weights = {-1: sp.Rational(1, 12) / dx,
0: sp.Rational(1, 3) / dx,
1: sp.Rational(1, 12) / dx}
upper_terms = sum(fa.neighbor(idx, +1).neighbor(other_idx, off) * w for off, w in weights.items())
lower_terms = sum(fa.neighbor(idx, -1).neighbor(other_idx, off) * w for off, w in weights.items())
return upper_terms - lower_terms
elif order == 2:
if indices[0] == indices[1]:
idx = indices[0]
other_idx = 1 if idx == 0 else 0
diagonals = sp.Rational(1, 12) * sum(fa.neighbor(0, i).neighbor(1, j) for i in (-1, 1) for j in (-1, 1))
div_direction = sp.Rational(5, 6) * sum(fa.neighbor(idx, i) for i in (-1, 1))
other_direction = - sp.Rational(1, 6) * sum(fa.neighbor(other_idx, i) for i in (-1, 1))
center = - sp.Rational(5, 3) * fa
return (diagonals + div_direction + other_direction + center) / (dx ** 2)
else:
return fd_stencils_standard(indices, dx, fa)
raise NotImplementedError("Supports only derivatives up to order 2 for 1D and 2D setups")
def discretize_spatial(expr, dx, stencil=fd_stencils_standard):
if isinstance(stencil, str):
if stencil == 'standard':
stencil = fd_stencils_standard
elif stencil == 'isotropic':
stencil = fd_stencils_isotropic
else:
raise ValueError("Unknown stencil. Supported 'standard' and 'isotropic'")
if isinstance(expr, list):
return [discretize_spatial(e, dx, stencil) for e in expr]
elif isinstance(expr, sp.Matrix):
return expr.applyfunc(partial(discretize_spatial, dx=dx, stencil=stencil))
elif isinstance(expr, AssignmentCollection):
return expr.copy(main_assignments=[e for e in expr.main_assignments],
subexpressions=[e for e in expr.subexpressions])
elif isinstance(expr, Diff):
arg, *indices = diff_args(expr)
if not isinstance(arg, Field.Access):
raise ValueError("Only derivatives with field or field accesses as arguments can be discretized")
return stencil(indices, dx, arg)
else:
new_args = [discretize_spatial(a, dx, stencil) for a in expr.args]
return expr.func(*new_args) if new_args else expr
......@@ -386,6 +386,8 @@ class Field:
return hash((self._layout, self.shape, self.strides, self._dtype, self.field_type, self._field_name))
def __eq__(self, other):
if not isinstance(other, Field):
return False
self_tuple = (self.shape, self.strides, self.name, self.dtype, self.field_type)
other_tuple = (other.shape, other.strides, other.name, other.dtype, other.field_type)
return self_tuple == other_tuple
......
......@@ -171,10 +171,11 @@ def visualize_stencil_2d(stencil, axes=None, figure=None, data=None, textsize='1
"""
Creates a matplotlib 2D plot of the stencil
:param stencil: sequence of directions
:param axes: optional matplotlib axes
:param data: data to annotate the directions with, if none given, the indices are used
:param textsize: size of annotation text
Args:
stencil: sequence of directions
axes: optional matplotlib axes
data: data to annotate the directions with, if none given, the indices are used
textsize: size of annotation text
"""
from matplotlib.patches import BoxStyle
import matplotlib.pyplot as plt
......@@ -329,6 +330,7 @@ def visualize_stencil_3d(stencil, figure=None, axes=None, data=None, textsize='8
def visualize_stencil_expression(expr, **kwargs):
"""Displays coefficients of a linear update expression of a single field as matplotlib arrow drawing."""
stencil, coefficients = stencil_coefficients(expr)
dim = len(stencil[0])
assert 0 < dim <= 3
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment