Skip to content
Snippets Groups Projects
Commit 2e42e5ba authored by Martin Bauer's avatar Martin Bauer
Browse files

Generalized assignment_from_stencil

- relative application of stencil by passing in field accesses
- documentation + doctest
parent aa2c89d8
Branches
Tags
No related merge requests found
...@@ -15,11 +15,11 @@ def print_assignment_latex(printer, expr): ...@@ -15,11 +15,11 @@ def print_assignment_latex(printer, expr):
"""sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer""" """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer"""
printed_lhs = printer.doprint(expr.lhs) printed_lhs = printer.doprint(expr.lhs)
printed_rhs = printer.doprint(expr.rhs) printed_rhs = printer.doprint(expr.rhs)
return "{printed_lhs} \leftarrow {printed_rhs}".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs) return r"{printed_lhs} \leftarrow {printed_rhs}".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs)
def assignment_str(assignment): def assignment_str(assignment):
return "{lhs} ← {rhs}".format(lhs=assignment.lhs, rhs=assignment.rhs) return r"{lhs} ← {rhs}".format(lhs=assignment.lhs, rhs=assignment.rhs)
if Assignment: if Assignment:
...@@ -51,15 +51,64 @@ else: ...@@ -51,15 +51,64 @@ else:
_print_Assignment = print_assignment_latex _print_Assignment = print_assignment_latex
def assignment_from_stencil(stencil_array, input_field, output_field, normalization_factor=None): def assignment_from_stencil(stencil_array, input_field, output_field,
normalization_factor=None, order='visual') -> Assignment:
"""Creates an assignment
Args:
stencil_array: nested list of numpy array defining the stencil weights
input_field: field or field access, defining where the stencil should be applied to
output_field: field or field access where the result is written to
normalization_factor: optional normalization factor for the stencil
order: defines how the stencil_array is interpreted. Possible values are 'visual' and 'numpy'.
For details see examples
Returns:
Assignment that can be used to create a kernel
Examples:
>>> import pystencils as ps
>>> f, g = ps.fields("f, g: [2D]")
>>> stencil = [[0, 2, 0],
... [3, 4, 5],
... [0, 6, 0]]
By default 'visual ordering is used - i.e. the stencil is applied as the nested lists are written down
>>> assignment_from_stencil(stencil, f, g, order='visual')
Assignment(g_C, 3*f_W + 6*f_S + 4*f_C + 2*f_N + 5*f_E)
'numpy' ordering uses the first coordinate of the stencil array for x offset, second for y offset etc.
>>> assignment_from_stencil(stencil, f, g, order='numpy')
Assignment(g_C, 2*f_W + 3*f_S + 4*f_C + 5*f_N + 6*f_E)
You can also pass field accesses to apply the stencil at an already shifted position:
>>> assignment_from_stencil(stencil, f[1, 0], g[2, 0])
Assignment(g_2E, 3*f_C + 6*f_SE + 4*f_E + 2*f_NE + 5*f_2E)
"""
from pystencils import Field
stencil_array = np.array(stencil_array) stencil_array = np.array(stencil_array)
if order == 'visual':
stencil_array = np.swapaxes(stencil_array, 0, 1)
stencil_array = np.flip(stencil_array, axis=1)
elif order == 'numpy':
pass
else:
raise ValueError("'order' has to be either 'visual' or 'numpy'")
if isinstance(input_field, Field):
input_field = input_field.center
if isinstance(output_field, Field):
output_field = output_field.center
rhs = 0 rhs = 0
offset = tuple(s // 2 for s in stencil_array.shape) offset = tuple(s // 2 for s in stencil_array.shape)
for index, factor in np.ndenumerate(stencil_array): for index, factor in np.ndenumerate(stencil_array):
rhs += factor * input_field[tuple(i - o for i, o in zip(index, offset))] shift = tuple(i - o for i, o in zip(index, offset))
rhs += factor * input_field.get_shifted(*shift)
if normalization_factor: if normalization_factor:
rhs *= normalization_factor rhs *= normalization_factor
return Assignment(output_field.center(), rhs) return Assignment(output_field, rhs)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment