Commit 42c9e289 authored by Martin Bauer's avatar Martin Bauer
Browse files

pystencils.fds: staggered spatial finite differences

parent 54196a50
......@@ -3,9 +3,10 @@ from .derivative import Diff, DiffOperator, \
expand_diff_full, expand_diff_linear, expand_diff_products, combine_diff_products, \
functional_derivative, diff
from .finitedifferences import advection, diffusion, transient, Discretization2ndOrder
from .spatial import discretize_spatial
from .spatial import discretize_spatial, discretize_spatial_staggered
__all__ = ['Diff', 'diff', 'DiffOperator', 'diff_terms', 'collect_diffs',
'zero_diffs', 'evaluate_diffs', 'normalize_diff_order', 'expand_diff_full', 'expand_diff_linear',
'expand_diff_products', 'combine_diff_products', 'functional_derivative',
'advection', 'diffusion', 'transient', 'Discretization2ndOrder', 'discretize_spatial']
'advection', 'diffusion', 'transient', 'Discretization2ndOrder', 'discretize_spatial',
'discretize_spatial_staggered']
......@@ -4,6 +4,8 @@ from typing import Union, Optional
from pystencils import Field, AssignmentCollection
from pystencils.fd import Diff
from pystencils.fd.derivative import diff_args
from pystencils.fd.spatial import fd_stencils_standard
from pystencils.sympyextensions import fast_subs
......@@ -68,9 +70,10 @@ def transient(scalar, idx=None):
class Discretization2ndOrder:
def __init__(self, dx=sp.Symbol("dx"), dt=sp.Symbol("dt")):
def __init__(self, dx=sp.Symbol("dx"), dt=sp.Symbol("dt"), discretization_stencil_func=fd_stencils_standard):
self.dx = dx
self.dt = dt
self.spatial_stencil = discretization_stencil_func
@staticmethod
def _diff_order(e):
......@@ -104,7 +107,10 @@ class Discretization2ndOrder:
elif isinstance(e, Advection):
return self._discretize_advection(e)
elif isinstance(e, Diff):
return self._discretize_diff(e)
arg, *indices = diff_args(e)
if not isinstance(arg, Field.Access):
raise ValueError("Only derivatives with field or field accesses as arguments can be discretized")
return self.spatial_stencil(indices, self.dx, arg)
else:
new_args = [self._discretize_spatial(a) for a in e.args]
return e.func(*new_args) if new_args else e
......
from typing import Tuple
import sympy as sp
from functools import partial
from pystencils.astnodes import LoopOverCoordinate
from pystencils.cache import memorycache
from pystencils import AssignmentCollection, Field
from pystencils.fd import Diff
from pystencils.transformations import generic_visit
from .derivative import diff_args
from .derivation import FiniteDifferenceStencilDerivation
......@@ -107,22 +110,59 @@ def discretize_spatial(expr, dx, stencil=fd_stencils_standard):
else:
raise ValueError("Unknown stencil. Supported 'standard' and 'isotropic'")
if isinstance(expr, list):
return [discretize_spatial(e, dx, stencil) for e in expr]
elif isinstance(expr, sp.Matrix):
return expr.applyfunc(partial(discretize_spatial, dx=dx, stencil=stencil))
elif isinstance(expr, AssignmentCollection):
return expr.copy(main_assignments=[e for e in expr.main_assignments],
subexpressions=[e for e in expr.subexpressions])
elif isinstance(expr, Diff):
arg, *indices = diff_args(expr)
def visitor(e):
if isinstance(e, Diff):
arg, *indices = diff_args(e)
if not isinstance(arg, Field.Access):
raise ValueError("Only derivatives with field or field accesses as arguments can be discretized")
return stencil(indices, dx, arg)
else:
new_args = [discretize_spatial(a, dx, stencil) for a in expr.args]
return expr.func(*new_args) if new_args else expr
new_args = [discretize_spatial(a, dx, stencil) for a in e.args]
return e.func(*new_args) if new_args else e
return generic_visit(expr, visitor)
def discretize_spatial_staggered(expr, dx, stencil=fd_stencils_standard):
def staggered_visitor(e, coordinate, sign):
if isinstance(e, Diff):
arg, *indices = diff_args(e)
if len(indices) != 1:
raise ValueError("Function supports only up to second derivatives")
if not isinstance(arg, Field.Access):
raise ValueError("Argument of inner derivative has to be field access")
target = indices[0]
if target == coordinate:
assert sign in (-1, 1)
return (arg.neighbor(coordinate, sign) - arg) / dx * sign
else:
return (stencil(indices, dx, arg.neighbor(coordinate, sign))
+ stencil(indices, dx, arg)) / 2
elif isinstance(e, Field.Access):
return (e.neighbor(coordinate, sign) + e) / 2
elif isinstance(e, sp.Symbol):
loop_idx = LoopOverCoordinate.is_loop_counter_symbol(e)
return e + sign / 2 if loop_idx == coordinate else e
else:
new_args = [staggered_visitor(a, coordinate, sign) for a in e.args]
return e.func(*new_args) if new_args else e
def visitor(e):
if isinstance(e, Diff):
arg, *indices = diff_args(e)
if isinstance(arg, Field.Access):
return stencil(indices, dx, arg)
else:
if not len(indices) == 1:
raise ValueError("This term is not support by the staggered discretization strategy")
target = indices[0]
return (staggered_visitor(arg, target, 1) - staggered_visitor(arg, target, -1)) / dx
else:
new_args = [visitor(a) for a in e.args]
return e.func(*new_args) if new_args else e
return generic_visit(expr, visitor)
# -------------------------------------- special stencils --------------------------------------------------------------
......
......@@ -532,7 +532,7 @@ class Field:
"""Value of index coordinates as tuple."""
return self._index
def neighbor(self, coord_id: int, offset: Sequence[int]) -> 'Field.Access':
def neighbor(self, coord_id: int, offset: int) -> 'Field.Access':
"""Returns a new Access with changed spatial coordinates.
Args:
......
......@@ -7,6 +7,8 @@ import hashlib
import sympy as sp
from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.assignment import Assignment
from pystencils.field import Field, FieldType
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, reinterpret_cast_func, \
......@@ -80,6 +82,21 @@ def filtered_tree_iteration(node, node_type, stop_type=None):
yield from filtered_tree_iteration(arg, node_type)
def generic_visit(term, visitor):
if isinstance(term, AssignmentCollection):
new_main_assignments = generic_visit(term.main_assignments, visitor)
new_subexpressions = generic_visit(term.subexpressions, visitor)
return term.copy(new_main_assignments, new_subexpressions)
elif isinstance(term, list):
return [generic_visit(e, visitor) for e in term]
elif isinstance(term, Assignment):
return Assignment(term.lhs, generic_visit(term.rhs, visitor))
elif isinstance(term, sp.Matrix):
return term.applyfunc(lambda e: generic_visit(e, visitor))
else:
return visitor(term)
def unify_shape_symbols(body, common_shape, fields):
"""Replaces symbols for array sizes to ensure they are represented by the same unique symbol.
......
import sympy as sp
import pystencils as ps
from pystencils.astnodes import LoopOverCoordinate
from pystencils.stencils import stencil_coefficients
from pystencils.fd.spatial import fd_stencils_standard, fd_stencils_isotropic, discretize_spatial
from pystencils.fd import diff
......@@ -38,3 +39,32 @@ def test_spatial_1d_unit_sum():
discretized = discretize_spatial(term, dx=h, stencil=scheme)
_, coefficients = stencil_coefficients(discretized)
assert sum(coefficients) == 0
def test_staggered_laplacian():
f = ps.fields("f : double[2D]")
a, dx = sp.symbols("a, dx")
factored_version = sum(ps.fd.Diff(a * ps.fd.Diff(f[0, 0], i), i)
for i in range(2))
expanded = ps.fd.expand_diff_full(factored_version, constants=[a])
reference = ps.fd.discretize_spatial(expanded, dx).factor()
to_test = ps.fd.discretize_spatial_staggered(factored_version, dx).factor()
assert reference == to_test
def test_staggered_combined():
from pystencils.fd import diff
f = ps.fields("f : double[2D]")
x, y = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(2)]
dx = sp.symbols("dx")
expr = diff(x * diff(f, 0) + y * diff(f, 1), 0)
right = (x + sp.Rational(1, 2)) * (f[1, 0] - f[0, 0]) + y * (f[1, 1] - f[1, -1] + f[0, 1] - f[0, -1]) / 4
left = (x - sp.Rational(1, 2)) * (f[0, 0] - f[-1, 0]) + y * (f[-1, 1] - f[-1, -1] + f[0, 1] - f[0, -1]) / 4
reference = (right - left) / (dx ** 2)
to_test = ps.fd.discretize_spatial_staggered(expr, dx)
assert sp.expand(reference - to_test) == 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment