Commit 362b4611 authored by Martin Bauer's avatar Martin Bauer
Browse files

Clean submodules for data handling and finite differences

parent b816bb31
......@@ -8,6 +8,7 @@ from .display_utils import show_code, to_dot
from .assignment_collection import AssignmentCollection
from .assignment import Assignment
from .sympyextensions import SymbolCreator
from .datahandling import create_data_handling
__all__ = ['Field', 'FieldType',
'TypedSymbol',
......@@ -16,4 +17,5 @@ __all__ = ['Field', 'FieldType',
'show_code', 'to_dot',
'AssignmentCollection',
'Assignment',
'SymbolCreator']
'SymbolCreator',
'create_data_handling']
......@@ -14,8 +14,8 @@ class FlagInterface:
"""Manages the reservation of bits (i.e. flags) in an array of unsigned integers.
Examples:
>>> from pystencils.datahandling import SerialDataHandling
>>> dh = SerialDataHandling((4, 5))
>>> from pystencils import create_data_handling
>>> dh = create_data_handling((4, 5))
>>> fi = FlagInterface(dh, 'flag_field', np.uint8)
>>> assert dh.has_data('flag_field')
>>> fi.reserve_next_flag()
......
from pystencils.datahandling.serial_datahandling import SerialDataHandling
from typing import Tuple, Union
from .serial_datahandling import SerialDataHandling
from .datahandling_interface import DataHandling
try:
# noinspection PyPep8Naming
......@@ -12,7 +14,23 @@ except ImportError:
ParallelDataHandling = None
def create_data_handling(parallel, domain_size, periodicity, default_layout='SoA', default_ghost_layers=1):
def create_data_handling(domain_size: Tuple[int, ...],
periodicity: Union[bool, Tuple[bool, ...]] = False,
default_layout: str = 'SoA',
parallel: bool = False,
default_ghost_layers: int = 1) -> DataHandling:
"""Creates a data handling instance.
Args:
parallel:
domain_size:
periodicity:
default_layout:
default_ghost_layers:
Returns:
"""
if parallel:
if wlb is None:
raise ValueError("Cannot create parallel data handling because walberla module is not available")
......@@ -39,3 +57,6 @@ def create_data_handling(parallel, domain_size, periodicity, default_layout='SoA
else:
return SerialDataHandling(domain_size, periodicity=periodicity,
default_layout=default_layout, default_ghost_layers=default_ghost_layers)
__all__ = ['create_data_handling']
import numpy as np
from pystencils import Field
from pystencils.datahandling.datahandling_interface import DataHandling
from pystencils.parallel.blockiteration import sliced_block_iteration, block_iteration
from pystencils.datahandling.blockiteration import sliced_block_iteration, block_iteration
from pystencils.utils import DotDict
# noinspection PyPep8Naming
import waLBerla as wlb
......
......@@ -309,7 +309,7 @@ class SerialDataHandling(DataHandling):
return np.array(sequence)
def create_vtk_writer(self, file_name, data_names, ghost_layers=False):
from pystencils.vtk import image_to_vtk
from pystencils.datahandling.vtk import image_to_vtk
def writer(step):
full_file_name = "%s_%08d" % (file_name, step,)
......@@ -336,7 +336,7 @@ class SerialDataHandling(DataHandling):
return writer
def create_vtk_writer_for_flag_array(self, file_name, data_name, masks_to_name, ghost_layers=False):
from pystencils.vtk import image_to_vtk
from pystencils.datahandling.vtk import image_to_vtk
def writer(step):
full_file_name = "%s_%08d" % (file_name, step,)
......
from .derivative import Diff, DiffOperator, \
diff_terms, collect_diffs, create_nested_diff, replace_diff, zero_diffs, evaluate_diffs, normalize_diff_order, \
expand_diff_full, expand_diff_linear, expand_diff_products, combine_diff_products, \
functional_derivative
from .finitedifferences import advection, diffusion, transient, Discretization2ndOrder
__all__ = ['Diff', 'DiffOperator', 'diff_terms', 'collect_diffs', 'create_nested_diff', 'replace_diff', 'zero_diffs',
'evaluate_diffs', 'normalize_diff_order', 'expand_diff_full', 'expand_diff_linear',
'expand_diff_products', 'combine_diff_products', 'functional_derivative']
......@@ -3,13 +3,14 @@ from collections import namedtuple, defaultdict
from pystencils.sympyextensions import normalize_product, prod
def default_diff_sort_key(d):
def _default_diff_sort_key(d):
return str(d.superscript), str(d.target)
class Diff(sp.Expr):
"""
Sympy Node representing a derivative. The difference to sympy's built in differential is:
"""Sympy Node representing a derivative.
The difference to sympy's built in differential is:
- shortened latex representation
- all simplifications have to be done manually
- optional marker displayed as superscript
......@@ -156,7 +157,7 @@ class DiffOperator(sp.Expr):
if len(diffs) == 0:
return mul * argument if apply_to_constants else mul
rest = [a for a in args if not isinstance(a, DiffOperator)]
diffs.sort(key=default_diff_sort_key)
diffs.sort(key=_default_diff_sort_key)
result = argument
for d in reversed(diffs):
result = Diff(result, target=d.target, superscript=d.superscript)
......@@ -174,10 +175,10 @@ class DiffOperator(sp.Expr):
# ----------------------------------------------------------------------------------------------------------------------
def derivative_terms(expr):
"""
Returns set of all derivatives in an expression
this is different from `expr.atoms(Diff)` when nested derivatives are in the expression,
def diff_terms(expr):
"""Returns set of all derivatives in an expression.
This function yields different results than `expr.atoms(Diff)` when nested derivatives are in the expression,
since this function only returns the outer derivatives
"""
result = set()
......@@ -193,9 +194,9 @@ def derivative_terms(expr):
return result
def collect_derivatives(expr):
def collect_diffs(expr):
"""Rewrites expression into a sum of distinct derivatives with pre-factors"""
return expr.collect(derivative_terms(expr))
return expr.collect(diff_terms(expr))
def create_nested_diff(arg, *args):
......@@ -208,39 +209,73 @@ def create_nested_diff(arg, *args):
return res
def expand_using_linearity(expr, functions=None, constants=None):
"""
Expands all derivative nodes by applying Diff.split_linear
:param expr: expression containing derivatives
:param functions: sequence of symbols that are considered functions and can not be pulled before the derivative.
if None, all symbols are viewed as functions
:param constants: sequence of symbols which are considered constants and can be pulled before the derivative
"""
if functions is None:
functions = expr.atoms(sp.Symbol)
if constants is not None:
functions.difference_update(constants)
def replace_diff(expr, replacement_dict):
"""replacement_dict: maps variable (target) to a new Differential operator"""
def visit(e):
if isinstance(e, Diff):
if e.target in replacement_dict:
return DiffOperator.apply(replacement_dict[e.target], visit(e.arg))
new_args = [visit(arg) for arg in e.args]
return e.func(*new_args) if new_args else e
return visit(expr)
def zero_diffs(expr, label):
"""Replaces all differentials with the given target by 0"""
def visit(e):
if isinstance(e, Diff):
if e.target == label:
return 0
new_args = [visit(arg) for arg in e.args]
return e.func(*new_args) if new_args else e
return visit(expr)
def evaluate_diffs(expr, var=None):
"""Replaces pystencils diff objects by sympy diff objects and evaluates them.
Replaces Diff nodes by sp.diff , the free variable is either the target (if var=None) otherwise
the specified var
"""
if isinstance(expr, Diff):
arg = expand_using_linearity(expr.arg, functions)
if hasattr(arg, 'func') and arg.func == sp.Add:
result = 0
for a in arg.args:
result += Diff(a, target=expr.target, superscript=expr.superscript).split_linear(functions)
if var is None:
var = expr.target
return sp.diff(evaluate_diffs(expr.arg, var), var)
else:
new_args = [evaluate_diffs(arg, var) for arg in expr.args]
return expr.func(*new_args) if new_args else expr
def normalize_diff_order(expression, functions=None, constants=None, sort_key=_default_diff_sort_key):
"""Assumes order of differentiation can be exchanged. Changes the order of nested Diffs to a standard order defined
by the sorting key 'sort_key' such that the derivative terms can be further simplified """
def visit(expr):
if isinstance(expr, Diff):
nodes = [expr]
while isinstance(nodes[-1].arg, Diff):
nodes.append(nodes[-1].arg)
processed_arg = visit(nodes[-1].arg)
nodes.sort(key=sort_key)
result = processed_arg
for d in reversed(nodes):
result = Diff(result, target=d.target, superscript=d.superscript)
return result
else:
diff = Diff(arg, target=expr.target, superscript=expr.superscript)
if diff == 0:
return 0
else:
return diff.split_linear(functions)
else:
new_args = [expand_using_linearity(e, functions) for e in expr.args]
result = sp.expand(expr.func(*new_args) if new_args else expr)
return result
new_args = [visit(e) for e in expr.args]
return expr.func(*new_args) if new_args else expr
expression = expand_diff_linear(expression.expand(), functions, constants).expand()
return visit(expression)
def full_diff_expand(expr, functions=None, constants=None):
def expand_diff_full(expr, functions=None, constants=None):
if functions is None:
functions = expr.atoms(sp.Symbol)
if constants is not None:
......@@ -278,35 +313,43 @@ def full_diff_expand(expr, functions=None, constants=None):
return visit(expr)
def normalize_diff_order(expression, functions=None, constants=None, sort_key=default_diff_sort_key):
"""Assumes order of differentiation can be exchanged. Changes the order of nested Diffs to a standard order defined
by the sorting key 'sort_key' such that the derivative terms can be further simplified """
def expand_diff_linear(expr, functions=None, constants=None):
"""Expands all derivative nodes by applying Diff.split_linear
def visit(expr):
if isinstance(expr, Diff):
nodes = [expr]
while isinstance(nodes[-1].arg, Diff):
nodes.append(nodes[-1].arg)
processed_arg = visit(nodes[-1].arg)
nodes.sort(key=sort_key)
Args:
expr: expression containing derivatives
functions: sequence of symbols that are considered functions and can not be pulled before the derivative.
if None, all symbols are viewed as functions
constants: sequence of symbols which are considered constants and can be pulled before the derivative
"""
if functions is None:
functions = expr.atoms(sp.Symbol)
if constants is not None:
functions.difference_update(constants)
result = processed_arg
for d in reversed(nodes):
result = Diff(result, target=d.target, superscript=d.superscript)
if isinstance(expr, Diff):
arg = expand_diff_linear(expr.arg, functions)
if hasattr(arg, 'func') and arg.func == sp.Add:
result = 0
for a in arg.args:
result += Diff(a, target=expr.target, superscript=expr.superscript).split_linear(functions)
return result
else:
new_args = [visit(e) for e in expr.args]
return expr.func(*new_args) if new_args else expr
expression = expand_using_linearity(expression.expand(), functions, constants).expand()
return visit(expression)
diff = Diff(arg, target=expr.target, superscript=expr.superscript)
if diff == 0:
return 0
else:
return diff.split_linear(functions)
else:
new_args = [expand_diff_linear(e, functions) for e in expr.args]
result = sp.expand(expr.func(*new_args) if new_args else expr)
return result
def expand_using_product_rule(expr):
def expand_diff_products(expr):
"""Fully expands all derivatives by applying product rule"""
if isinstance(expr, Diff):
arg = expand_using_product_rule(expr.args[0])
arg = expand_diff_products(expr.args[0])
if arg.func == sp.Add:
new_args = [Diff(e, target=expr.target, superscript=expr.superscript)
for e in arg.args]
......@@ -321,11 +364,11 @@ def expand_using_product_rule(expr):
result += pre_factor * Diff(prod_list[i], target=expr.target, superscript=expr.superscript)
return result
else:
new_args = [expand_using_product_rule(e) for e in expr.args]
new_args = [expand_diff_products(e) for e in expr.args]
return expr.func(*new_args) if new_args else expr
def combine_using_product_rule(expr):
def combine_diff_products(expr):
"""Inverse product rule"""
def expr_to_diff_decomposition(expression):
......@@ -408,53 +451,14 @@ def combine_using_product_rule(expr):
rest += process_diff_list(diff_list, label, superscript)
return rest
else:
new_args = [combine_using_product_rule(e) for e in expression.args]
new_args = [combine_diff_products(e) for e in expression.args]
return expression.func(*new_args) if new_args else expression
return combine(expr)
def replace_diff(expr, replacement_dict):
"""replacement_dict: maps variable (target) to a new Differential operator"""
def visit(e):
if isinstance(e, Diff):
if e.target in replacement_dict:
return DiffOperator.apply(replacement_dict[e.target], visit(e.arg))
new_args = [visit(arg) for arg in e.args]
return e.func(*new_args) if new_args else e
return visit(expr)
def zero_diffs(expr, label):
"""Replaces all differentials with the given target by 0"""
def visit(e):
if isinstance(e, Diff):
if e.target == label:
return 0
new_args = [visit(arg) for arg in e.args]
return e.func(*new_args) if new_args else e
return visit(expr)
def evaluate_diffs(expr, var=None):
"""Replaces Diff nodes by sp.diff , the free variable is either the target (if var=None) otherwise
the specified var"""
if isinstance(expr, Diff):
if var is None:
var = expr.target
return sp.diff(evaluate_diffs(expr.arg, var), var)
else:
new_args = [evaluate_diffs(arg, var) for arg in expr.args]
return expr.func(*new_args) if new_args else expr
def functional_derivative(functional, v):
r"""
Computes functional derivative of functional with respect to v using Euler-Lagrange equation
r"""Computes functional derivative of functional with respect to v using Euler-Lagrange equation
.. math ::
......
......@@ -54,7 +54,7 @@ def create_kernel(assignments, target='cpu', data_type="double", iteration_slice
add_openmp(ast, num_threads=cpu_openmp)
if cpu_vectorize_info:
import pystencils.backends.simd_instruction_sets as vec
from pystencils.vectorization import vectorize
from pystencils.cpu.vectorization import vectorize
vec_params = cpu_vectorize_info
vec.selected_instruction_set = vec.x86_vector_instruction_set(instruction_set=vec_params[0],
data_type=vec_params[1])
......
import sympy as sp
from pystencils import Assignment, AssignmentCollection
from pystencils.assignment_collection import SimplificationStrategy, apply_on_all_subexpressions, \
subexpression_substitution_in_existing_subexpressions
def test_simplification_strategy():
a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
s0, s1, s2, s3 = sp.symbols("s_:4")
a0, a1, a2, a3 = sp.symbols("a_:4")
subexpressions = [
Assignment(s0, 2 * a + 2 * b),
Assignment(s1, 2 * a + 2 * b + 2 * c),
Assignment(s2, 2 * a + 2 * b + 2 * c + 2 * d),
]
main = [
Assignment(a0, s0 + s1),
Assignment(a1, s0 + s2),
Assignment(a2, s1 + s2),
]
ac = AssignmentCollection(main, subexpressions)
strategy = SimplificationStrategy()
strategy.add(subexpression_substitution_in_existing_subexpressions)
strategy.add(apply_on_all_subexpressions(sp.factor))
result = strategy(ac)
assert result.operation_count['adds'] == 7
assert result.operation_count['muls'] == 5
assert result.operation_count['divs'] == 0
# Trigger display routines, such that they are at least executed
report = strategy.show_intermediate_results(ac, symbols=[s0])
assert 's_0' in str(report)
report = strategy.show_intermediate_results(ac)
assert 's_{1}' in report._repr_html_()
report = strategy.create_simplification_report(ac)
assert 'Adds' in str(report)
assert 'Adds' in report._repr_html_()
assert 'factor' in str(strategy)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment