Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 3633 additions and 0 deletions
import ast
import inspect
import textwrap
from typing import Callable, Union, List, Dict, Tuple
import sympy as sp
from pystencils.assignment import Assignment
from pystencils.sympyextensions import SymbolCreator
from pystencils.config import CreateKernelConfig
__all__ = ['kernel', 'kernel_config']
def _kernel(func: Callable[..., None], **kwargs) -> Tuple[List[Assignment], str]:
"""
Convenient function for kernel decorator to prevent code duplication
Args:
func: decorated function
**kwargs: kwargs for the function
Returns:
assignments, function_name
"""
source = inspect.getsource(func)
source = textwrap.dedent(source)
a = ast.parse(source)
KernelFunctionRewrite().visit(a)
ast.fix_missing_locations(a)
gl = func.__globals__.copy()
assignments = []
def assignment_adder(lhs, rhs):
assignments.append(Assignment(lhs, rhs))
gl['_add_assignment'] = assignment_adder
gl['_Piecewise'] = sp.Piecewise
gl.update(inspect.getclosurevars(func).nonlocals)
exec(compile(a, filename="<ast>", mode="exec"), gl)
func = gl[func.__name__]
args = inspect.getfullargspec(func).args
if 's' in args and 's' not in kwargs:
kwargs['s'] = SymbolCreator()
func(**kwargs)
return assignments, func.__name__
def kernel(func: Callable[..., None], **kwargs) -> List[Assignment]:
"""Decorator to simplify generation of pystencils Assignments.
Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
in the result list. Furthermore the meaning of the ternary inline 'if-else' changes meaning to denote a
sympy Piecewise.
The decorated function may not receive any arguments, with exception of an argument called 's' that specifies
a SymbolCreator()
Args:
func: decorated function
**kwargs: kwargs for the function
Examples:
>>> import pystencils as ps
>>> @kernel
... def my_kernel(s):
... f, g = ps.fields('f, g: [2D]')
... s.neighbors @= f[0,1] + f[1,0]
... g[0,0] @= s.neighbors + f[0,0] if f[0,0] > 0 else 0
>>> f, g = ps.fields('f, g: [2D]')
>>> assert my_kernel[0].rhs == f[0,1] + f[1,0]
"""
assignments, _ = _kernel(func, **kwargs)
return assignments
def kernel_config(config: CreateKernelConfig, **kwargs) -> Callable[..., Dict]:
"""Decorator to simplify generation of pystencils Assignments, which takes a configuration
and updates the function name accordingly.
Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
in the result list. Furthermore, the meaning of the ternary inline 'if-else' changes meaning to denote a
sympy Piecewise.
The decorated function may not receive any arguments, with exception to an argument called 's' that specifies
a SymbolCreator()
Args:
config: Specify whether to return the list with assignments, or a dictionary containing additional settings
like func_name
Returns:
decorator with config
Examples:
>>> import pystencils as ps
>>> kernel_configuration = ps.CreateKernelConfig()
>>> @kernel_config(kernel_configuration)
... def my_kernel(s):
... src, dst = ps.fields('src, dst: [2D]')
... s.neighbors @= src[0, 1] + src[1, 0]
... dst[0, 0] @= s.neighbors + src[0, 0] if src[0, 0] > 0 else 0
>>> f, g = ps.fields('src, dst: [2D]')
>>> assert my_kernel['assignments'][0].rhs == f[0, 1] + f[1, 0]
"""
def decorator(func: Callable[..., None]) -> Union[List[Assignment], Dict]:
"""
Args:
func: decorated function
Returns:
Dict for unpacking into create_kernel
"""
assignments, func_name = _kernel(func, **kwargs)
config.function_name = func_name
return {'assignments': assignments, 'config': config}
return decorator
# noinspection PyMethodMayBeStatic
class KernelFunctionRewrite(ast.NodeTransformer):
def visit_IfExp(self, node):
piecewise_func = ast.Name(id='_Piecewise', ctx=ast.Load())
piecewise_func = ast.copy_location(piecewise_func, node)
piecewise_args = [ast.Tuple(elts=[node.body, node.test], ctx=ast.Load()),
ast.Tuple(elts=[node.orelse, ast.NameConstant(value=True)], ctx=ast.Load())]
result = ast.Call(func=piecewise_func, args=piecewise_args, keywords=[])
return ast.copy_location(result, node)
def visit_AugAssign(self, node):
self.generic_visit(node)
node.target.ctx = ast.Load()
new_node = ast.Expr(ast.Call(func=ast.Name(id='_add_assignment', ctx=ast.Load()),
args=[node.target, node.value],
keywords=[]))
return ast.copy_location(new_node, node)
def visit_FunctionDef(self, node):
self.generic_visit(node)
node.decorator_list = []
return node
import pystencils
class KernelWrapper:
"""
Light-weight wrapper around a compiled kernel.
Can be called while still providing access to underlying AST.
"""
def __init__(self, kernel, parameters, ast_node: pystencils.astnodes.KernelFunction):
self.kernel = kernel
self.parameters = parameters
self.ast = ast_node
self.num_regs = None
def __call__(self, **kwargs):
return self.kernel(**kwargs)
@property
def code(self):
return pystencils.get_code_str(self.ast)
import itertools
import warnings
from typing import Union, List
import sympy as sp
from pystencils.config import CreateKernelConfig
from pystencils.assignment import Assignment, AddAugmentedAssignment
from pystencils.astnodes import Node, Block, Conditional, LoopOverCoordinate, SympyAssignment
from pystencils.cpu.vectorization import vectorize
from pystencils.enums import Target, Backend
from pystencils.field import Field, FieldType
from pystencils.node_collection import NodeCollection
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.kernel_contrains_check import KernelConstraintsCheck
from pystencils.simplificationfactory import create_simplification_strategy
from pystencils.stencil import direction_string_to_offset, inverse_direction_string
from pystencils.transformations import (
loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel)
def create_kernel(assignments: Union[Assignment, List[Assignment],
AddAugmentedAssignment, List[AddAugmentedAssignment],
AssignmentCollection, List[Node], NodeCollection],
*,
config: CreateKernelConfig = None, **kwargs):
"""
Creates abstract syntax tree (AST) of kernel, using a list of update equations.
This function forms the general API and delegates the kernel creation to others depending on the CreateKernelConfig.
Args:
assignments: can be a single assignment, sequence of assignments or an `AssignmentCollection`
config: CreateKernelConfig which includes the needed configuration
kwargs: Arguments for updating the config
Returns:
abstract syntax tree (AST) object, that can either be printed as source code with `show_code` or
can be compiled with through its 'compile()' member
Example:
>>> import pystencils as ps
>>> import numpy as np
>>> s, d = ps.fields('s, d: [2D]')
>>> assignment = ps.Assignment(d[0,0], s[0, 1] + s[0, -1] + s[1, 0] + s[-1, 0])
>>> kernel_ast = ps.create_kernel(assignment, config=ps.CreateKernelConfig(cpu_openmp=True))
>>> kernel = kernel_ast.compile()
>>> d_arr = np.zeros([5, 5])
>>> kernel(d=d_arr, s=np.ones([5, 5]))
>>> d_arr
array([[0., 0., 0., 0., 0.],
[0., 4., 4., 4., 0.],
[0., 4., 4., 4., 0.],
[0., 4., 4., 4., 0.],
[0., 0., 0., 0., 0.]])
"""
# ---- Updating configuration from kwargs
if not config:
config = CreateKernelConfig(**kwargs)
else:
for k, v in kwargs.items():
if not hasattr(config, k):
raise KeyError(f'{v} is not a valid kwarg. Please look in CreateKernelConfig for valid settings')
setattr(config, k, v)
# ---- Normalizing parameters
if isinstance(assignments, (Assignment, AddAugmentedAssignment)):
assignments = [assignments]
assert assignments, "Assignments must not be empty!"
if isinstance(assignments, list):
assignments = NodeCollection(assignments)
elif isinstance(assignments, AssignmentCollection):
# TODO Markus check and doku
# --- applying first default simplifications
try:
if config.default_assignment_simplifications:
simplification = create_simplification_strategy()
assignments = simplification(assignments)
except Exception as e:
warnings.warn(f"It was not possible to apply the default pystencils optimisations to the "
f"AssignmentCollection due to the following problem :{e}")
simplification_hints = assignments.simplification_hints
assignments = NodeCollection.from_assignment_collection(assignments)
assignments.simplification_hints = simplification_hints
if config.index_fields:
return create_indexed_kernel(assignments, config=config)
else:
return create_domain_kernel(assignments, config=config)
def create_domain_kernel(assignments: NodeCollection, *, config: CreateKernelConfig):
"""
Creates abstract syntax tree (AST) of kernel, using a NodeCollection.
Note that `create_domain_kernel` is a lower level function which shoul be accessed by not providing `index_fields`
to create_kernel
Args:
assignments: `pystencils.node_collection.NodeCollection` containing all assignements and nodes to be processed
config: CreateKernelConfig which includes the needed configuration
Returns:
abstract syntax tree (AST) object, that can either be printed as source code with `show_code` or
can be compiled with through its 'compile()' member
Example:
>>> import pystencils as ps
>>> import numpy as np
>>> from pystencils.kernelcreation import create_domain_kernel
>>> from pystencils.node_collection import NodeCollection
>>> s, d = ps.fields('s, d: [2D]')
>>> assignment = ps.Assignment(d[0,0], s[0, 1] + s[0, -1] + s[1, 0] + s[-1, 0])
>>> kernel_config = ps.CreateKernelConfig(cpu_openmp=True)
>>> kernel_ast = create_domain_kernel(NodeCollection([assignment]), config=kernel_config)
>>> kernel = kernel_ast.compile()
>>> d_arr = np.zeros([5, 5])
>>> kernel(d=d_arr, s=np.ones([5, 5]))
>>> d_arr
array([[0., 0., 0., 0., 0.],
[0., 4., 4., 4., 0.],
[0., 4., 4., 4., 0.],
[0., 4., 4., 4., 0.],
[0., 0., 0., 0., 0.]])
"""
# --- eval
assignments.evaluate_terms()
# FUTURE WORK from here we shouldn't NEED sympy
# --- check constrains
check = KernelConstraintsCheck(check_independence_condition=not config.skip_independence_check,
check_double_write_condition=not config.allow_double_writes)
check.visit(assignments)
assignments.bound_fields = check.fields_written
assignments.rhs_fields = check.fields_read
# ---- Creating ast
ast = None
if config.target == Target.CPU:
if config.backend == Backend.C:
from pystencils.cpu import add_openmp, create_kernel
ast = create_kernel(assignments, config=config)
for optimization in config.cpu_prepend_optimizations:
optimization(ast)
omp_collapse = None
if config.cpu_blocking:
omp_collapse = loop_blocking(ast, config.cpu_blocking)
if config.cpu_openmp:
add_openmp(ast, num_threads=config.cpu_openmp, collapse=omp_collapse,
assume_single_outer_loop=config.omp_single_loop)
if config.cpu_vectorize_info:
if config.cpu_vectorize_info is True:
vectorize(ast)
elif isinstance(config.cpu_vectorize_info, dict):
vectorize(ast, **config.cpu_vectorize_info)
if config.cpu_openmp and config.cpu_blocking and 'nontemporal' in config.cpu_vectorize_info and \
config.cpu_vectorize_info['nontemporal'] and 'cachelineZero' in ast.instruction_set:
# This condition is stricter than it needs to be: if blocks along the fastest axis start on a
# cache line boundary, it's okay. But we cannot determine that here.
# We don't need to disallow OpenMP collapsing because it is never applied to the inner loop.
raise ValueError("Blocking cannot be combined with cacheline-zeroing")
else:
raise ValueError("Invalid value for cpu_vectorize_info")
elif config.target == Target.GPU:
if config.backend == Backend.CUDA:
from pystencils.gpu import create_cuda_kernel
ast = create_cuda_kernel(assignments, config=config)
if not ast:
raise NotImplementedError(
f'{config.target} together with {config.backend} is not supported by `create_domain_kernel`')
if config.use_auto_for_assignments:
for a in ast.atoms(SympyAssignment):
a.use_auto = True
return ast
def create_indexed_kernel(assignments: NodeCollection, *, config: CreateKernelConfig):
"""
Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling.
The coordinates are stored in a separated index_field, which is a one dimensional array with struct data type.
This struct has to contain fields named 'x', 'y' and for 3D fields ('z'). These names are configurable with the
'coordinate_names' parameter. The struct can have also other fields that can be read and written in the kernel, for
example boundary parameters.
Note that `create_indexed_kernel` is a lower level function which shoul be accessed by providing `index_fields`
to create_kernel
Args:
assignments: `pystencils.node_collection.NodeCollection` containing all assignements and nodes to be processed
config: CreateKernelConfig which includes the needed configuration
Returns:
abstract syntax tree (AST) object, that can either be printed as source code with `show_code` or
can be compiled with through its 'compile()' member
Example:
>>> import pystencils as ps
>>> from pystencils.node_collection import NodeCollection
>>> import numpy as np
>>> from pystencils.kernelcreation import create_indexed_kernel
>>>
>>> # Index field stores the indices of the cell to visit together with optional values
>>> index_arr_dtype = np.dtype([('x', np.int32), ('y', np.int32), ('val', np.double)])
>>> index_arr = np.array([(1, 1, 0.1), (2, 2, 0.2), (3, 3, 0.3)], dtype=index_arr_dtype)
>>> idx_field = ps.fields(idx=index_arr)
>>>
>>> # Additional values stored in index field can be accessed in the kernel as well
>>> s, d = ps.fields('s, d: [2D]')
>>> assignment = ps.Assignment(d[0, 0], 2 * s[0, 1] + 2 * s[1, 0] + idx_field('val'))
>>> kernel_config = ps.CreateKernelConfig(index_fields=[idx_field], coordinate_names=('x', 'y'))
>>> kernel_ast = create_indexed_kernel(NodeCollection([assignment]), config=kernel_config)
>>> kernel = kernel_ast.compile()
>>> d_arr = np.zeros([5, 5])
>>> kernel(s=np.ones([5, 5]), d=d_arr, idx=index_arr)
>>> d_arr
array([[0. , 0. , 0. , 0. , 0. ],
[0. , 4.1, 0. , 0. , 0. ],
[0. , 0. , 4.2, 0. , 0. ],
[0. , 0. , 0. , 4.3, 0. ],
[0. , 0. , 0. , 0. , 0. ]])
"""
# --- eval
assignments.evaluate_terms()
# FUTURE WORK from here we shouldn't NEED sympy
# --- check constrains
check = KernelConstraintsCheck(check_independence_condition=not config.skip_independence_check,
check_double_write_condition=not config.allow_double_writes)
check.visit(assignments)
assignments.bound_fields = check.fields_written
assignments.rhs_fields = check.fields_read
ast = None
if config.target == Target.CPU and config.backend == Backend.C:
from pystencils.cpu import add_openmp, create_indexed_kernel
ast = create_indexed_kernel(assignments, config=config)
if config.cpu_openmp:
add_openmp(ast, num_threads=config.cpu_openmp)
elif config.target == Target.GPU:
if config.backend == Backend.CUDA:
from pystencils.gpu import created_indexed_cuda_kernel
ast = created_indexed_cuda_kernel(assignments, config=config)
if not ast:
raise NotImplementedError(f'Indexed kernels are not yet supported for {config.target} with {config.backend}')
return ast
def create_staggered_kernel(assignments, target: Target = Target.CPU, gpu_exclusive_conditions=False, **kwargs):
"""Kernel that updates a staggered field.
.. image:: /img/staggered_grid.svg
For a staggered field, the first index coordinate defines the location of the staggered value.
Further index coordinates can be used to store vectors/tensors at each point.
Args:
assignments: a sequence of assignments or an AssignmentCollection.
Assignments to staggered field are processed specially, while subexpressions and assignments to
regular fields are passed through to `create_kernel`. Multiple different staggered fields can be
used, but they all need to use the same stencil (i.e. the same number of staggered points) and
shape.
target: 'CPU' or 'GPU'
gpu_exclusive_conditions: disable the use of multiple conditionals inside the loop. The outer layers are then
handled in an else branch.
kwargs: passed directly to create_kernel, iteration_slice and ghost_layers parameters are not allowed
Returns:
AST, see `create_kernel`
"""
# TODO: Add doku like in the other kernels
if 'ghost_layers' in kwargs:
assert kwargs['ghost_layers'] is None
del kwargs['ghost_layers']
if 'iteration_slice' in kwargs:
assert kwargs['iteration_slice'] is None
del kwargs['iteration_slice']
if 'omp_single_loop' in kwargs:
assert kwargs['omp_single_loop'] is False
del kwargs['omp_single_loop']
if isinstance(assignments, AssignmentCollection):
subexpressions = assignments.subexpressions + [a for a in assignments.main_assignments
if not hasattr(a, 'lhs')
or type(a.lhs) is not Field.Access
or not FieldType.is_staggered(a.lhs.field)]
assignments = [a for a in assignments.main_assignments if hasattr(a, 'lhs')
and type(a.lhs) is Field.Access
and FieldType.is_staggered(a.lhs.field)]
else:
subexpressions = [a for a in assignments if not hasattr(a, 'lhs')
or type(a.lhs) is not Field.Access
or not FieldType.is_staggered(a.lhs.field)]
assignments = [a for a in assignments if hasattr(a, 'lhs')
and type(a.lhs) is Field.Access
and FieldType.is_staggered(a.lhs.field)]
if len(set([tuple(a.lhs.field.staggered_stencil) for a in assignments])) != 1:
raise ValueError("All assignments need to be made to staggered fields with the same stencil")
if len(set([a.lhs.field.shape for a in assignments])) != 1:
raise ValueError("All assignments need to be made to staggered fields with the same shape")
staggered_field = assignments[0].lhs.field
stencil = staggered_field.staggered_stencil
dim = staggered_field.spatial_dimensions
shape = staggered_field.shape
counters = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(dim)]
final_assignments = []
# find out whether any of the ghost layers is not needed
common_exclusions = set(["E", "W", "N", "S", "T", "B"][:2 * dim])
for direction in stencil:
exclusions = set(["E", "W", "N", "S", "T", "B"][:2 * dim])
for elementary_direction in direction:
exclusions.remove(inverse_direction_string(elementary_direction))
common_exclusions.intersection_update(exclusions)
ghost_layers = [[0, 0] for d in range(dim)]
for direction in common_exclusions:
direction = direction_string_to_offset(direction)
for d, s in enumerate(direction):
if s == 1:
ghost_layers[d][1] = 1
elif s == -1:
ghost_layers[d][0] = 1
def condition(direction):
"""exclude those staggered points that correspond to fluxes between ghost cells"""
exclusions = set(["E", "W", "N", "S", "T", "B"][:2 * dim])
for elementary_direction in direction:
exclusions.remove(inverse_direction_string(elementary_direction))
conditions = []
for e in exclusions:
if e in common_exclusions:
continue
offset = direction_string_to_offset(e)
for i, o in enumerate(offset):
if o == 1:
conditions.append(counters[i] < shape[i] - 1)
elif o == -1:
conditions.append(counters[i] > 0)
return sp.And(*conditions)
if gpu_exclusive_conditions:
outer_assignment = None
conditions = {direction: condition(direction) for direction in stencil}
for num_conditions in range(len(stencil)):
for combination in itertools.combinations(conditions.values(), num_conditions):
for assignment in assignments:
direction = stencil[assignment.lhs.index[0]]
if conditions[direction] in combination:
assignment = SympyAssignment(assignment.lhs, assignment.rhs)
outer_assignment = Conditional(sp.And(*combination), Block([assignment]), outer_assignment)
inner_assignment = []
for assignment in assignments:
inner_assignment.append(SympyAssignment(assignment.lhs, assignment.rhs))
last_conditional = Conditional(sp.And(*[condition(d) for d in stencil]),
Block(inner_assignment), outer_assignment)
final_assignments = [s for s in subexpressions if not hasattr(s, 'lhs')] + \
[SympyAssignment(s.lhs, s.rhs) for s in subexpressions if hasattr(s, 'lhs')] + \
[last_conditional]
config = CreateKernelConfig(target=target, ghost_layers=ghost_layers, omp_single_loop=False, **kwargs)
ast = create_kernel(final_assignments, config=config)
return ast
for assignment in assignments:
direction = stencil[assignment.lhs.index[0]]
sp_assignments = [s for s in subexpressions if not hasattr(s, 'lhs')] + \
[SympyAssignment(s.lhs, s.rhs) for s in subexpressions if hasattr(s, 'lhs')] + \
[SympyAssignment(assignment.lhs, assignment.rhs)]
last_conditional = Conditional(condition(direction), Block(sp_assignments))
final_assignments.append(last_conditional)
remove_start_conditional = any([gl[0] == 0 for gl in ghost_layers])
prepend_optimizations = [lambda ast: remove_conditionals_in_staggered_kernel(ast, remove_start_conditional),
move_constants_before_loop]
if 'cpu_prepend_optimizations' in kwargs:
prepend_optimizations += kwargs['cpu_prepend_optimizations']
del kwargs['cpu_prepend_optimizations']
config = CreateKernelConfig(ghost_layers=ghost_layers, target=target, omp_single_loop=False,
cpu_prepend_optimizations=prepend_optimizations, **kwargs)
ast = create_kernel(final_assignments, config=config)
return ast
from typing import Any, Dict, List, Union, Optional, Set
import sympy
import sympy as sp
from sympy.codegen.rewriting import ReplaceOptim, optimize
from pystencils.assignment import Assignment, AddAugmentedAssignment
import pystencils.astnodes as ast
from pystencils.backends.cbackend import CustomCodeNode
from pystencils.functions import DivFunc
from pystencils.simp import AssignmentCollection
from pystencils.typing import FieldPointerSymbol
class NodeCollection:
def __init__(self, assignments: List[Union[ast.Node, Assignment]],
simplification_hints: Optional[Dict[str, Any]] = None,
bound_fields: Set[sp.Symbol] = None, rhs_fields: Set[sp.Symbol] = None):
def visit(obj):
if isinstance(obj, (list, tuple)):
return [visit(e) for e in obj]
if isinstance(obj, Assignment):
if isinstance(obj.lhs, FieldPointerSymbol):
return ast.SympyAssignment(obj.lhs, obj.rhs, is_const=obj.lhs.dtype.const)
return ast.SympyAssignment(obj.lhs, obj.rhs)
elif isinstance(obj, AddAugmentedAssignment):
return ast.SympyAssignment(obj.lhs, obj.lhs + obj.rhs)
elif isinstance(obj, ast.SympyAssignment):
return obj
elif isinstance(obj, ast.Conditional):
true_block = visit(obj.true_block)
false_block = None if obj.false_block is None else visit(obj.false_block)
return ast.Conditional(obj.condition_expr, true_block=true_block, false_block=false_block)
elif isinstance(obj, ast.Block):
return ast.Block([visit(e) for e in obj.args])
elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
return obj
else:
raise ValueError("Invalid object in the List of Assignments " + str(type(obj)))
self.all_assignments = visit(assignments)
self.simplification_hints = simplification_hints if simplification_hints else {}
self.bound_fields = bound_fields if bound_fields else {}
self.rhs_fields = rhs_fields if rhs_fields else {}
@staticmethod
def from_assignment_collection(assignment_collection: AssignmentCollection):
return NodeCollection(assignments=assignment_collection.all_assignments,
simplification_hints=assignment_collection.simplification_hints,
bound_fields=assignment_collection.bound_fields,
rhs_fields=assignment_collection.rhs_fields)
def evaluate_terms(self):
evaluate_constant_terms = ReplaceOptim(
lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
lambda p: p.evalf()
)
evaluate_pow = ReplaceOptim(
lambda e: e.is_Pow and e.exp.is_Integer and abs(e.exp) <= 8,
lambda p: sp.UnevaluatedExpr(sp.Mul(*([p.base] * +p.exp), evaluate=False)) if p.exp > 0 else
(DivFunc(sp.Integer(1), p.base) if p.exp == -1 else
DivFunc(sp.Integer(1), sp.UnevaluatedExpr(sp.Mul(*([p.base] * -p.exp), evaluate=False))))
)
sympy_optimisations = [evaluate_constant_terms, evaluate_pow]
def visitor(node):
if isinstance(node, CustomCodeNode):
return node
elif isinstance(node, ast.Block):
return node.func([visitor(child) for child in node.args])
elif isinstance(node, ast.SympyAssignment):
new_lhs = visitor(node.lhs)
new_rhs = visitor(node.rhs)
return node.func(new_lhs, new_rhs, node.is_const, node.use_auto)
elif isinstance(node, ast.Node):
return node.func(*[visitor(child) for child in node.args])
elif isinstance(node, sympy.Basic):
return optimize(node, sympy_optimisations)
else:
raise NotImplementedError(f'{node} {type(node)} has no valid visitor')
self.all_assignments = [visitor(assignment) for assignment in self.all_assignments]
from typing import List
import sympy as sp
from pystencils.assignment import Assignment
from pystencils.astnodes import Node
from pystencils.sympyextensions import is_constant
from pystencils.transformations import generic_visit
class PlaceholderFunction:
pass
def to_placeholder_function(expr, name):
"""Replaces an expression by a sympy function.
- replacing an expression with just a symbol would lead to problem when calculating derivatives
- placeholder functions get rid of this problem
Examples:
>>> x, t = sp.symbols("x, t")
>>> temperature = x**2 + t**4 # some 'complicated' dependency
>>> temperature_placeholder = to_placeholder_function(temperature, 'T')
>>> diffusivity = temperature_placeholder + 42 * t
>>> sp.diff(diffusivity, t) # returns a symbol instead of the computed derivative
_dT_dt + 42
>>> result, subexpr = remove_placeholder_functions(diffusivity)
>>> result
T + 42*t
>>> subexpr
[Assignment(T, t**4 + x**2), Assignment(_dT_dt, 4*t**3), Assignment(_dT_dx, 2*x)]
"""
symbols = list(expr.atoms(sp.Symbol))
symbols.sort(key=lambda e: e.name)
derivative_symbols = [sp.Symbol(f"_d{name}_d{s.name}") for s in symbols]
derivatives = [sp.diff(expr, s) for s in symbols]
assignments = [Assignment(sp.Symbol(name), expr)]
assignments += [Assignment(symbol, derivative)
for symbol, derivative in zip(derivative_symbols, derivatives)
if not is_constant(derivative)]
def fdiff(_, index):
result = derivatives[index - 1]
return result if is_constant(result) else derivative_symbols[index - 1]
func = type(name, (sp.Function, PlaceholderFunction),
{'fdiff': fdiff,
'value': sp.Symbol(name),
'subexpressions': assignments,
'nargs': len(symbols)})
return func(*symbols)
def remove_placeholder_functions(expr):
subexpressions = []
def visit(e):
if isinstance(e, Node):
return e
elif isinstance(e, PlaceholderFunction):
for se in e.subexpressions:
if se.lhs not in {a.lhs for a in subexpressions}:
subexpressions.append(se)
return e.value
else:
new_args = [visit(a) for a in e.args]
return e.func(*new_args) if new_args else e
return generic_visit(expr, visit), subexpressions
def prepend_placeholder_functions(assignments: List[Assignment]):
result, subexpressions = remove_placeholder_functions(assignments)
return subexpressions + result
"""
This module extends the pyplot module with functions to show scalar and vector fields in the usual
simulation coordinate system (y-axis goes up), instead of the "image coordinate system" (y axis goes down) that
matplotlib normally uses.
"""
import warnings
from itertools import cycle
from matplotlib.pyplot import *
def vector_field(array, step=2, **kwargs):
"""Plots given vector field as quiver (arrow) plot.
Args:
array: numpy array with 3 dimensions, first two are spatial x,y coordinate, the last
coordinate should have shape 2 and stores the 2 velocity components
step: plots only every steps's cell, increase the step for high resolution arrays
kwargs: keyword arguments passed to :func:`matplotlib.pyplot.quiver`
Returns:
quiver plot object
"""
assert len(array.shape) == 3, "Wrong shape of array - did you forget to slice your 3D domain first?"
assert array.shape[2] == 2, "Last array dimension is expected to store 2D vectors"
vel_n = array.swapaxes(0, 1)
res = quiver(vel_n[::step, ::step, 0], vel_n[::step, ::step, 1], **kwargs)
axis('equal')
return res
def vector_field_magnitude(array, **kwargs):
"""Plots the magnitude of a vector field as colormap.
Args:
array: numpy array with 3 dimensions, first two are spatial x,y coordinate, the last
coordinate should have shape 2 and stores the 2 velocity components
kwargs: keyword arguments passed to :func:`matplotlib.pyplot.imshow`
Returns:
imshow object
"""
assert len(array.shape) == 3, "Wrong shape of array - did you forget to slice your 3D domain first?"
assert array.shape[2] in (2, 3), "Wrong size of the last coordinate. Has to be a 2D or 3D vector field."
from numpy.linalg import norm
norm = norm(array, axis=2, ord=2)
if hasattr(array, 'mask'):
norm = np.ma.masked_array(norm, mask=array.mask[:, :, 0])
return scalar_field(norm, **kwargs)
def scalar_field(array, **kwargs):
"""Plots field values as colormap.
Works just as imshow, but uses coordinate system where second coordinate (y) points upwards.
Args:
array: two dimensional numpy array
kwargs: keyword arguments passed to :func:`matplotlib.pyplot.imshow`
Returns:
imshow object
"""
import numpy
array = numpy.swapaxes(array, 0, 1)
res = imshow(array, origin='lower', **kwargs)
axis('equal')
return res
def scalar_field_surface(array, **kwargs):
"""Plots scalar field as 3D surface
Args:
array: the two dimensional numpy array to plot
kwargs: keyword arguments passed to :func:`mpl_toolkits.mplot3d.Axes3D.plot_surface`
"""
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
fig = gcf()
ax = fig.add_subplot(111, projection='3d')
x, y = np.meshgrid(np.arange(array.shape[0]), np.arange(array.shape[1]), indexing='ij')
kwargs.setdefault('rstride', 2)
kwargs.setdefault('cstride', 2)
kwargs.setdefault('color', 'b')
kwargs.setdefault('cmap', cm.coolwarm)
return ax.plot_surface(x, y, array, **kwargs)
def scalar_field_alpha_value(array, color, clip=False, **kwargs):
"""Plots an image with same color everywhere, using the array values as transparency.
Array is supposed to have values between 0 and 1 (if this is not the case it is normalized).
An image is plotted that has the same color everywhere, the passed array determines the transparency.
Regions where the array is 1 are fully opaque, areas with 0 are fully transparent.
Args:
array: 2D array with alpha values
color: fill color
clip: if True, all values in the array larger than 1 are set to 1, all values smaller than 0 are set to zero
if False, the array is linearly scaled to the [0, 1] interval
**kwargs: arguments passed to imshow
Returns:
imshow object
"""
import numpy
import matplotlib
assert len(array.shape) == 2, "Wrong shape of array - did you forget to slice your 3D domain first?"
array = numpy.swapaxes(array, 0, 1)
if clip:
normalized_field = array.copy()
normalized_field[normalized_field < 0] = 0
normalized_field[normalized_field > 1] = 1
else:
minimum, maximum = numpy.min(array), numpy.max(array)
normalized_field = (array - minimum) / (maximum - minimum)
color = matplotlib.colors.to_rgba(color)
field_to_plot = numpy.empty(array.shape + (4,))
# set the complete array to the color
for i in range(3):
field_to_plot[:, :, i] = color[i]
# only the alpha channel varies using the array values
field_to_plot[:, :, 3] = normalized_field
res = imshow(field_to_plot, origin='lower', **kwargs)
axis('equal')
return res
def scalar_field_contour(array, **kwargs):
"""Small wrapper around contour to transform the coordinate system.
For details see :func:`matplotlib.pyplot.imshow`
"""
array = np.swapaxes(array, 0, 1)
res = contour(array, **kwargs)
axis('equal')
return res
def multiple_scalar_fields(array, **kwargs):
"""Plots a 3D array by slicing the last dimension and creates on plot for each entry of the last dimension.
Args:
array: 3D array to plot.
**kwargs: passed along to imshow
"""
assert len(array.shape) == 3
sub_plots = array.shape[-1]
for i in range(sub_plots):
subplot(1, sub_plots, i + 1)
title(str(i))
scalar_field(array[..., i], **kwargs)
colorbar()
def phase_plot(phase_field: np.ndarray, linewidth=1.0, clip=True) -> None:
"""Plots a phase field array using the phase variables as alpha channel.
Args:
phase_field: array with len(shape) == 3, first two dimensions are spatial, the last one indexes the phase
components.
linewidth: line width of the 0.5 contour lines that are drawn over the alpha blended phase images
clip: see scalar_field_alpha_value function
"""
color_cycle = cycle(['#fe0002', '#00fe00', '#0000ff', '#ffa800', '#f600ff'])
assert len(phase_field.shape) == 3
with warnings.catch_warnings():
warnings.simplefilter("ignore")
for i in range(phase_field.shape[-1]):
scalar_field_alpha_value(phase_field[..., i], next(color_cycle), clip=clip, interpolation='bilinear')
if linewidth:
for i in range(phase_field.shape[-1]):
scalar_field_contour(phase_field[..., i], levels=[0.5], colors='k', linewidths=[linewidth])
def sympy_function(expr, x_values=None, **kwargs):
"""Plots the graph of a sympy term that depends on one symbol only.
Args:
expr: sympy term that depends on one symbol only, which is plotted on the x axis
x_values: describes sampling of x axis. Possible values are:
* tuple of (start, stop) or (start, stop, nr_of_steps)
* None, then start=0, stop=1, nr_of_steps=100
* 1D numpy array with x values
**kwargs: passed on to :func:`matplotlib.pyplot.plot`
Returns:
plot object
"""
import sympy as sp
if x_values is None:
x_arr = np.linspace(0, 1, 100)
elif type(x_values) is tuple:
x_arr = np.linspace(*x_values)
elif isinstance(x_values, np.ndarray):
assert len(x_values.shape) == 1
x_arr = x_values
else:
raise ValueError("Invalid value for parameter x_values")
symbols = expr.atoms(sp.Symbol)
assert len(symbols) == 1, "Sympy expression may only depend on one variable only. Depends on " + str(symbols)
y_arr = sp.lambdify(symbols.pop(), expr)(x_arr)
return plot(x_arr, y_arr, **kwargs)
# ------------------------------------------- Animations ---------------------------------------------------------------
def __scale_array(arr):
from numpy.linalg import norm
norm_arr = norm(arr, axis=2, ord=2)
if isinstance(arr, np.ma.MaskedArray):
norm_arr = np.ma.masked_array(norm_arr, arr.mask[..., 0])
return arr / norm_arr.max()
def vector_field_animation(run_function, step=2, rescale=True, plot_setup_function=lambda *_: None,
plot_update_function=lambda *_: None, interval=200, frames=180, **kwargs):
"""Creates a matplotlib animation of a vector field using a quiver plot.
Args:
run_function: callable without arguments, returning a 2D vector field i.e. numpy array with len(shape)==3
step: see documentation of vector_field function
rescale: if True, the length of the arrows is rescaled in every time step
plot_setup_function: optional callable with the quiver object as argument,
that can be used to set up the plot (title, legend,..)
plot_update_function: optional callable with the quiver object as argument
that is called of the quiver object was updated
interval: delay between frames in milliseconds (see matplotlib.FuncAnimation)
frames: how many frames should be generated, see matplotlib.FuncAnimation
**kwargs: passed to quiver plot
Returns:
matplotlib animation object
"""
import matplotlib.animation as animation
fig = gcf()
im = None
field = run_function()
if rescale:
field = __scale_array(field)
kwargs.setdefault('scale', 0.6)
kwargs.setdefault('angles', 'xy')
kwargs.setdefault('scale_units', 'xy')
quiver_plot = vector_field(field, step=step, **kwargs)
plot_setup_function(quiver_plot)
def update_figure(*_):
f = run_function()
f = np.swapaxes(f, 0, 1)
if rescale:
f = __scale_array(f)
u, v = f[::step, ::step, 0], f[::step, ::step, 1]
quiver_plot.set_UVC(u, v)
plot_update_function(quiver_plot)
return im,
return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames)
def vector_field_magnitude_animation(run_function, plot_setup_function=lambda *_: None, rescale=False,
plot_update_function=lambda *_: None, interval=30, frames=180, **kwargs):
"""Animation of a vector field, showing the magnitude as colormap.
For arguments, see vector_field_animation
"""
import matplotlib.animation as animation
from numpy.linalg import norm
fig = gcf()
im = None
field = run_function()
if rescale:
field = __scale_array(field)
im = vector_field_magnitude(field, **kwargs)
plot_setup_function(im)
def update_figure(*_):
f = run_function()
if rescale:
f = __scale_array(f)
normed = norm(f, axis=2, ord=2)
if hasattr(f, 'mask'):
normed = np.ma.masked_array(normed, mask=f.mask[:, :, 0])
normed = np.swapaxes(normed, 0, 1)
im.set_array(normed)
plot_update_function(im)
return im,
return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames)
def scalar_field_animation(run_function, plot_setup_function=lambda *_: None, rescale=True,
plot_update_function=lambda *_: None, interval=30, frames=180, **kwargs):
"""Animation of scalar field as colored image, see `scalar_field`."""
import matplotlib.animation as animation
fig = gcf()
im = None
field = run_function()
if rescale:
f_min, f_max = np.min(field), np.max(field)
field = (field - f_min) / (f_max - f_min)
im = scalar_field(field, vmin=0.0, vmax=1.0, **kwargs)
else:
im = scalar_field(field, **kwargs)
plot_setup_function(im)
def update_figure(*_):
f = run_function()
if rescale:
f_min, f_max = np.min(f), np.max(f)
f = (f - f_min) / (f_max - f_min)
if hasattr(f, 'mask'):
f = np.ma.masked_array(f, mask=f.mask[:, :])
f = np.swapaxes(f, 0, 1)
im.set_array(f)
plot_update_function(im)
return im,
return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames)
def surface_plot_animation(run_function, frames=90, interval=30, zlim=None, **kwargs):
"""Animation of scalar field as 3D plot."""
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
from matplotlib import cm
fig = gcf()
ax = fig.add_subplot(111, projection='3d')
data = run_function()
x, y = np.meshgrid(np.arange(data.shape[0]), np.arange(data.shape[1]), indexing='ij')
kwargs.setdefault('rstride', 2)
kwargs.setdefault('cstride', 2)
kwargs.setdefault('color', 'b')
kwargs.setdefault('cmap', cm.coolwarm)
ax.plot_surface(x, y, data, **kwargs)
if zlim is not None:
ax.set_zlim(*zlim)
def update_figure(*_):
d = run_function()
ax.clear()
plot = ax.plot_surface(x, y, d, **kwargs)
if zlim is not None:
ax.set_zlim(*zlim)
return plot,
return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames, blit=False)
import copy
import numpy as np
import sympy as sp
from pystencils.typing import TypedSymbol, CastFunc
from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import CustomCodeNode
from pystencils.sympyextensions import fast_subs
class RNGBase(CustomCodeNode):
id = 0
def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=None, keys=None):
if keys is None:
keys = (0,) * self._num_keys
if offsets is None:
offsets = (0,) * dim
if len(keys) != self._num_keys:
raise ValueError(f"Provided {len(keys)} keys but need {self._num_keys}")
if len(offsets) != dim:
raise ValueError(f"Provided {len(offsets)} offsets but need {dim}")
coordinates = [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] for i in range(dim)]
if dim < 3:
coordinates.append(0)
self._args = sp.sympify([time_step, *coordinates, *keys])
self.result_symbols = tuple(TypedSymbol(f'random_{self.id}_{i}', self._data_type)
for i in range(self._num_vars))
symbols_read = set.union(*[s.atoms(sp.Symbol) for s in self.args])
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self.headers = [f'"{self._name.split("_")[0]}_rand.h"']
RNGBase.id += 1
@property
def args(self):
return self._args
def fast_subs(self, subs_dict, skip):
rng = copy.deepcopy(self)
rng._args = [fast_subs(a, subs_dict, skip) for a in rng._args]
return rng
def get_code(self, dialect, vector_instruction_set, print_arg):
code = "\n"
for r in self.result_symbols:
if vector_instruction_set and not self.args[1].atoms(CastFunc):
# this vector RNG has become scalar through substitution
code += f"{r.dtype} {r.name};\n"
else:
code += f"{vector_instruction_set[r.dtype.c_name] if vector_instruction_set else r.dtype} " + \
f"{r.name};\n"
args = [print_arg(a) for a in self.args] + ['' + r.name for r in self.result_symbols]
code += (self._name + "(" + ", ".join(args) + ");\n")
return code
def __repr__(self):
return ", ".join([str(s) for s in self.result_symbols]) + " \\leftarrow " + \
self._name.capitalize() + "_RNG(" + ", ".join([str(a) for a in self.args]) + ")"
def _hashable_content(self):
return (self._name, *self.result_symbols, *self.args)
def __eq__(self, other):
return type(self) is type(other) and self._hashable_content() == other._hashable_content()
def __hash__(self):
return hash(self._hashable_content())
class PhiloxTwoDoubles(RNGBase):
_name = "philox_double2"
_data_type = np.float64
_num_vars = 2
_num_keys = 2
class PhiloxFourFloats(RNGBase):
_name = "philox_float4"
_data_type = np.float32
_num_vars = 4
_num_keys = 2
class AESNITwoDoubles(RNGBase):
_name = "aesni_double2"
_data_type = np.float64
_num_vars = 2
_num_keys = 4
class AESNIFourFloats(RNGBase):
_name = "aesni_float4"
_data_type = np.float32
_num_vars = 4
_num_keys = 4
def random_symbol(assignment_list, dim, seed=TypedSymbol("seed", np.uint32), rng_node=PhiloxTwoDoubles,
time_step=TypedSymbol("time_step", np.uint32), offsets=None):
"""Return a symbol generator for random numbers
Args:
assignment_list: the subexpressions member of an AssignmentCollection, into which helper variables assignments
will be inserted
dim: 2 or 3 for two or three spatial dimensions
seed: an integer or TypedSymbol(..., np.uint32) to seed the random number generator. If you create multiple
symbol generators, please pass them different seeds so you don't get the same stream of random numbers!
rng_node: which random number generator to use (PhiloxTwoDoubles, PhiloxFourFloats, AESNITwoDoubles,
AESNIFourFloats).
time_step: TypedSymbol(..., np.uint32) that indicates the number of the current time step
offsets: tuple of offsets (constant integers or TypedSymbol(..., np.uint32)) that give the global coordinates
of the local origin
"""
counter = 0
while True:
keys = (counter, seed) + (0,) * (rng_node._num_keys - 2)
node = rng_node(dim, keys=keys, time_step=time_step, offsets=offsets)
inserted = False
for symbol in node.result_symbols:
if not inserted:
assignment_list.insert(0, node)
inserted = True
yield symbol
counter += 1
from pystencils.runhelper.db import Database
from pystencils.runhelper.parameterstudy import ParameterStudy
__all__ = ['Database', 'ParameterStudy']
import socket
import time
from types import MappingProxyType
from typing import Dict, Iterator, Sequence
import blitzdb
import six
from blitzdb.backends.file.backend import serializer_classes
from blitzdb.backends.file.utils import JsonEncoder
from pystencils.cpu.cpujit import get_compiler_config
from pystencils import CreateKernelConfig, Target, Backend, Field
import json
import sympy as sp
from pystencils.typing import BasicType
class PystencilsJsonEncoder(JsonEncoder):
def default(self, obj):
if isinstance(obj, CreateKernelConfig):
return obj.__dict__
if isinstance(obj, (sp.Float, sp.Rational)):
return float(obj)
if isinstance(obj, sp.Integer):
return int(obj)
if isinstance(obj, (BasicType, MappingProxyType)):
return str(obj)
if isinstance(obj, (Target, Backend, sp.Symbol)):
return obj.name
if isinstance(obj, Field):
return f"pystencils.Field(name = {obj.name}, field_type = {obj.field_type.name}, " \
f"dtype = {str(obj.dtype)}, layout = {obj.layout}, shape = {obj.shape}, " \
f"strides = {obj.strides})"
return JsonEncoder.default(self, obj)
class PystencilsJsonSerializer(object):
@classmethod
def serialize(cls, data):
if six.PY3:
if isinstance(data, bytes):
return json.dumps(data.decode('utf-8'), cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
else:
return json.dumps(data, cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
else:
return json.dumps(data, cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
@classmethod
def deserialize(cls, data):
if six.PY3:
return json.loads(data.decode('utf-8'))
else:
return json.loads(data.decode('utf-8'))
class Database:
"""NoSQL database for storing simulation results.
Two backends are supported:
* `blitzdb`: simple file-based solution similar to sqlite for SQL databases, stores json files
no server setup required, but slow for larger collections
* `mongodb`: mongodb backend via `pymongo`
A simulation result is stored as an object consisting of
* parameters: dict with simulation parameters
* results: dict with results
* environment: information about the machine, compiler configuration and time
Args:
file: database identifier, for blitzdb pass a directory name here. Database folder is created if it doesn't
exist yet. For larger collections use mongodb. In this case pass a pymongo connection string
e.g. "mongo://server:9131"
Example:
>>> from tempfile import TemporaryDirectory
>>> with TemporaryDirectory() as tmp_dir:
... db = Database(tmp_dir) # create database in temporary folder
... params = {'method': 'finite_diff', 'dx': 1.5} # some hypothetical simulation parameters
... db.save(params, result={'error': 1e-6}) # save simulation parameters together with hypothetical results
... assert db.was_already_simulated(params) # search for parameters in database
... assert next(db.filter_params(params))['params'] == params # get data set, keys are 'params', 'results'
... # and 'env'
... # get a pandas object with all results matching a query
... df = db.to_pandas({'dx': 1.5}, remove_prefix=True)
... # order columns alphabetically (just for doctest output)
... df.reindex(sorted(df.columns), axis=1) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
dx error method
pk
... 1.5 0.000001 finite_diff
"""
class SimulationResult(blitzdb.Document):
pass
def __init__(self, file: str, serializer_info: tuple = None) -> None:
if file.startswith("mongo://"):
from pymongo import MongoClient
db_name = file[len("mongo://"):]
c = MongoClient()
self.backend = blitzdb.MongoBackend(c[db_name])
else:
self.backend = blitzdb.FileBackend(file)
self.backend.autocommit = True
if serializer_info:
serializer_classes.update({serializer_info[0]: serializer_info[1]})
self.backend.load_config({'serializer_class': serializer_info[0]}, True)
def save(self, params: Dict, result: Dict, env: Dict = None, **kwargs) -> None:
"""Stores a simulation result in the database.
Args:
params: dict of simulation parameters
result: dict of simulation results
env: optional environment - if None a default environment with compiler configuration, machine info and time
is used
**kwargs: the final object is updated with the keyword arguments
"""
document_dict = {
'params': params,
'result': result,
'env': env if env else self.get_environment(),
}
document_dict.update(kwargs)
document = Database.SimulationResult(document_dict, backend=self.backend)
document.save()
self.backend.commit()
def filter_params(self, parameter_query: Dict, *args, **kwargs) -> Iterator['SimulationResult']:
"""Query using simulation parameters.
See blitzdb documentation for filter
Args:
parameter_query: blitzdb filter dict using only simulation parameters
*args: arguments passed to blitzdb filter
**kwargs: arguments passed to blitzdb filter
Returns:
generator of SimulationResult, which is a dict-like object with keys 'params', 'result' and 'env'
"""
query = {'params.' + k: v for k, v in parameter_query.items()}
return self.filter(query, *args, **kwargs)
def filter(self, *args, **kwargs):
"""blitzdb filter on SimulationResult, not only simulation parameters.
Can be used to filter for results or environment options.
The filter dictionary has to have prefixes "params." , "env." or "result."
"""
return self.backend.filter(Database.SimulationResult, *args, **kwargs)
def was_already_simulated(self, parameters):
"""Checks if there is at least one simulation result matching the passed parameters."""
return len(self.filter({'params': parameters})) > 0
# Columns with these prefixes are not included in pandas result
pandas_columns_to_ignore = ['changedParams.', 'env.']
def to_pandas(self, parameter_query, remove_prefix=True, drop_constant_columns=False):
"""Queries for simulations with given parameters and returns them in a pandas data frame.
Args:
parameter_query: see filter method
remove_prefix: if True the name of the pandas columns are not prefixed with "params." or "results."
drop_constant_columns: if True, all columns are dropped that have the same value is all rows
Returns:
pandas data frame
"""
from pandas import json_normalize
query_result = self.filter_params(parameter_query)
attributes = [e.attributes for e in query_result]
if not attributes:
return
df = json_normalize(attributes)
df.set_index('pk', inplace=True)
if self.pandas_columns_to_ignore:
remove_columns_by_prefix(df, self.pandas_columns_to_ignore, inplace=True)
if remove_prefix:
remove_prefix_in_column_name(df, inplace=True)
if drop_constant_columns:
df, _ = remove_constant_columns(df)
return df
@staticmethod
def get_environment():
result = {
'timestamp': time.mktime(time.gmtime()),
'hostname': socket.gethostname(),
'cpuCompilerConfig': get_compiler_config(),
}
try:
from git import Repo
except ImportError:
return result
try:
from git import InvalidGitRepositoryError
repo = Repo(search_parent_directories=True)
result['git_hash'] = str(repo.head.commit)
except InvalidGitRepositoryError:
pass
return result
# ----------------------------------------- Helper Functions -----------------------------------------------------------
def remove_constant_columns(df):
"""Removes all columns of a pandas data frame that have the same value in all rows."""
import pandas as pd
remaining_df = df.loc[:, df.apply(pd.Series.nunique) > 1]
constants = df.loc[:, df.apply(pd.Series.nunique) <= 1].iloc[0]
return remaining_df, constants
def remove_columns_by_prefix(df, prefixes: Sequence[str], inplace: bool = False):
"""Remove all columns from a pandas data frame whose name starts with one of the given prefixes."""
if not inplace:
df = df.copy()
for column_name in df.columns:
for prefix in prefixes:
if column_name.startswith(prefix):
del df[column_name]
return df
def remove_prefix_in_column_name(df, inplace: bool = False):
"""Removes dotted prefixes from pandas column names.
A column named 'result.finite_diff.dx' is renamed to 'finite_diff.dx', everything before the first dot is removed.
If the column name does not contain a dot, the column name is not changed.
"""
if not inplace:
df = df.copy()
new_column_names = []
for column_name in df.columns:
if '.' in column_name:
new_column_names.append(column_name[column_name.index('.') + 1:])
else:
new_column_names.append(column_name)
df.columns = new_column_names
return df
import datetime
import itertools
import json
import os
import socket
from collections import namedtuple
from copy import deepcopy
from time import sleep
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
from pystencils.runhelper import Database
from pystencils.runhelper.db import PystencilsJsonSerializer
from pystencils.utils import DotDict
ParameterDict = Dict[str, Any]
WeightFunction = Callable[[Dict], int]
FilterFunction = Callable[[ParameterDict], Optional[ParameterDict]]
class ParameterStudy:
"""Manages and runs multiple configurations locally or distributed and stores results in NoSQL database.
To run a parameter study, define a run function that takes all parameters as keyword arguments and returns the
results as a (possibly nested) dictionary. Then, define the parameter sets that this function should be run with.
Examples:
>>> import tempfile
>>>
>>> def dummy_run_function(p1, p2, p3, p4):
... print("Run called with", p1, p2, p3, p4)
... return { 'result1': p1 * p2, 'result2': p3 + p4 }
>>>
>>> with tempfile.TemporaryDirectory() as tmp_dir:
... ps = ParameterStudy(dummy_run_function, database_connector=tmp_dir)
... ps.add_run({'p1': 5, 'p2': 42, 'p3': 'abc', 'p4': 'def'})
... ps.add_combinations( [('p1', [1, 2]),
... ('p3', ['x', 'y'])], constant_parameters={'p2': 5, 'p4': 'z' })
... ps.run()
... ps.run_scenarios_not_in_database()
... ps.run_from_command_line(argv=['local']) # alternative to run - exposes a command line interface if
... # no argv is passed. Does not run anything here, because
... # configuration already in database are skipped
Run called with 2 5 y z
Run called with 2 5 x z
Run called with 1 5 y z
Run called with 1 5 x z
Run called with 5 42 abc def
Above example runs all parameter combinations locally and stores the returned result in the NoSQL database.
It is also possible to distribute the runs to multiple processes, by starting a server on one machine and multiple
executing runners on other machines. The server distributes configurations to the runners, collects their results
to stores the results in the database.
"""
Run = namedtuple("Run", ['parameter_dict', 'weight'])
def __init__(self, run_function: Callable[..., Dict], runs: Sequence = (),
database_connector: str = './db',
serializer_info: tuple = ('pystencils_serializer', PystencilsJsonSerializer)) -> None:
self.runs = list(runs)
self.run_function = run_function
self.db = Database(database_connector, serializer_info)
def add_run(self, parameter_dict: ParameterDict, weight: int = 1) -> None:
"""Schedule a dictionary of parameters to run in this parameter study.
Args:
parameter_dict: used as keyword arguments to the run function.
weight: weight of the run configuration which should be proportional to runtime of this case,
used for progress display and distribution to processes.
"""
self.runs.append(self.Run(parameter_dict, weight))
def add_combinations(self, degrees_of_freedom: Sequence[Tuple[str, Sequence[Any]]],
constant_parameters: Optional[ParameterDict] = None,
filter_function: Optional[FilterFunction] = None,
runtime_weight_function: Optional[WeightFunction] = None) -> None:
"""Add all possible combinations of given parameters as runs.
This is a convenience function to simulate all possible parameter combinations of a scenario.
Configurations can be filtered and weighted by passing filter- and weighting functions.
Args:
degrees_of_freedom: defines for each parameter the possible values it can take on
constant_parameters: parameter dict, for parameters that should not be changed
filter_function: optional function that receives a parameter dict and returns the potentially modified dict
or None if this combination should not be added.
runtime_weight_function: function mapping a parameter dict to the runtime weight (see weight at add_runs)
Examples:
degrees_of_freedom = [('p1', [1,2]),
('p2', ['a', 'b'])]
is equivalent to calling add_run four times, with all possible parameter combinations.
"""
parameter_names = [e[0] for e in degrees_of_freedom]
parameter_values = [e[1] for e in degrees_of_freedom]
default_params_dict = {} if constant_parameters is None else constant_parameters
for value_tuple in itertools.product(*parameter_values):
params_dict = deepcopy(default_params_dict)
params_dict.update({name: value for name, value in zip(parameter_names, value_tuple)})
params = DotDict(params_dict)
if filter_function:
params = filter_function(params)
if params is None:
continue
weight = 1 if not runtime_weight_function else runtime_weight_function(params)
self.add_run(params, weight)
def run(self, process: int = 0, num_processes: int = 1, parameter_update: Optional[ParameterDict] = None) -> None:
"""Runs all added configurations.
Args:
process: configurations are split into num_processes chunks according to weights and only the
process'th chunk is run. To run all, use process=0 and num_processes=1
num_processes: see above
parameter_update: Extend/override all configurations with this dictionary.
"""
parameter_update = {} if parameter_update is None else parameter_update
own_runs = self._distribute_runs(self.runs, process, num_processes)
for run in own_runs:
parameter_dict = run.parameter_dict.copy()
parameter_dict.update(parameter_update)
result = self.run_function(**parameter_dict)
self.db.save(run.parameter_dict, result, None, changed_params=parameter_update)
def run_scenarios_not_in_database(self, parameter_update: Optional[ParameterDict] = None) -> None:
"""Same as run method, but runs only configuration for which no result is in the database yet."""
parameter_update = {} if parameter_update is None else parameter_update
filtered_runs = self._filter_already_simulated(self.runs)
for run in filtered_runs:
parameter_dict = run.parameter_dict.copy()
parameter_dict.update(parameter_update)
result = self.run_function(**parameter_dict)
self.db.save(run.parameter_dict, result, changed_params=parameter_update)
def run_server(self, ip: str = "0.0.0.0", port: int = 8342):
"""Runs server to supply runner clients with scenarios to simulate and collect results from them.
Skips scenarios that are already in the database."""
from http.server import BaseHTTPRequestHandler, HTTPServer
filtered_runs = self._filter_already_simulated(self.runs)
if not filtered_runs:
print("No Scenarios to simulate")
return
class ParameterStudyServer(BaseHTTPRequestHandler):
parameterStudy = self
all_runs = filtered_runs
runs = filtered_runs.copy()
currently_running = {}
finished_runs = []
def next_scenario(self, received_json_data):
client_name = received_json_data['client_name']
if len(self.runs) > 0:
run_status = "%d/%d" % (len(self.finished_runs), len(self.all_runs))
work_status = "%d/%d" % (sum(r.weight for r in self.finished_runs),
sum(r.weight for r in self.all_runs))
format_args = {
'remaining': len(self.runs),
'time': datetime.datetime.now().strftime("%H:%M:%S"),
'client_name': client_name,
'run_status': run_status,
'work_status': work_status,
}
scenario = self.runs.pop(0)
print(" {time} {client_name} fetched scenario. Scenarios: {run_status}, Work: {work_status}"
.format(**format_args))
self.currently_running[client_name] = scenario
return {'status': 'ok', 'params': scenario.parameter_dict}
else:
return {'status': 'finished'}
def result(self, received_json_data):
client_name = received_json_data['client_name']
run = self.currently_running[client_name]
self.finished_runs.append(run)
del self.currently_running[client_name]
d = received_json_data
def hash_dict(dictionary):
import hashlib
return hashlib.sha1(json.dumps(dictionary, sort_keys=True).encode()).hexdigest()
assert hash_dict(d['params']) == hash_dict(run.parameter_dict), \
str(d['params']) + "is not equal to " + str(run.parameter_dict)
self.parameterStudy.db.save(run.parameter_dict,
result=d['result'], env=d['env'], changed_params=d['changed_params'])
return {}
# noinspection PyPep8Naming
def do_POST(self) -> None:
mapping = {'/next_scenario': self.next_scenario,
'/result': self.result}
if self.path in mapping.keys():
data = self._read_contents()
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
json_data = json.loads(data)
response = mapping[self.path](json_data)
self.wfile.write(json.dumps(response).encode())
else:
self.send_response(400)
# noinspection PyPep8Naming
def do_GET(self):
return self.do_POST()
def _read_contents(self):
return self.rfile.read(int(self.headers['Content-Length'])).decode()
def log_message(self, fmt, *args):
return
print(f"Listening to connections on {ip}:{port}. Scenarios to simulate: {len(filtered_runs)}")
server = HTTPServer((ip, port), ParameterStudyServer)
while len(ParameterStudyServer.currently_running) > 0 or len(ParameterStudyServer.runs) > 0:
server.handle_request()
server.handle_request()
def run_client(self, client_name: str = "{hostname}_{pid}", server: str = 'localhost', port: int = 8342,
parameter_update: Optional[ParameterDict] = None, max_time=None) -> None:
"""Start runner client that retrieves configuration from server, runs it and reports results back to server.
Args:
client_name: name of the client. Has to be unique for each client.
Placeholders {hostname} and {pid} can be used to generate unique name.
server: url to server
port: port as specified in run_server
parameter_update: Used to override/extend parameters received from the server.
Typical use cases is to set optimization or GPU parameters for some clients to make
some clients simulate on CPU, others on GPU
max_time: maximum runtime in seconds: the client runs scenario after scenario, but starts only a new
scenario if not more than max_time seconds have passed since this function was called.
So the time given here should be the total maximum runtime minus a typical runtime for one setup
"""
from urllib.request import urlopen
from urllib.error import URLError
import time
parameter_update = {} if parameter_update is None else parameter_update
url = f"http://{server}:{port}"
client_name = client_name.format(hostname=socket.gethostname(), pid=os.getpid())
start_time = time.time()
while True:
try:
if max_time is not None and (time.time() - start_time) > max_time:
print("Stopping client - maximum time reached")
break
http_response = urlopen(url + "/next_scenario",
data=json.dumps({'client_name': client_name}).encode())
scenario = json.loads(http_response.read().decode())
if scenario['status'] != 'ok':
break
original_params = scenario['params'].copy()
scenario['params'].update(parameter_update)
result = self.run_function(**scenario['params'])
answer = {'params': original_params,
'changed_params': parameter_update,
'result': result,
'env': Database.get_environment(),
'client_name': client_name}
urlopen(url + '/result', data=json.dumps(answer).encode())
except URLError:
print(f"Cannot connect to server {url} retrying in 5 seconds...")
sleep(5)
def run_from_command_line(self, argv: Optional[Sequence[str]] = None) -> None:
"""Exposes interface to command line with possibility to run directly or distributed via server/client."""
from argparse import ArgumentParser
def server(a):
if a.database:
self.db = Database(a.database)
self.run_server(a.host, a.port)
def client(a):
self.run_client(a.client_name, a.host, a.port, json.loads(a.parameter_override), a.max_time)
def local(a):
if a.database:
self.db = Database(a.database)
self.run_scenarios_not_in_database(json.loads(a.parameter_override))
parser = ArgumentParser()
subparsers = parser.add_subparsers()
local_parser = subparsers.add_parser('local', aliases=['l'],
help="Run scenarios locally which are not yet in database", )
local_parser.add_argument("-d", "--database", type=str, default="")
local_parser.add_argument("-P", "--parameter_override", type=str, default="{}",
help="JSON: the parameter dictionary is updated with these parameters. Use this to "
"set host specific options like GPU call parameters. Enclose in \" ")
local_parser.set_defaults(func=local)
server_parser = subparsers.add_parser('server', aliases=['s'],
help="Runs server to distribute different scenarios to workers", )
server_parser.add_argument("-p", "--port", type=int, default=8342, help="Port to listen on")
server_parser.add_argument("-H", "--host", type=str, default="0.0.0.0", help="IP/Hostname to listen on")
server_parser.add_argument("-d", "--database", type=str, default="")
server_parser.set_defaults(func=server)
client_parser = subparsers.add_parser('client', aliases=['c'],
help="Runs a worker client connection to scenario distribution server")
client_parser.add_argument("-p", "--port", type=int, default=8342, help="Port to connect to")
client_parser.add_argument("-H", "--host", type=str, default="localhost", help="Host or IP to connect to")
client_parser.add_argument("-n", "--client_name", type=str, default="{hostname}_{pid}",
help="Unique client name, you can use {hostname} and {pid} as placeholder")
client_parser.add_argument("-P", "--parameter_override", type=str, default="{}",
help="JSON: the parameter dictionary is updated with these parameters. Use this to "
"set host specific options like GPU call parameters. Enclose in \" ")
client_parser.add_argument("-t", "--max_time", type=int, default=None,
help="If more than this time in seconds has passed, "
"the client stops running scenarios.")
client_parser.set_defaults(func=client)
args = parser.parse_args(argv)
if not len(vars(args)):
parser.print_help()
else:
args.func(args)
def _filter_already_simulated(self, all_runs):
"""Removes all runs from the given list, that are already in the database"""
already_simulated = {json.dumps(e.params) for e in self.db.filter({})}
return [r for r in all_runs if json.dumps(r.parameter_dict) not in already_simulated]
@staticmethod
def _distribute_runs(all_runs, process, num_processes):
"""Partitions runs by their weights into num_processes chunks and returns the process's chunk."""
sorted_runs = sorted(all_runs, key=lambda e: e.weight, reverse=True)
result = sorted_runs[process::num_processes]
result.reverse() # start with faster scenarios
return result
import numpy as np
import sympy as sp
import pystencils as ps
from pystencils.jupyter import make_imshow_animation, display_animation, set_display_mode
import pystencils.plot as plt
__all__ = ['sp', 'np', 'ps', 'plt', 'make_imshow_animation', 'display_animation', 'set_display_mode']
from .assignment_collection import AssignmentCollection
from .simplifications import (
add_subexpressions_for_constants,
add_subexpressions_for_divisions, add_subexpressions_for_field_reads,
add_subexpressions_for_sums, apply_on_all_subexpressions, apply_to_all_assignments,
subexpression_substitution_in_existing_subexpressions,
subexpression_substitution_in_main_assignments, sympy_cse, sympy_cse_on_assignment_list)
from .subexpression_insertion import (
insert_aliases, insert_zeros, insert_constants,
insert_constant_additions, insert_constant_multiples,
insert_squares, insert_symbol_times_minus_one)
from .simplificationstrategy import SimplificationStrategy
__all__ = ['AssignmentCollection', 'SimplificationStrategy',
'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments',
'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions',
'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_constants',
'add_subexpressions_for_divisions', 'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads',
'insert_aliases', 'insert_zeros', 'insert_constants',
'insert_constant_additions', 'insert_constant_multiples',
'insert_squares', 'insert_symbol_times_minus_one']
import itertools
from copy import copy
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
import sympy as sp
import pystencils
from pystencils.assignment import Assignment
from pystencils.simp.simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
from pystencils.sympyextensions import count_operations, fast_subs
class AssignmentCollection:
"""
A collection of equations with subexpression definitions, also represented as assignments,
that are used in the main equations. AssignmentCollection can be passed to simplification methods.
These simplification methods can change the subexpressions, but the number and
left hand side of the main equations themselves is not altered.
Additionally a dictionary of simplification hints is stored, which are set by the functions that create
assignment collections to transport information to the simplification system.
Args:
main_assignments: List of assignments. Main assignments are characterised, that the right hand side of each
assignment is a field access. Thus the generated equations write on arrays.
subexpressions: List of assignments defining subexpressions used in main equations
simplification_hints: Dict that is used to annotate the assignment collection with hints that are
used by the simplification system. See documentation of the simplification rules for
potentially required hints and their meaning.
subexpression_symbol_generator: Generator for new symbols that are used when new subexpressions are added
used to get new symbols that are unique for this AssignmentCollection
"""
# ------------------------------- Creation & Inplace Manipulation --------------------------------------------------
def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = None,
simplification_hints: Optional[Dict[str, Any]] = None,
subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None:
if subexpressions is None:
subexpressions = {}
if isinstance(main_assignments, Dict):
main_assignments = [Assignment(k, v)
for k, v in main_assignments.items()]
if isinstance(subexpressions, Dict):
subexpressions = [Assignment(k, v)
for k, v in subexpressions.items()]
main_assignments = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in main_assignments]))
subexpressions = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in subexpressions]))
self.main_assignments = main_assignments
self.subexpressions = subexpressions
if simplification_hints is None:
simplification_hints = {}
self.simplification_hints = simplification_hints
ctrs = [int(n.name[3:])for n in self.rhs_symbols if "xi_" in n.name]
max_ctr = max(ctrs) + 1 if len(ctrs) > 0 else 0
if subexpression_symbol_generator is None:
self.subexpression_symbol_generator = SymbolGen(ctr=max_ctr)
else:
self.subexpression_symbol_generator = subexpression_symbol_generator
def add_simplification_hint(self, key: str, value: Any) -> None:
"""Adds an entry to the simplification_hints dictionary and checks that is does not exist yet."""
assert key not in self.simplification_hints, "This hint already exists"
self.simplification_hints[key] = value
def add_subexpression(self, rhs: sp.Expr, lhs: Optional[sp.Symbol] = None, topological_sort=True) -> sp.Symbol:
"""Adds a subexpression to current collection.
Args:
rhs: right hand side of new subexpression
lhs: optional left hand side of new subexpression. If None a new unique symbol is generated.
topological_sort: sort the subexpressions topologically after insertion, to make sure that
definition of a symbol comes before its usage. If False, subexpression is appended.
Returns:
left hand side symbol (which could have been generated)
"""
if lhs is None:
lhs = next(self.subexpression_symbol_generator)
eq = Assignment(lhs, rhs)
self.subexpressions.append(eq)
if topological_sort:
self.topological_sort(sort_subexpressions=True,
sort_main_assignments=False)
return lhs
def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None:
"""Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition."""
if sort_subexpressions:
self.subexpressions = sort_assignments_topologically(self.subexpressions)
if sort_main_assignments:
self.main_assignments = sort_assignments_topologically(self.main_assignments)
# ---------------------------------------------- Properties -------------------------------------------------------
@property
def all_assignments(self) -> List[Assignment]:
"""Subexpression and main equations as a single list."""
return self.subexpressions + self.main_assignments
@property
def rhs_symbols(self) -> Set[sp.Symbol]:
"""All symbols used in the assignment collection, which occur on the rhs of any assignment."""
rhs_symbols = set()
for eq in self.all_assignments:
if isinstance(eq, Assignment):
rhs_symbols.update(eq.rhs.atoms(sp.Symbol))
elif isinstance(eq, pystencils.astnodes.Node):
rhs_symbols.update(eq.undefined_symbols)
return rhs_symbols
@property
def free_symbols(self) -> Set[sp.Symbol]:
"""All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
return self.rhs_symbols - self.bound_symbols
@property
def bound_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur on the left hand side of a main assignment or a subexpression."""
bound_symbols_set = set(
[assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)]
)
assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \
"Not in SSA form - same symbol assigned multiple times"
bound_symbols_set = bound_symbols_set.union(*[
assignment.symbols_defined for assignment in self.all_assignments
if isinstance(assignment, pystencils.astnodes.Node)
])
return bound_symbols_set
@property
def rhs_fields(self):
"""All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
return {s.field for s in self.rhs_symbols if hasattr(s, 'field')}
@property
def free_fields(self):
"""All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
return {s.field for s in self.free_symbols if hasattr(s, 'field')}
@property
def bound_fields(self):
"""All field accessed on the left hand side of a main assignment or a subexpression."""
return {s.field for s in self.bound_symbols if hasattr(s, 'field')}
@property
def defined_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur as left-hand-sides of one of the main equations"""
lhs_set = set([assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)])
return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments
if isinstance(assignment, pystencils.astnodes.Node)]))
@property
def operation_count(self):
"""See :func:`count_operations` """
return count_operations(self.all_assignments, only_type=None)
def atoms(self, *args):
return set().union(*[a.atoms(*args) for a in self.all_assignments])
def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]:
"""Returns all symbols that depend on one of the passed symbols.
A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- some_expression(b)' i.e. when
'b' is required to compute 'a'.
"""
queue = list(symbols)
def add_symbols_from_expr(expr):
dependent_symbols = expr.atoms(sp.Symbol)
for ds in dependent_symbols:
queue.append(ds)
handled_symbols = set()
assignment_dict = {e.lhs: e.rhs for e in self.all_assignments}
while len(queue) > 0:
e = queue.pop(0)
if e in handled_symbols:
continue
if e in assignment_dict:
add_symbols_from_expr(assignment_dict[e])
handled_symbols.add(e)
return handled_symbols
def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]] = None, module=None):
"""Returns a python function to evaluate this equation collection.
Args:
symbols: symbol(s) which are the parameter for the created function
fixed_symbols: dictionary with substitutions, that are applied before sympy's lambdify
module: same as sympy.lambdify parameter. Defines which module to use e.g. 'numpy'
Examples:
>>> a, b, c, d = sp.symbols("a b c d")
>>> ac = AssignmentCollection([Assignment(c, a + b), Assignment(d, a**2 + b)],
... subexpressions=[Assignment(b, a + b / 2)])
>>> python_function = ac.lambdify([a], fixed_symbols={b: 2})
>>> python_function(4)
{c: 6, d: 18}
"""
assignments = self.new_with_substitutions(fixed_symbols, substitute_on_lhs=False) if fixed_symbols else self
assignments = assignments.new_without_subexpressions().main_assignments
lambdas = {assignment.lhs: sp.lambdify(symbols, assignment.rhs, module) for assignment in assignments}
def f(*args, **kwargs):
return {s: func(*args, **kwargs) for s, func in lambdas.items()}
return f
# ---------------------------- Creating new modified collections ---------------------------------------------------
def copy(self,
main_assignments: Optional[List[Assignment]] = None,
subexpressions: Optional[List[Assignment]] = None) -> 'AssignmentCollection':
"""Returns a copy with optionally replaced main_assignments and/or subexpressions."""
res = copy(self)
res.simplification_hints = self.simplification_hints.copy()
res.subexpression_symbol_generator = copy(self.subexpression_symbol_generator)
if main_assignments is not None:
res.main_assignments = main_assignments
else:
res.main_assignments = self.main_assignments.copy()
if subexpressions is not None:
res.subexpressions = subexpressions
else:
res.subexpressions = self.subexpressions.copy()
return res
def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False,
substitute_on_lhs: bool = True,
sort_topologically: bool = True) -> 'AssignmentCollection':
"""Returns new object, where terms are substituted according to the passed substitution dict.
Args:
substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions
add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions
substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments
sort_topologically: if subexpressions are added as substitutions and this parameters is true,
the subexpressions are sorted topologically after insertion
Returns:
New AssignmentCollection where substitutions have been applied, self is not altered.
"""
transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs
transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions)
transformed_assignments = transform(self.main_assignments, fast_subs, substitutions)
if add_substitutions_as_subexpressions:
transformed_subexpressions = [Assignment(b, a) for a, b in
substitutions.items()] + transformed_subexpressions
if sort_topologically:
transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions)
return self.copy(transformed_assignments, transformed_subexpressions)
def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection':
"""Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
own_definitions = set([e.lhs for e in self.main_assignments])
other_definitions = set([e.lhs for e in other.main_assignments])
assert len(own_definitions.intersection(other_definitions)) == 0, \
"Cannot merge collections, since both define the same symbols"
own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions}
substitution_dict = {}
processed_other_subexpression_equations = []
for other_subexpression_eq in other.subexpressions:
if other_subexpression_eq.lhs in own_subexpression_symbols:
new_rhs = fast_subs(other_subexpression_eq.rhs, substitution_dict)
if new_rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
continue # exact the same subexpression equation exists already
else:
# different definition - a new name has to be introduced
new_lhs = next(self.subexpression_symbol_generator)
new_eq = Assignment(new_lhs, new_rhs)
processed_other_subexpression_equations.append(new_eq)
substitution_dict[other_subexpression_eq.lhs] = new_lhs
else:
processed_other_subexpression_equations.append(fast_subs(other_subexpression_eq, substitution_dict))
processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments]
return self.copy(self.main_assignments + processed_other_main_assignments,
self.subexpressions + processed_other_subexpression_equations)
def new_filtered(self, symbols_to_extract: Iterable[sp.Symbol]) -> 'AssignmentCollection':
"""Extracts equations that have symbols_to_extract as left hand side, together with necessary subexpressions.
Returns:
new AssignmentCollection, self is not altered
"""
symbols_to_extract = set(symbols_to_extract)
dependent_symbols = self.dependent_symbols(symbols_to_extract)
new_assignments = []
for eq in self.all_assignments:
if eq.lhs in symbols_to_extract:
new_assignments.append(eq)
new_sub_expr = [eq for eq in self.all_assignments
if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract]
return self.copy(new_assignments, new_sub_expr)
def new_without_unused_subexpressions(self) -> 'AssignmentCollection':
"""Returns new collection that only contains subexpressions required to compute the main assignments."""
all_lhs = [eq.lhs for eq in self.main_assignments]
return self.new_filtered(all_lhs)
def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection':
"""Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere."""
new_subexpressions = []
subs_dict = None
for se in self.subexpressions:
if se.lhs == symbol:
subs_dict = {se.lhs: se.rhs}
else:
new_subexpressions.append(se)
if subs_dict is None:
return self
new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions]
new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments]
return self.copy(new_eqs, new_subexpressions)
def new_without_subexpressions(self, subexpressions_to_keep=None) -> 'AssignmentCollection':
"""Returns a new collection where all subexpressions have been inserted."""
if subexpressions_to_keep is None:
subexpressions_to_keep = set()
if len(self.subexpressions) == 0:
return self.copy()
subexpressions_to_keep = set(subexpressions_to_keep)
kept_subexpressions = []
if self.subexpressions[0].lhs in subexpressions_to_keep:
substitution_dict = {}
kept_subexpressions.append(self.subexpressions[0])
else:
substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
subexpression = [e for e in self.subexpressions]
for i in range(1, len(subexpression)):
subexpression[i] = fast_subs(subexpression[i], substitution_dict)
if subexpression[i].lhs in subexpressions_to_keep:
kept_subexpressions.append(subexpression[i])
else:
substitution_dict[subexpression[i].lhs] = subexpression[i].rhs
new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments]
return self.copy(new_assignment, kept_subexpressions)
# ----------------------------------------- Display and Printing -------------------------------------------------
def _repr_html_(self):
"""Interface to Jupyter notebook, to display as a nicely formatted HTML table"""
def make_html_equation_table(equations):
no_border = 'style="border:none"'
html_table = '<table style="border:none; width: 100%; ">'
line = '<tr {nb}> <td {nb}>$${eq}$$</td> </tr> '
for eq in equations:
format_dict = {'eq': sp.latex(eq),
'nb': no_border, }
html_table += line.format(**format_dict)
html_table += "</table>"
return html_table
result = ""
if len(self.subexpressions) > 0:
result += "<div>Subexpressions:</div>"
result += make_html_equation_table(self.subexpressions)
result += "<div>Main Assignments:</div>"
result += make_html_equation_table(self.main_assignments)
return result
def __repr__(self):
return f"AssignmentCollection: {str(tuple(self.defined_symbols))[1:-1]} <- f{tuple(self.free_symbols)}"
def __str__(self):
result = "Subexpressions:\n"
for eq in self.subexpressions:
result += f"\t{eq}\n"
result += "Main Assignments:\n"
for eq in self.main_assignments:
result += f"\t{eq}\n"
return result
def __iter__(self):
return self.all_assignments.__iter__()
@property
def main_assignments_dict(self):
return {a.lhs: a.rhs for a in self.main_assignments}
@property
def subexpressions_dict(self):
return {a.lhs: a.rhs for a in self.subexpressions}
def set_main_assignments_from_dict(self, main_assignments_dict):
self.main_assignments = [Assignment(k, v)
for k, v in main_assignments_dict.items()]
def set_sub_expressions_from_dict(self, sub_expressions_dict):
self.subexpressions = [Assignment(k, v)
for k, v in sub_expressions_dict.items()]
def find(self, *args, **kwargs):
return set.union(
*[a.find(*args, **kwargs) for a in self.all_assignments]
)
def match(self, *args, **kwargs):
rtn = {}
for a in self.all_assignments:
partial_result = a.match(*args, **kwargs)
if partial_result:
rtn.update(partial_result)
return rtn
def subs(self, *args, **kwargs):
return AssignmentCollection(
main_assignments=[a.subs(*args, **kwargs) for a in self.main_assignments],
subexpressions=[a.subs(*args, **kwargs) for a in self.subexpressions]
)
def replace(self, *args, **kwargs):
return AssignmentCollection(
main_assignments=[a.replace(*args, **kwargs) for a in self.main_assignments],
subexpressions=[a.replace(*args, **kwargs) for a in self.subexpressions]
)
def __eq__(self, other):
return set(self.all_assignments) == set(other.all_assignments)
def __bool__(self):
return bool(self.all_assignments)
class SymbolGen:
"""Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
def __init__(self, symbol="xi", dtype=None, ctr=0):
self._ctr = ctr
self._symbol = symbol
self._dtype = dtype
def __iter__(self):
return self
def __next__(self):
name = f"{self._symbol}_{self._ctr}"
self._ctr += 1
if self._dtype is not None:
return pystencils.TypedSymbol(name, self._dtype)
return sp.Symbol(name)
from itertools import chain
from typing import Callable, List, Sequence, Union
from collections import defaultdict
import sympy as sp
from pystencils.assignment import Assignment
from pystencils.astnodes import Node
from pystencils.field import Field
from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
from pystencils.typing import TypedSymbol
def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
"""Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
edges = []
for c1, e1 in enumerate(assignments):
if hasattr(e1, 'lhs') and hasattr(e1, 'rhs'):
symbols = [e1.lhs]
elif isinstance(e1, Node):
symbols = e1.symbols_defined
else:
raise NotImplementedError(f"Cannot sort topologically. Object of type {type(e1)} cannot be handled.")
for lhs in symbols:
for c2, e2 in enumerate(assignments):
if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols:
edges.append((c1, c2))
elif isinstance(e2, Node) and lhs in e2.undefined_symbols:
edges.append((c1, c2))
return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))]
def sympy_cse(ac, **kwargs):
"""Searches for common subexpressions inside the assignment collection.
Searches is done in both the existing subexpressions as well as the assignments themselves.
It uses the sympy subexpression detection to do this. Return a new assignment collection
with the additional subexpressions found
"""
symbol_gen = ac.subexpression_symbol_generator
all_assignments = [e for e in chain(ac.subexpressions, ac.main_assignments) if isinstance(e, Assignment)]
other_objects = [e for e in chain(ac.subexpressions, ac.main_assignments) if not isinstance(e, Assignment)]
replacements, new_eq = sp.cse(all_assignments, symbols=symbol_gen, **kwargs)
replacement_eqs = [Assignment(*r) for r in replacements]
modified_subexpressions = new_eq[:len(ac.subexpressions)]
modified_update_equations = new_eq[len(ac.subexpressions):]
new_subexpressions = sort_assignments_topologically(other_objects + replacement_eqs + modified_subexpressions)
return ac.copy(modified_update_equations, new_subexpressions)
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
"""Extracts common subexpressions from a list of assignments."""
from pystencils.simp.assignment_collection import AssignmentCollection
ec = AssignmentCollection([], assignments)
return sympy_cse(ec).all_assignments
def subexpression_substitution_in_existing_subexpressions(ac):
"""Goes through the subexpressions list and replaces the term in the following subexpressions."""
result = []
for outer_ctr, s in enumerate(ac.subexpressions):
new_rhs = s.rhs
for inner_ctr in range(outer_ctr):
sub_expr = ac.subexpressions[inner_ctr]
new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs)
result.append(Assignment(s.lhs, new_rhs))
return ac.copy(ac.main_assignments, result)
def subexpression_substitution_in_main_assignments(ac):
"""Replaces already existing subexpressions in the equations of the assignment_collection."""
result = []
for s in ac.main_assignments:
new_rhs = s.rhs
for sub_expr in ac.subexpressions:
new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
result.append(Assignment(s.lhs, new_rhs))
return ac.copy(result)
def add_subexpressions_for_constants(ac):
"""Extracts constant factors to subexpressions in the given assignment collection.
SymPy will exclude common factors from a sum only if they are symbols. This simplification
can be applied to exclude common numeric constants from multiple terms of a sum. As a consequence,
the number of multiplications is reduced and in some cases, more common subexpressions can be found.
"""
constants_to_subexp_dict = defaultdict(lambda: next(ac.subexpression_symbol_generator))
def visit(expr):
args = list(expr.args)
if len(args) == 0:
return expr
if isinstance(expr, sp.Add) or isinstance(expr, sp.Mul):
for i, arg in enumerate(args):
if is_constant(arg) and abs(arg) != 1:
if arg < 0:
args[i] = - constants_to_subexp_dict[- arg]
else:
args[i] = constants_to_subexp_dict[arg]
return expr.func(*(visit(a) for a in args))
main_assignments = [Assignment(a.lhs, visit(a.rhs)) for a in ac.main_assignments]
subexpressions = [Assignment(a.lhs, visit(a.rhs)) for a in ac.subexpressions]
symbols_to_collect = set(constants_to_subexp_dict.values())
main_assignments = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in main_assignments]
subexpressions = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in subexpressions]
subexpressions = [Assignment(symb, c) for c, symb in constants_to_subexp_dict.items()] + subexpressions
return ac.copy(main_assignments=main_assignments, subexpressions=subexpressions)
def add_subexpressions_for_divisions(ac):
r"""Introduces subexpressions for all divisions which have no constant in the denominator.
For example :math:`\frac{1}{x}` is replaced while :math:`\frac{1}{3}` is not replaced.
"""
divisors = set()
def search_divisors(term):
if term.func == sp.Pow:
if term.exp.is_integer and term.exp.is_number and term.exp < 0:
divisors.add(term)
else:
for a in term.args:
search_divisors(a)
for eq in ac.all_assignments:
search_divisors(eq.rhs)
divisors = sorted(list(divisors), key=lambda x: str(x))
new_symbol_gen = ac.subexpression_symbol_generator
substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)}
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)
def add_subexpressions_for_sums(ac):
r"""Introduces subexpressions for all sums - i.e. splits addends into subexpressions."""
addends = []
def contains_sum(term):
if term.func == sp.Add:
return True
if term.is_Atom:
return False
return any([contains_sum(a) for a in term.args])
def search_addends(term):
if term.func == sp.Add:
if all([not contains_sum(a) for a in term.args]):
addends.extend(term.args)
for a in term.args:
search_addends(a)
for eq in ac.all_assignments:
search_addends(eq.rhs)
addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, Field.Access)]
new_symbol_gen = ac.subexpression_symbol_generator
substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)}
return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)
def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True, data_type=None):
r"""Substitutes field accesses on rhs of assignments with subexpressions
Can change semantics of the update rule (which is the goal of this transformation)
This is useful if a field should be update in place - all values are loaded before into subexpression variables,
then the new values are computed and written to the same field in-place.
Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have
this data type. This is useful for mixed precision kernels
"""
field_reads = set()
to_iterate = []
if subexpressions:
to_iterate = chain(to_iterate, ac.subexpressions)
if main_assignments:
to_iterate = chain(to_iterate, ac.main_assignments)
for assignment in to_iterate:
if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'):
field_reads.update(assignment.rhs.atoms(Field.Access))
if not field_reads:
return ac
substitutions = dict()
for fa in field_reads:
lhs = next(ac.subexpression_symbol_generator)
if data_type is not None:
substitutions.update({fa: TypedSymbol(lhs.name, data_type)})
else:
substitutions.update({fa: lhs})
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True,
substitute_on_lhs=False, sort_topologically=False)
def transform_rhs(assignment_list, transformation, *args, **kwargs):
"""Applies a transformation function on the rhs of each element of the passed assignment list
If the list also contains other object, like AST nodes, these are ignored.
Additional parameters are passed to the transformation function"""
return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a
for a in assignment_list]
def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
return [Assignment(transformation(a.lhs, *args, **kwargs),
transformation(a.rhs, *args, **kwargs))
if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a
for a in assignment_list]
def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]):
"""Applies a given operation to all equations in collection."""
def f(ac):
return ac.copy(transform_rhs(ac.main_assignments, operation))
f.__name__ = operation.__name__
return f
def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):
"""Applies the given operation on all subexpressions of the AC."""
def f(ac):
return ac.copy(ac.main_assignments, transform_rhs(ac.subexpressions, operation))
f.__name__ = operation.__name__
return f
# TODO Markus
# make this really work for Assignmentcollections
# this function should ONLY evaluate
# do the optims_c99 elsewhere optionally
# def apply_sympy_optimisations(ac: AssignmentCollection):
# """ Evaluates constant expressions (e.g. :math:`\\sqrt{3}` will be replaced by its floating point representation)
# and applies the default sympy optimisations. See sympy.codegen.rewriting
# """
#
# # Evaluates all constant terms
#
# assignments = ac.all_assignments
#
# evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
# lambda p: p.evalf())
#
# sympy_optimisations = [evaluate_constant_terms] + list(optims_c99)
#
# assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
# if hasattr(a, 'lhs')
# else a for a in assignments]
# assignments_nodes = [a.atoms(SympyAssignment) for a in assignments]
# for a in chain.from_iterable(assignments_nodes):
# a.optimize(sympy_optimisations)
#
# return AssignmentCollection(assignments)
from collections import namedtuple
from typing import Any, Callable, Optional, Sequence
import sympy as sp
from pystencils.simp.assignment_collection import AssignmentCollection
class SimplificationStrategy:
"""A simplification strategy is an ordered collection of simplification rules.
Each simplification is a function taking an assignment collection, and returning a new simplified
assignment collection. The strategy can nicely print intermediate simplification stages and results
to Jupyter notebooks.
"""
def __init__(self):
self._rules = []
def add(self, rule: Callable[[AssignmentCollection], AssignmentCollection]) -> None:
"""Adds the given simplification rule to the end of the collection.
Args:
rule: function that rewrites/simplifies an assignment collection
"""
self._rules.append(rule)
@property
def rules(self):
return self._rules
def apply(self, assignment_collection: AssignmentCollection) -> AssignmentCollection:
"""Runs all rules on the given assignment collection."""
for t in self._rules:
assignment_collection = t(assignment_collection)
return assignment_collection
def __call__(self, assignment_collection: AssignmentCollection) -> AssignmentCollection:
"""Same as apply"""
return self.apply(assignment_collection)
def create_simplification_report(self, assignment_collection: AssignmentCollection) -> Any:
"""Creates a report to be displayed as HTML in a Jupyter notebook.
The simplification report contains the number of operations at each simplification stage together
with the run-time the simplification took.
"""
ReportElement = namedtuple('ReportElement', ['simplificationName', 'runtime', 'adds', 'muls', 'divs', 'total'])
class Report:
def __init__(self):
self.elements = []
def add(self, element):
self.elements.append(element)
def __str__(self):
try:
import tabulate
return tabulate(self.elements, headers=['Name', 'Runtime', 'Adds', 'Muls', 'Divs', 'Total'])
except ImportError:
result = "Name, Adds, Muls, Divs, Runtime\n"
for e in self.elements:
result += ",".join([str(tuple_item) for tuple_item in e]) + "\n"
return result
def _repr_html_(self):
html_table = '<table style="border:none">'
html_table += "<tr><th>Name</th>" \
"<th>Runtime</th>" \
"<th>Adds</th>" \
"<th>Muls</th>" \
"<th>Divs</th>" \
"<th>Total</th></tr>"
line = "<tr><td>{simplificationName}</td>" \
"<td>{runtime}</td> <td>{adds}</td> <td>{muls}</td> <td>{divs}</td> <td>{total}</td> </tr>"
for e in self.elements:
# noinspection PyProtectedMember
html_table += line.format(**e._asdict())
html_table += "</table>"
return html_table
import timeit
report = Report()
op = assignment_collection.operation_count
total = op['adds'] + op['muls'] + op['divs']
report.add(ReportElement("OriginalTerm", '-', op['adds'], op['muls'], op['divs'], total))
for t in self._rules:
start_time = timeit.default_timer()
assignment_collection = t(assignment_collection)
end_time = timeit.default_timer()
op = assignment_collection.operation_count
time_str = f"{(end_time - start_time) * 1000:.2f} ms"
total = op['adds'] + op['muls'] + op['divs']
report.add(ReportElement(t.__name__, time_str, op['adds'], op['muls'], op['divs'], total))
return report
def show_intermediate_results(self, assignment_collection: AssignmentCollection,
symbols: Optional[Sequence[sp.Symbol]] = None) -> Any:
"""Shows the assignment collection after the application of each rule as HTML report for Jupyter notebook.
Args:
assignment_collection: the collection to apply the rules to
symbols: if not None, only the assignments are shown that have one of these symbols as left hand side
"""
class IntermediateResults:
def __init__(self, strategy, collection, restrict_symbols):
self.strategy = strategy
self.assignment_collection = collection
self.restrict_symbols = restrict_symbols
def __str__(self):
def print_assignment_collection(title, c):
text = title
if self.restrict_symbols:
text += "\n".join([str(e) for e in c.new_filtered(self.restrict_symbols).main_assignments])
else:
text += (" " * 3 + (" " * 3).join(str(c).splitlines(True)))
return text
result = print_assignment_collection("Initial Version", self.assignment_collection)
collection = self.assignment_collection
for rule in self.strategy.rules:
collection = rule(collection)
result += print_assignment_collection(rule.__name__, collection)
return result
def _repr_html_(self):
def print_assignment_collection(title, c):
text = f'<h5 style="padding-bottom:10px">{title}</h5> <div style="padding-left:20px;">'
if self.restrict_symbols:
text += "\n".join(["$$" + sp.latex(e) + '$$'
for e in c.new_filtered(self.restrict_symbols).main_assignments])
else:
# noinspection PyProtectedMember
text += c._repr_html_()
text += "</div>"
return text
result = print_assignment_collection("Initial Version", self.assignment_collection)
collection = self.assignment_collection
for rule in self.strategy.rules:
collection = rule(collection)
result += print_assignment_collection(rule.__name__, collection)
return result
return IntermediateResults(self, assignment_collection, symbols)
def __repr__(self):
result = "Simplification Strategy:\n"
for t in self._rules:
result += f" - {t.__name__}\n"
return result
import sympy as sp
from pystencils.sympyextensions import is_constant
# Subexpression Insertion
def insert_subexpressions(ac, selection_callback, skip=None):
"""
Removes a number of subexpressions from an assignment collection by
inserting their right-hand side wherever they occur.
Args:
- selection_callback: Function that is called to qualify subexpressions
for insertion. Should return `True` for any subexpression that is to be
inserted, and `False` otherwise.
- skip: Set of symbols (left-hand sides of subexpressions) that should be
ignored even if qualified by the callback.
"""
if skip is None:
skip = set()
i = 0
while i < len(ac.subexpressions):
exp = ac.subexpressions[i]
if exp.lhs not in skip and selection_callback(exp):
ac = ac.new_with_inserted_subexpression(exp.lhs)
else:
i += 1
return ac
def insert_aliases(ac, **kwargs):
"""Inserts subexpressions that are aliases of other symbols,
i.e. their right-hand side is only another symbol."""
return insert_subexpressions(ac, lambda x: isinstance(x.rhs, sp.Symbol), **kwargs)
def insert_zeros(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is zero."""
zero = sp.Integer(0)
return insert_subexpressions(ac, lambda x: x.rhs == zero, **kwargs)
def insert_constants(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is constant,
i.e. contains no symbols."""
return insert_subexpressions(ac, lambda x: is_constant(x.rhs), **kwargs)
def insert_symbol_times_minus_one(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is just a
negation of another symbol."""
def callback(exp):
rhs = exp.rhs
minus_one = sp.Integer(-1)
atoms = rhs.atoms(sp.Symbol)
return len(atoms) == 1 and rhs == minus_one * atoms.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_constant_multiples(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is a constant
multiplied with another symbol."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
numbers = rhs.atoms(sp.Number)
return len(symbols) == 1 and len(numbers) == 1 and \
rhs == numbers.pop() * symbols.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_constant_additions(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is a sum of a
constant and another symbol."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
numbers = rhs.atoms(sp.Number)
return len(symbols) == 1 and len(numbers) == 1 and \
rhs == numbers.pop() + symbols.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_squares(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is another symbol squared."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
return len(symbols) == 1 and rhs == symbols.pop() ** 2
return insert_subexpressions(ac, callback, **kwargs)
def bind_symbols_to_skip(insertion_function, skip):
return lambda ac: insertion_function(ac, skip=skip)
from pystencils.simp import (SimplificationStrategy, insert_constants, insert_symbol_times_minus_one,
insert_constant_multiples, insert_constant_additions, insert_squares, insert_zeros)
def create_simplification_strategy():
"""
Creates a default simplification `ps.simp.SimplificationStrategy`. The idea behind the default simplification
strategy is to reduce the number of subexpressions by inserting single constants and to evaluate constant
terms beforehand.
"""
s = SimplificationStrategy()
s.add(insert_symbol_times_minus_one)
s.add(insert_constant_multiples)
s.add(insert_constant_additions)
s.add(insert_squares)
s.add(insert_zeros)
s.add(insert_constants)
s.add(lambda ac: ac.new_without_unused_subexpressions())
import sympy as sp
from pystencils.field import create_numpy_array_with_layout, get_layout_of_array
class SliceMaker(object):
def __getitem__(self, item):
return item
make_slice = SliceMaker()
class SlicedGetter(object):
def __init__(self, function_returning_array):
self._functionReturningArray = function_returning_array
def __getitem__(self, item):
return self._functionReturningArray(item)
class SlicedGetterDataHandling:
def __init__(self, data_handling, name):
self.dh = data_handling
self.name = name
def __getitem__(self, slice_obj):
if slice_obj is None:
slice_obj = make_slice[:, :] if self.data_handling.dim == 2 else make_slice[:, :, 0.5]
return self.dh.gather_array(self.name, slice_obj).squeeze()
def normalize_slice(slices, sizes):
"""Converts slices with floating point and/or negative entries to integer slices"""
if len(slices) != len(sizes):
raise ValueError("Slice dimension does not match sizes")
result = []
for s, size in zip(slices, sizes):
if type(s) is int:
if s < 0:
s = size + s
result.append(s)
continue
if type(s) is float:
result.append(int(s * size))
continue
assert (type(s) is slice)
if s.start is None:
new_start = 0
elif type(s.start) is float:
new_start = int(s.start * size)
elif not isinstance(s.start, sp.Basic) and s.start < 0:
new_start = size + s.start
else:
new_start = s.start
if s.stop is None:
new_stop = size
elif type(s.stop) is float:
new_stop = int(s.stop * size)
elif not isinstance(s.stop, sp.Basic) and s.stop < 0:
new_stop = size + s.stop
else:
new_stop = s.stop
result.append(slice(new_start, new_stop, s.step if s.step is not None else 1))
return tuple(result)
def shift_slice(slices, offset):
def shift_slice_component(slice_comp, shift_offset):
if slice_comp is None:
return None
elif isinstance(slice_comp, int):
return slice_comp + shift_offset
elif isinstance(slice_comp, float):
return slice_comp # relative entries are not shifted
elif isinstance(slice_comp, slice):
return slice(shift_slice_component(slice_comp.start, shift_offset),
shift_slice_component(slice_comp.stop, shift_offset),
slice_comp.step)
else:
raise ValueError()
if hasattr(offset, '__len__'):
return tuple(shift_slice_component(k, off) for k, off in zip(slices, offset))
else:
if isinstance(slices, slice) or isinstance(slices, int) or isinstance(slices, float):
return shift_slice_component(slices, offset)
else:
return tuple(shift_slice_component(k, offset) for k in slices)
def slice_from_direction(direction_name, dim, normal_offset=0, tangential_offset=0):
"""
Create a slice from a direction named by compass scheme:
i.e. 'N' for north returns same as make_slice[:, -1]
the naming is:
- x: W, E (west, east)
- y: S, N (south, north)
- z: B, T (bottom, top)
Also combinations are allowed like north-east 'NE'
:param direction_name: name of direction as explained above
:param dim: dimension of the returned slice (should be 2 or 3)
:param normal_offset: the offset in 'normal' direction: e.g. slice_from_direction('N',2, normal_offset=2)
would return make_slice[:, -3]
:param tangential_offset: offset in the other directions: e.g. slice_from_direction('N',2, tangential_offset=2)
would return make_slice[2:-2, -1]
"""
if tangential_offset == 0:
result = [slice(None, None, None)] * dim
else:
result = [slice(tangential_offset, -tangential_offset, None)] * dim
normal_slice_high, normal_slice_low = -1 - normal_offset, normal_offset
for dim_idx, (low_name, high_name) in enumerate([('W', 'E'), ('S', 'N'), ('B', 'T')]):
if low_name in direction_name:
assert high_name not in direction_name, "Invalid direction name"
result[dim_idx] = normal_slice_low
if high_name in direction_name:
assert low_name not in direction_name, "Invalid direction name"
result[dim_idx] = normal_slice_high
return tuple(result)
def remove_ghost_layers(arr, index_dimensions=0, ghost_layers=1):
if ghost_layers <= 0:
return arr
dimensions = len(arr.shape)
spatial_dimensions = dimensions - index_dimensions
indexing = [slice(ghost_layers, -ghost_layers, None), ] * spatial_dimensions
indexing += [slice(None, None, None)] * index_dimensions
return arr[tuple(indexing)]
def add_ghost_layers(arr, index_dimensions=0, ghost_layers=1, layout=None):
dimensions = len(arr.shape)
spatial_dimensions = dimensions - index_dimensions
new_shape = [e + 2 * ghost_layers for e in arr.shape[:spatial_dimensions]] + list(arr.shape[spatial_dimensions:])
if layout is None:
layout = get_layout_of_array(arr)
result = create_numpy_array_with_layout(new_shape, layout)
result.fill(0.0)
indexing = [slice(ghost_layers, -ghost_layers, None), ] * spatial_dimensions
indexing += [slice(None, None, None)] * index_dimensions
result[tuple(indexing)] = arr
return result
def get_slice_before_ghost_layer(direction, ghost_layers=1, thickness=None, full_slice=False):
"""
Returns slicing expression for region before ghost layer
:param direction: tuple specifying direction of slice
:param ghost_layers: number of ghost layers
:param thickness: thickness of the slice, defaults to number of ghost layers
:param full_slice: if true also the ghost cells in directions orthogonal to direction are contained in the
returned slice. Example (d=W ): if full_slice then also the ghost layer in N-S and T-B
are included, otherwise only inner cells are returned
"""
if not thickness:
thickness = ghost_layers
full_slice_inc = ghost_layers if not full_slice else 0
slices = []
for dir_component in direction:
if dir_component == -1:
s = slice(ghost_layers, thickness + ghost_layers)
elif dir_component == 0:
end = -full_slice_inc
s = slice(full_slice_inc, end if end != 0 else None)
elif dir_component == 1:
start = -thickness - ghost_layers
end = -ghost_layers
s = slice(start if start != 0 else None, end if end != 0 else None)
else:
raise ValueError("Invalid direction: only -1, 0, 1 components are allowed")
slices.append(s)
return tuple(slices)
def get_ghost_region_slice(direction, ghost_layers=1, thickness=None, full_slice=False):
"""
Returns slice of ghost region. For parameters see :func:`get_slice_before_ghost_layer`
"""
if not thickness:
thickness = ghost_layers
assert thickness > 0
assert thickness <= ghost_layers
full_slice_inc = ghost_layers if not full_slice else 0
slices = []
for dir_component in direction:
if dir_component == -1:
s = slice(ghost_layers - thickness, ghost_layers)
elif dir_component == 0:
end = -full_slice_inc
s = slice(full_slice_inc, end if end != 0 else None)
elif dir_component == 1:
start = -ghost_layers
end = - ghost_layers + thickness
s = slice(start if start != 0 else None, end if end != 0 else None)
else:
raise ValueError("Invalid direction: only -1, 0, 1 components are allowed")
slices.append(s)
return tuple(slices)
def get_periodic_boundary_src_dst_slices(stencil, ghost_layers=1, thickness=None):
src_dst_slice_tuples = []
for d in stencil:
if sum([abs(e) for e in d]) == 0:
continue
inv_dir = (-e for e in d)
src = get_slice_before_ghost_layer(inv_dir, ghost_layers, thickness=thickness, full_slice=False)
dst = get_ghost_region_slice(d, ghost_layers, thickness=thickness, full_slice=False)
src_dst_slice_tuples.append((src, dst))
return src_dst_slice_tuples
def get_periodic_boundary_functor(stencil, ghost_layers=1, thickness=None):
"""
Returns a function that applies periodic boundary conditions
:param stencil: sequence of directions e.g. ( [0,1], [0,-1] ) for y periodicity
:param ghost_layers: how many ghost layers the array has
:param thickness: how many of the ghost layers to copy, None means 'all'
:return: function that takes a single array and applies the periodic copy operation
"""
src_dst_slice_tuples = get_periodic_boundary_src_dst_slices(stencil, ghost_layers, thickness)
def functor(pdfs, **_):
for src_slice, dst_slice in src_dst_slice_tuples:
pdfs[dst_slice] = pdfs[src_slice]
return functor
def slice_intersection(slice1, slice2):
slice1 = [s if not isinstance(s, int) else slice(s, s + 1, None) for s in slice1]
slice2 = [s if not isinstance(s, int) else slice(s, s + 1, None) for s in slice2]
new_min = [max(s1.start, s2.start) for s1, s2 in zip(slice1, slice2)]
new_max = [min(s1.stop, s2.stop) for s1, s2 in zip(slice1, slice2)]
if any(max_p - min_p < 0 for min_p, max_p in zip(new_min, new_max)):
return None
return [slice(min_p, max_p, None) for min_p, max_p in zip(new_min, new_max)]
import sympy
import pystencils
import pystencils.astnodes
x_, y_, z_ = tuple(pystencils.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(3))
x_staggered, y_staggered, z_staggered = x_ + 0.5, y_ + 0.5, z_ + 0.5
def x_vector(ndim):
return sympy.Matrix(tuple(pystencils.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(ndim)))
def x_staggered_vector(ndim):
return sympy.Matrix(tuple(
pystencils.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) + 0.5 for i in range(ndim)
))
"""This submodule offers functions to work with stencils in expression an offset-list form."""
from collections import defaultdict
from typing import Sequence
import numpy as np
import sympy as sp
from pystencils.utils import binary_numbers
def inverse_direction(direction):
"""Returns inverse i.e. negative of given direction tuple
Example:
>>> inverse_direction((1, -1, 0))
(-1, 1, 0)
"""
return tuple([-i for i in direction])
def inverse_direction_string(direction):
"""Returns inverse of given direction string"""
return offset_to_direction_string(inverse_direction(direction_string_to_offset(direction)))
def is_valid(stencil, max_neighborhood=None):
"""
Tests if a nested sequence is a valid stencil i.e. all the inner sequences have the same length.
If max_neighborhood is specified, it is also verified that the stencil does not contain any direction components
with absolute value greater than the maximal neighborhood.
Examples:
>>> is_valid([(1, 0), (1, 0, 0)]) # stencil entries have different length
False
>>> is_valid([(2, 0), (1, 0)])
True
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=1)
False
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=2)
True
"""
expected_dim = len(stencil[0])
for d in stencil:
if len(d) != expected_dim:
return False
if max_neighborhood is not None:
for d_i in d:
if abs(d_i) > max_neighborhood:
return False
return True
def is_symmetric(stencil):
"""Tests for every direction d, that -d is also in the stencil
Examples:
>>> is_symmetric([(1, 0), (0, 1)])
False
>>> is_symmetric([(1, 0), (-1, 0)])
True
"""
for d in stencil:
if inverse_direction(d) not in stencil:
return False
return True
def have_same_entries(s1, s2):
"""Checks if two stencils are the same
Examples:
>>> stencil1 = [(1, 0), (-1, 0), (0, 1), (0, -1)]
>>> stencil2 = [(-1, 0), (0, -1), (1, 0), (0, 1)]
>>> stencil3 = [(-1, 0), (0, -1), (1, 0)]
>>> have_same_entries(stencil1, stencil2)
True
>>> have_same_entries(stencil1, stencil3)
False
"""
if len(s1) != len(s2):
return False
return len(set(s1) - set(s2)) == 0
# -------------------------------------Expression - Coefficient Form Conversion ----------------------------------------
def coefficient_dict(expr):
"""Extracts coefficients in front of field accesses in a expression.
Expression may only access a single field at a single index.
Returns:
center, coefficient dict, nonlinear part
where center is the single field that is accessed in expression accessed at center
and coefficient dict maps offsets to coefficients. The nonlinear part is everything that is not in the form of
coefficient times field access.
Examples:
>>> import pystencils as ps
>>> f = ps.fields("f(3) : double[2D]")
>>> field, coeffs, nonlinear_part = coefficient_dict(2 * f[0, 1](1) + 3 * f[-1, 0](1) + 123)
>>> assert nonlinear_part == 123 and field == f(1)
>>> sorted(coeffs.items())
[((-1, 0), 3), ((0, 1), 2)]
"""
from pystencils.field import Field
expr = expr.expand()
field_accesses = expr.atoms(Field.Access)
fields = set(fa.field for fa in field_accesses)
accessed_indices = set(fa.index for fa in field_accesses)
if len(fields) != 1:
raise ValueError("Could not extract stencil coefficients. "
"Expression has to be a linear function of exactly one field.")
if len(accessed_indices) != 1:
raise ValueError("Could not extract stencil coefficients. Field is accessed at multiple indices")
field = fields.pop()
idx = accessed_indices.pop()
coeffs = defaultdict(lambda: 0)
coeffs.update({fa.offsets: expr.coeff(fa) for fa in field_accesses})
linear_part = sum(c * field[off](*idx) for off, c in coeffs.items())
nonlinear_part = expr - linear_part
return field(*idx), coeffs, nonlinear_part
def coefficients(expr):
"""Returns two lists - one with accessed offsets and one with their coefficients.
Same restrictions as `coefficient_dict` apply. Expression must not have any nonlinear part
>>> import pystencils as ps
>>> f = ps.fields("f(3) : double[2D]")
>>> coff = coefficients(2 * f[0, 1](1) + 3 * f[-1, 0](1))
"""
field_center, coeffs, nonlinear_part = coefficient_dict(expr)
assert nonlinear_part == 0
stencil = list(coeffs.keys())
entries = [coeffs[c] for c in stencil]
return stencil, entries
def coefficient_list(expr, matrix_form=False):
"""Returns stencil coefficients in the form of nested lists
Same restrictions as `coefficient_dict` apply. Expression must not have any nonlinear part
Examples:
>>> import pystencils as ps
>>> f = ps.fields("f: double[2D]")
>>> coefficient_list(2 * f[0, 1] + 3 * f[-1, 0])
[[0, 0, 0], [3, 0, 0], [0, 2, 0]]
>>> coefficient_list(2 * f[0, 1] + 3 * f[-1, 0], matrix_form=True)
Matrix([
[0, 2, 0],
[3, 0, 0],
[0, 0, 0]])
"""
field_center, coeffs, nonlinear_part = coefficient_dict(expr)
assert nonlinear_part == 0
field = field_center.field
dim = field.spatial_dimensions
max_offsets = defaultdict(lambda: 0)
for offset in coeffs.keys():
for d, off in enumerate(offset):
max_offsets[d] = max(max_offsets[d], abs(off))
if dim == 1:
result = [coeffs[(i,)] for i in range(-max_offsets[0], max_offsets[0] + 1)]
return sp.Matrix(result) if matrix_form else result
else:
y_range = list(range(-max_offsets[1], max_offsets[1] + 1))
if matrix_form:
y_range.reverse()
if dim == 2:
result = [[coeffs[(i, j)]
for i in range(-max_offsets[0], max_offsets[0] + 1)]
for j in y_range]
return sp.Matrix(result) if matrix_form else result
elif dim == 3:
result = [[[coeffs[(i, j, k)]
for i in range(-max_offsets[0], max_offsets[0] + 1)]
for j in y_range]
for k in range(-max_offsets[2], max_offsets[2] + 1)]
return [sp.Matrix(l) for l in result] if matrix_form else result
else:
raise ValueError("Can only handle fields with 1,2 or 3 spatial dimensions")
# ------------------------------------- Point-on-compass notation ------------------------------------------------------
def offset_component_to_direction_string(coordinate_id: int, value: int) -> str:
"""Translates numerical offset to string notation.
x offsets are labeled with east 'E' and 'W',
y offsets with north 'N' and 'S' and
z offsets with top 'T' and bottom 'B'
If the absolute value of the offset is bigger than 1, this number is prefixed.
Args:
coordinate_id: integer 0, 1 or 2 standing for x,y and z
value: integer offset
Examples:
>>> offset_component_to_direction_string(0, 1)
'E'
>>> offset_component_to_direction_string(1, 2)
'2N'
"""
assert 0 <= coordinate_id < 3, "Works only for at most 3D arrays"
name_components = (('W', 'E'), # west, east
('S', 'N'), # south, north
('B', 'T')) # bottom, top
if value == 0:
result = ""
elif value < 0:
result = name_components[coordinate_id][0]
else:
result = name_components[coordinate_id][1]
if abs(value) > 1:
result = "%d%s" % (abs(value), result)
return result
def offset_to_direction_string(offsets: Sequence[int]) -> str:
"""
Translates numerical offset to string notation.
For details see :func:`offset_component_to_direction_string`
Args:
offsets: 3-tuple with x,y,z offset
Examples:
>>> offset_to_direction_string([1, -1, 0])
'SE'
>>> offset_to_direction_string(([-3, 0, -2]))
'2B3W'
"""
if len(offsets) > 3:
return str(offsets)
names = ["", "", ""]
for i in range(len(offsets)):
names[i] = offset_component_to_direction_string(i, offsets[i])
name = "".join(reversed(names))
if name == "":
name = "C"
return name
def direction_string_to_offset(direction: str, dim: int = 3):
"""
Reverse mapping of :func:`offset_to_direction_string`
Args:
direction: string representation of offset
dim: dimension of offset, i.e the length of the returned list
Examples:
>>> direction_string_to_offset('NW', dim=3)
array([-1, 1, 0])
>>> direction_string_to_offset('NW', dim=2)
array([-1, 1])
>>> direction_string_to_offset(offset_to_direction_string((3,-2,1)))
array([ 3, -2, 1])
"""
offset_dict = {
'C': np.array([0, 0, 0]),
'W': np.array([-1, 0, 0]),
'E': np.array([1, 0, 0]),
'S': np.array([0, -1, 0]),
'N': np.array([0, 1, 0]),
'B': np.array([0, 0, -1]),
'T': np.array([0, 0, 1]),
}
offset = np.array([0, 0, 0])
while len(direction) > 0:
factor = 1
first_non_digit = 0
while direction[first_non_digit].isdigit():
first_non_digit += 1
if first_non_digit > 0:
factor = int(direction[:first_non_digit])
direction = direction[first_non_digit:]
cur_offset = offset_dict[direction[0]]
offset += factor * cur_offset
direction = direction[1:]
return offset[:dim]
def adjacent_directions(direction):
"""
Returns all adjacent directions for a direction as tuple of tuples. This is useful for exmple to find all directions
relevant for neighbour communication.
Args:
direction: tuple representing a direction. For example (0, 1, 0) for the northern side
Examples:
>>> adjacent_directions((0, 0, 0))
((0, 0, 0),)
>>> adjacent_directions((0, 1, 0))
((0, 1, 0),)
>>> adjacent_directions((0, 1, 1))
((0, 0, 1), (0, 1, 0), (0, 1, 1))
>>> adjacent_directions((-1, -1))
((-1, -1), (-1, 0), (0, -1))
"""
result = set()
if all(e == 0 for e in direction):
result.add(direction)
return tuple(result)
binary_numbers_list = binary_numbers(len(direction))
for adjacent_direction in binary_numbers_list:
for i, entry in enumerate(direction):
if entry == 0:
adjacent_direction[i] = 0
if entry == -1 and adjacent_direction[i] == 1:
adjacent_direction[i] = -1
if not all(e == 0 for e in adjacent_direction):
result.add(tuple(adjacent_direction))
return tuple(sorted(result))
# -------------------------------------- Visualization -----------------------------------------------------------------
def plot(stencil, **kwargs):
dim = len(stencil[0])
if dim == 2:
plot_2d(stencil, **kwargs)
else:
slicing = False
if 'slice' in kwargs:
slicing = kwargs['slice']
del kwargs['slice']
if slicing:
plot_3d_slicing(stencil, **kwargs)
else:
plot_3d(stencil, **kwargs)
def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs):
"""
Creates a matplotlib 2D plot of the stencil
Args:
stencil: sequence of directions
axes: optional matplotlib axes
figure: optional matplotlib figure
data: data to annotate the directions with, if none given, the indices are used
textsize: size of annotation text
"""
from matplotlib.patches import BoxStyle
import matplotlib.pyplot as plt
if axes is None:
if figure is None:
figure = plt.gcf()
axes = figure.gca()
text_box_style = BoxStyle("Round", pad=0.3)
head_length = 0.1
max_offsets = [max(abs(int(d[c])) for d in stencil) for c in (0, 1)]
if data is None:
data = list(range(len(stencil)))
for direction, annotation in zip(stencil, data):
assert len(direction) == 2, "Works only for 2D stencils"
direction = tuple(int(i) for i in direction)
if not (direction[0] == 0 and direction[1] == 0):
axes.arrow(0, 0, direction[0], direction[1], head_width=0.08, head_length=head_length, color='k')
if isinstance(annotation, sp.Basic):
annotation = "$" + sp.latex(annotation) + "$"
else:
annotation = str(annotation)
def position_correction(d, magnitude=0.18):
if d < 0:
return -magnitude
elif d > 0:
return +magnitude
else:
return 0
text_position = [direction[c] + position_correction(direction[c]) for c in (0, 1)]
axes.text(x=text_position[0], y=text_position[1], s=annotation, verticalalignment='center',
zorder=30, horizontalalignment='center', size=textsize,
bbox=dict(boxstyle=text_box_style, facecolor='#00b6eb', alpha=0.85, linewidth=0))
axes.set_axis_off()
axes.set_aspect('equal')
max_offsets = [m if m > 0 else 0.1 for m in max_offsets]
border = 0.1
axes.set_xlim([-border - max_offsets[0], border + max_offsets[0]])
axes.set_ylim([-border - max_offsets[1], border + max_offsets[1]])
def plot_3d_slicing(stencil, slice_axis=2, figure=None, data=None, **kwargs):
"""Visualizes a 3D, first-neighborhood stencil by plotting 3 slices along a given axis.
Args:
stencil: stencil as sequence of directions
slice_axis: 0, 1, or 2 indicating the axis to slice through
figure: optional matplotlib figure
data: optional data to print as text besides the arrows
"""
import matplotlib.pyplot as plt
for d in stencil:
for element in d:
assert element == -1 or element == 0 or element == 1, "This function can only first neighborhood stencils"
if figure is None:
figure = plt.gcf()
axes = [figure.add_subplot(1, 3, i + 1) for i in range(3)]
splitted_directions = [[], [], []]
splitted_data = [[], [], []]
axes_names = ['x', 'y', 'z']
for i, d in enumerate(stencil):
split_idx = d[slice_axis] + 1
reduced_dir = tuple([element for j, element in enumerate(d) if j != slice_axis])
splitted_directions[split_idx].append(reduced_dir)
splitted_data[split_idx].append(i if data is None else data[i])
for i in range(3):
plot_2d(splitted_directions[i], axes=axes[i], data=splitted_data[i], **kwargs)
for i in [-1, 0, 1]:
axes[i + 1].set_title("Cut at %s=%d" % (axes_names[slice_axis], i), y=1.08)
def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'):
"""
Draws 3D stencil into a 3D coordinate system, parameters are similar to :func:`visualize_stencil_2d`
If data is None, no labels are drawn. To draw the labels as in the 2D case, use ``data=list(range(len(stencil)))``
"""
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
import matplotlib.pyplot as plt
from matplotlib.patches import BoxStyle
from itertools import product, combinations
import numpy as np
class Arrow3D(FancyArrowPatch):
def __init__(self, xs, ys, zs, *args, **kwargs):
FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
self._verts3d = xs, ys, zs
def do_3d_projection(self, *_):
xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
return np.min(zs)
if axes is None:
if figure is None:
figure = plt.figure()
axes = figure.add_subplot(projection='3d')
try:
axes.set_aspect("equal")
except NotImplementedError:
pass
if data is None:
data = [None] * len(stencil)
text_offset = 1.25
text_box_style = BoxStyle("Round", pad=0.3)
# Draw cell (cube)
r = [-1, 1]
for s, e in combinations(np.array(list(product(r, r, r))), 2):
if np.sum(np.abs(s - e)) == r[1] - r[0]:
axes.plot(*zip(s, e), color="k", alpha=0.5)
for d, annotation in zip(stencil, data):
assert len(d) == 3, "Works only for 3D stencils"
d = tuple(int(i) for i in d)
if not (d[0] == 0 and d[1] == 0 and d[2] == 0):
if d[0] == 0:
color = '#348abd'
elif d[1] == 0:
color = '#fac364'
elif sum([abs(d) for d in d]) == 2:
color = '#95bd50'
else:
color = '#808080'
a = Arrow3D([0, d[0]], [0, d[1]], [0, d[2]], mutation_scale=20, lw=2, arrowstyle="-|>", color=color)
axes.add_artist(a)
if annotation:
if isinstance(annotation, sp.Basic):
annotation = "$" + sp.latex(annotation) + "$"
else:
annotation = str(annotation)
axes.text(x=d[0] * text_offset, y=d[1] * text_offset, z=d[2] * text_offset,
s=annotation, verticalalignment='center', zorder=30,
size=textsize, bbox=dict(boxstyle=text_box_style, facecolor='#777777', alpha=0.6, linewidth=0))
axes.set_xlim([-text_offset * 1.1, text_offset * 1.1])
axes.set_ylim([-text_offset * 1.1, text_offset * 1.1])
axes.set_zlim([-text_offset * 1.1, text_offset * 1.1])
axes.set_axis_off()
def plot_expression(expr, **kwargs):
"""Displays coefficients of a linear update expression of a single field as matplotlib arrow drawing."""
stencil, coeffs = coefficients(expr)
dim = len(stencil[0])
assert 0 < dim <= 3
if dim == 1:
return coefficient_list(expr, matrix_form=True)
elif dim == 2:
return plot_2d(stencil, data=coeffs, **kwargs)
elif dim == 3:
return plot_3d_slicing(stencil, data=coeffs, **kwargs)