Commit fe199929 authored by Martin Bauer's avatar Martin Bauer
Browse files

Refactored finite difference spatial discretization

- added isotropic version
parent 8a517fd1
import sympy as sp import sympy as sp
from collections import namedtuple, defaultdict from collections import namedtuple, defaultdict
from pystencils import Field
from pystencils.sympyextensions import normalize_product, prod from pystencils.sympyextensions import normalize_product, prod
...@@ -22,6 +24,8 @@ class Diff(sp.Expr): ...@@ -22,6 +24,8 @@ class Diff(sp.Expr):
def __new__(cls, argument, target=-1, superscript=-1): def __new__(cls, argument, target=-1, superscript=-1):
if argument == 0: if argument == 0:
return sp.Rational(0, 1) return sp.Rational(0, 1)
if isinstance(argument, Field):
argument = argument.center
return sp.Expr.__new__(cls, argument.expand(), sp.sympify(target), sp.sympify(superscript)) return sp.Expr.__new__(cls, argument.expand(), sp.sympify(target), sp.sympify(superscript))
@property @property
...@@ -176,6 +180,35 @@ class DiffOperator(sp.Expr): ...@@ -176,6 +180,35 @@ class DiffOperator(sp.Expr):
# ---------------------------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------------------------
def diff(expr, *args):
"""Shortcut function to create nested derivatives
>>> f = sp.Symbol("f")
>>> diff(f, 0, 0, 1) == Diff(Diff( Diff(f, 1), 0), 0)
True
"""
if len(args) == 0:
return expr
result = expr
for index in reversed(args):
result = Diff(result, index)
return result
def diff_args(expr):
"""Extracts the indices and argument of possibly nested derivative - inverse of diff function
>>> args = (sp.Symbol("x"), 0, 1, 2, 5, 1)
>>> e = diff(*args)
>>> assert diff_args(e) == args
"""
if not isinstance(expr, Diff):
return expr,
else:
inner_res = diff_args(expr.args[0])
return (inner_res[0], expr.args[1], *inner_res[1:])
def diff_terms(expr): def diff_terms(expr):
"""Returns set of all derivatives in an expression. """Returns set of all derivatives in an expression.
...@@ -200,16 +233,6 @@ def collect_diffs(expr): ...@@ -200,16 +233,6 @@ def collect_diffs(expr):
return expr.collect(diff_terms(expr)) return expr.collect(diff_terms(expr))
def create_nested_diff(arg, *args):
"""Shortcut to create nested derivatives"""
assert arg is not None
args = sorted(args, reverse=True, key=lambda e: e.name if isinstance(e, sp.Symbol) else e)
res = arg
for i in args:
res = Diff(res, i)
return res
def replace_diff(expr, replacement_dict): def replace_diff(expr, replacement_dict):
"""replacement_dict: maps variable (target) to a new Differential operator""" """replacement_dict: maps variable (target) to a new Differential operator"""
...@@ -464,6 +487,27 @@ def combine_diff_products(expr): ...@@ -464,6 +487,27 @@ def combine_diff_products(expr):
return combine(expr) return combine(expr)
def replace_generic_laplacian(expr, dim=None):
"""Laplacian can be written as Diff(Diff(term)) without explicitly giving the dimensions.
This function replaces these constructs by diff(term, 0, 0) + diff(term, 1, 1) + ...
For this to work, the arguments of the derivative have to be field or field accesses such that the spatial
dimension can be determined.
"""
if isinstance(expr, Diff):
arg, *indices = diff_args(expr)
if isinstance(arg, Field.Access):
dim = arg.field.spatial_dimensions
assert dim is not None
if len(indices) == 2 and all(i == -1 for i in indices):
return sum(diff(arg, i, i) for i in range(dim))
else:
return expr
else:
new_args = [replace_generic_laplacian(a, dim) for a in expr.args]
return expr.func(*new_args) if new_args else expr
def functional_derivative(functional, v): def functional_derivative(functional, v):
r"""Computes functional derivative of functional with respect to v using Euler-Lagrange equation r"""Computes functional derivative of functional with respect to v using Euler-Lagrange equation
......
...@@ -73,11 +73,11 @@ class Discretization2ndOrder: ...@@ -73,11 +73,11 @@ class Discretization2ndOrder:
self.dt = dt self.dt = dt
@staticmethod @staticmethod
def __diff_order(e): def _diff_order(e):
if not isinstance(e, Diff): if not isinstance(e, Diff):
return 0 return 0
else: else:
return 1 + Discretization2ndOrder.__diff_order(e.args[0]) return 1 + Discretization2ndOrder._diff_order(e.args[0])
def _discretize_diffusion(self, expr): def _discretize_diffusion(self, expr):
result = 0 result = 0
...@@ -110,7 +110,7 @@ class Discretization2ndOrder: ...@@ -110,7 +110,7 @@ class Discretization2ndOrder:
return e.func(*new_args) if new_args else e return e.func(*new_args) if new_args else e
def _discretize_diff(self, e): def _discretize_diff(self, e):
order = self.__diff_order(e) order = self._diff_order(e)
if order == 1: if order == 1:
fa = e.args[0] fa = e.args[0]
index = e.target index = e.target
......
import sympy as sp
from functools import partial
from pystencils import AssignmentCollection, Field
from pystencils.fd import Diff
from .derivative import diff_args
def fd_stencils_standard(indices, dx, fa):
order = len(indices)
if order == 1:
idx = indices[0]
return (fa.neighbor(idx, 1) - fa.neighbor(idx, -1)) / (2 * dx)
elif order == 2:
if indices[0] == indices[1]:
return (-2 * fa + fa.neighbor(indices[0], -1) + fa.neighbor(indices[0], +1)) / (dx ** 2)
else:
offsets = [(1, 1), [-1, 1], [1, -1], [-1, -1]]
return sum(o1 * o2 * fa.neighbor(indices[0], o1).neighbor(indices[1], o2)
for o1, o2 in offsets) / (4 * dx ** 2)
raise NotImplementedError("Supports only derivatives up to order 2")
def fd_stencils_isotropic(indices, dx, fa):
dim = fa.field.spatial_dimensions
if dim == 1:
return fd_stencils_standard(indices, dx, fa)
elif dim == 2:
order = len(indices)
if order == 1:
idx = indices[0]
assert 0 <= idx < 2
other_idx = 1 if indices[0] == 0 else 0
weights = {-1: sp.Rational(1, 12) / dx,
0: sp.Rational(1, 3) / dx,
1: sp.Rational(1, 12) / dx}
upper_terms = sum(fa.neighbor(idx, +1).neighbor(other_idx, off) * w for off, w in weights.items())
lower_terms = sum(fa.neighbor(idx, -1).neighbor(other_idx, off) * w for off, w in weights.items())
return upper_terms - lower_terms
elif order == 2:
if indices[0] == indices[1]:
idx = indices[0]
other_idx = 1 if idx == 0 else 0
diagonals = sp.Rational(1, 12) * sum(fa.neighbor(0, i).neighbor(1, j) for i in (-1, 1) for j in (-1, 1))
div_direction = sp.Rational(5, 6) * sum(fa.neighbor(idx, i) for i in (-1, 1))
other_direction = - sp.Rational(1, 6) * sum(fa.neighbor(other_idx, i) for i in (-1, 1))
center = - sp.Rational(5, 3) * fa
return (diagonals + div_direction + other_direction + center) / (dx ** 2)
else:
return fd_stencils_standard(indices, dx, fa)
raise NotImplementedError("Supports only derivatives up to order 2 for 1D and 2D setups")
def discretize_spatial(expr, dx, stencil=fd_stencils_standard):
if isinstance(stencil, str):
if stencil == 'standard':
stencil = fd_stencils_standard
elif stencil == 'isotropic':
stencil = fd_stencils_isotropic
else:
raise ValueError("Unknown stencil. Supported 'standard' and 'isotropic'")
if isinstance(expr, list):
return [discretize_spatial(e, dx, stencil) for e in expr]
elif isinstance(expr, sp.Matrix):
return expr.applyfunc(partial(discretize_spatial, dx=dx, stencil=stencil))
elif isinstance(expr, AssignmentCollection):
return expr.copy(main_assignments=[e for e in expr.main_assignments],
subexpressions=[e for e in expr.subexpressions])
elif isinstance(expr, Diff):
arg, *indices = diff_args(expr)
if not isinstance(arg, Field.Access):
raise ValueError("Only derivatives with field or field accesses as arguments can be discretized")
return stencil(indices, dx, arg)
else:
new_args = [discretize_spatial(a, dx, stencil) for a in expr.args]
return expr.func(*new_args) if new_args else expr
...@@ -386,6 +386,8 @@ class Field: ...@@ -386,6 +386,8 @@ class Field:
return hash((self._layout, self.shape, self.strides, self._dtype, self.field_type, self._field_name)) return hash((self._layout, self.shape, self.strides, self._dtype, self.field_type, self._field_name))
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, Field):
return False
self_tuple = (self.shape, self.strides, self.name, self.dtype, self.field_type) self_tuple = (self.shape, self.strides, self.name, self.dtype, self.field_type)
other_tuple = (other.shape, other.strides, other.name, other.dtype, other.field_type) other_tuple = (other.shape, other.strides, other.name, other.dtype, other.field_type)
return self_tuple == other_tuple return self_tuple == other_tuple
......
...@@ -171,10 +171,11 @@ def visualize_stencil_2d(stencil, axes=None, figure=None, data=None, textsize='1 ...@@ -171,10 +171,11 @@ def visualize_stencil_2d(stencil, axes=None, figure=None, data=None, textsize='1
""" """
Creates a matplotlib 2D plot of the stencil Creates a matplotlib 2D plot of the stencil
:param stencil: sequence of directions Args:
:param axes: optional matplotlib axes stencil: sequence of directions
:param data: data to annotate the directions with, if none given, the indices are used axes: optional matplotlib axes
:param textsize: size of annotation text data: data to annotate the directions with, if none given, the indices are used
textsize: size of annotation text
""" """
from matplotlib.patches import BoxStyle from matplotlib.patches import BoxStyle
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -329,6 +330,7 @@ def visualize_stencil_3d(stencil, figure=None, axes=None, data=None, textsize='8 ...@@ -329,6 +330,7 @@ def visualize_stencil_3d(stencil, figure=None, axes=None, data=None, textsize='8
def visualize_stencil_expression(expr, **kwargs): def visualize_stencil_expression(expr, **kwargs):
"""Displays coefficients of a linear update expression of a single field as matplotlib arrow drawing."""
stencil, coefficients = stencil_coefficients(expr) stencil, coefficients = stencil_coefficients(expr)
dim = len(stencil[0]) dim = len(stencil[0])
assert 0 < dim <= 3 assert 0 < dim <= 3
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment