Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 4686 additions and 319 deletions
from itertools import chain
from typing import Callable, List, Sequence, Union
from collections import defaultdict
import sympy as sp
from pystencils.assignment import Assignment
from pystencils.astnodes import Node
from pystencils.field import Field
from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
from pystencils.typing import TypedSymbol
def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
"""Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
edges = []
for c1, e1 in enumerate(assignments):
if hasattr(e1, 'lhs') and hasattr(e1, 'rhs'):
symbols = [e1.lhs]
elif isinstance(e1, Node):
symbols = e1.symbols_defined
else:
raise NotImplementedError(f"Cannot sort topologically. Object of type {type(e1)} cannot be handled.")
for lhs in symbols:
for c2, e2 in enumerate(assignments):
if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols:
edges.append((c1, c2))
elif isinstance(e2, Node) and lhs in e2.undefined_symbols:
edges.append((c1, c2))
return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))]
def sympy_cse(ac, **kwargs):
"""Searches for common subexpressions inside the assignment collection.
Searches is done in both the existing subexpressions as well as the assignments themselves.
It uses the sympy subexpression detection to do this. Return a new assignment collection
with the additional subexpressions found
"""
symbol_gen = ac.subexpression_symbol_generator
all_assignments = [e for e in chain(ac.subexpressions, ac.main_assignments) if isinstance(e, Assignment)]
other_objects = [e for e in chain(ac.subexpressions, ac.main_assignments) if not isinstance(e, Assignment)]
replacements, new_eq = sp.cse(all_assignments, symbols=symbol_gen, **kwargs)
replacement_eqs = [Assignment(*r) for r in replacements]
modified_subexpressions = new_eq[:len(ac.subexpressions)]
modified_update_equations = new_eq[len(ac.subexpressions):]
new_subexpressions = sort_assignments_topologically(other_objects + replacement_eqs + modified_subexpressions)
return ac.copy(modified_update_equations, new_subexpressions)
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
"""Extracts common subexpressions from a list of assignments."""
from pystencils.simp.assignment_collection import AssignmentCollection
ec = AssignmentCollection([], assignments)
return sympy_cse(ec).all_assignments
def subexpression_substitution_in_existing_subexpressions(ac):
"""Goes through the subexpressions list and replaces the term in the following subexpressions."""
result = []
for outer_ctr, s in enumerate(ac.subexpressions):
new_rhs = s.rhs
for inner_ctr in range(outer_ctr):
sub_expr = ac.subexpressions[inner_ctr]
new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs)
result.append(Assignment(s.lhs, new_rhs))
return ac.copy(ac.main_assignments, result)
def subexpression_substitution_in_main_assignments(ac):
"""Replaces already existing subexpressions in the equations of the assignment_collection."""
result = []
for s in ac.main_assignments:
new_rhs = s.rhs
for sub_expr in ac.subexpressions:
new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
result.append(Assignment(s.lhs, new_rhs))
return ac.copy(result)
def add_subexpressions_for_constants(ac):
"""Extracts constant factors to subexpressions in the given assignment collection.
SymPy will exclude common factors from a sum only if they are symbols. This simplification
can be applied to exclude common numeric constants from multiple terms of a sum. As a consequence,
the number of multiplications is reduced and in some cases, more common subexpressions can be found.
"""
constants_to_subexp_dict = defaultdict(lambda: next(ac.subexpression_symbol_generator))
def visit(expr):
args = list(expr.args)
if len(args) == 0:
return expr
if isinstance(expr, sp.Add) or isinstance(expr, sp.Mul):
for i, arg in enumerate(args):
if is_constant(arg) and abs(arg) != 1:
if arg < 0:
args[i] = - constants_to_subexp_dict[- arg]
else:
args[i] = constants_to_subexp_dict[arg]
return expr.func(*(visit(a) for a in args))
main_assignments = [Assignment(a.lhs, visit(a.rhs)) for a in ac.main_assignments]
subexpressions = [Assignment(a.lhs, visit(a.rhs)) for a in ac.subexpressions]
symbols_to_collect = set(constants_to_subexp_dict.values())
main_assignments = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in main_assignments]
subexpressions = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in subexpressions]
subexpressions = [Assignment(symb, c) for c, symb in constants_to_subexp_dict.items()] + subexpressions
return ac.copy(main_assignments=main_assignments, subexpressions=subexpressions)
def add_subexpressions_for_divisions(ac):
r"""Introduces subexpressions for all divisions which have no constant in the denominator.
For example :math:`\frac{1}{x}` is replaced while :math:`\frac{1}{3}` is not replaced.
"""
divisors = set()
def search_divisors(term):
if term.func == sp.Pow:
if term.exp.is_integer and term.exp.is_number and term.exp < 0:
divisors.add(term)
else:
for a in term.args:
search_divisors(a)
for eq in ac.all_assignments:
search_divisors(eq.rhs)
divisors = sorted(list(divisors), key=lambda x: str(x))
new_symbol_gen = ac.subexpression_symbol_generator
substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)}
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)
def add_subexpressions_for_sums(ac):
r"""Introduces subexpressions for all sums - i.e. splits addends into subexpressions."""
addends = []
def contains_sum(term):
if term.func == sp.Add:
return True
if term.is_Atom:
return False
return any([contains_sum(a) for a in term.args])
def search_addends(term):
if term.func == sp.Add:
if all([not contains_sum(a) for a in term.args]):
addends.extend(term.args)
for a in term.args:
search_addends(a)
for eq in ac.all_assignments:
search_addends(eq.rhs)
addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, Field.Access)]
new_symbol_gen = ac.subexpression_symbol_generator
substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)}
return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)
def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True, data_type=None):
r"""Substitutes field accesses on rhs of assignments with subexpressions
Can change semantics of the update rule (which is the goal of this transformation)
This is useful if a field should be update in place - all values are loaded before into subexpression variables,
then the new values are computed and written to the same field in-place.
Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have
this data type. This is useful for mixed precision kernels
"""
field_reads = set()
to_iterate = []
if subexpressions:
to_iterate = chain(to_iterate, ac.subexpressions)
if main_assignments:
to_iterate = chain(to_iterate, ac.main_assignments)
for assignment in to_iterate:
if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'):
field_reads.update(assignment.rhs.atoms(Field.Access))
if not field_reads:
return ac
substitutions = dict()
for fa in field_reads:
lhs = next(ac.subexpression_symbol_generator)
if data_type is not None:
substitutions.update({fa: TypedSymbol(lhs.name, data_type)})
else:
substitutions.update({fa: lhs})
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True,
substitute_on_lhs=False, sort_topologically=False)
def transform_rhs(assignment_list, transformation, *args, **kwargs):
"""Applies a transformation function on the rhs of each element of the passed assignment list
If the list also contains other object, like AST nodes, these are ignored.
Additional parameters are passed to the transformation function"""
return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a
for a in assignment_list]
def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
return [Assignment(transformation(a.lhs, *args, **kwargs),
transformation(a.rhs, *args, **kwargs))
if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a
for a in assignment_list]
def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]):
"""Applies a given operation to all equations in collection."""
def f(ac):
return ac.copy(transform_rhs(ac.main_assignments, operation))
f.__name__ = operation.__name__
return f
def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):
"""Applies the given operation on all subexpressions of the AC."""
def f(ac):
return ac.copy(ac.main_assignments, transform_rhs(ac.subexpressions, operation))
f.__name__ = operation.__name__
return f
# TODO Markus
# make this really work for Assignmentcollections
# this function should ONLY evaluate
# do the optims_c99 elsewhere optionally
# def apply_sympy_optimisations(ac: AssignmentCollection):
# """ Evaluates constant expressions (e.g. :math:`\\sqrt{3}` will be replaced by its floating point representation)
# and applies the default sympy optimisations. See sympy.codegen.rewriting
# """
#
# # Evaluates all constant terms
#
# assignments = ac.all_assignments
#
# evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
# lambda p: p.evalf())
#
# sympy_optimisations = [evaluate_constant_terms] + list(optims_c99)
#
# assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
# if hasattr(a, 'lhs')
# else a for a in assignments]
# assignments_nodes = [a.atoms(SympyAssignment) for a in assignments]
# for a in chain.from_iterable(assignments_nodes):
# a.optimize(sympy_optimisations)
#
# return AssignmentCollection(assignments)
from collections import namedtuple
from typing import Any, Callable, Optional, Sequence
import sympy as sp
from pystencils.simp.assignment_collection import AssignmentCollection
class SimplificationStrategy:
"""A simplification strategy is an ordered collection of simplification rules.
Each simplification is a function taking an assignment collection, and returning a new simplified
assignment collection. The strategy can nicely print intermediate simplification stages and results
to Jupyter notebooks.
"""
def __init__(self):
self._rules = []
def add(self, rule: Callable[[AssignmentCollection], AssignmentCollection]) -> None:
"""Adds the given simplification rule to the end of the collection.
Args:
rule: function that rewrites/simplifies an assignment collection
"""
self._rules.append(rule)
@property
def rules(self):
return self._rules
def apply(self, assignment_collection: AssignmentCollection) -> AssignmentCollection:
"""Runs all rules on the given assignment collection."""
for t in self._rules:
assignment_collection = t(assignment_collection)
return assignment_collection
def __call__(self, assignment_collection: AssignmentCollection) -> AssignmentCollection:
"""Same as apply"""
return self.apply(assignment_collection)
def create_simplification_report(self, assignment_collection: AssignmentCollection) -> Any:
"""Creates a report to be displayed as HTML in a Jupyter notebook.
The simplification report contains the number of operations at each simplification stage together
with the run-time the simplification took.
"""
ReportElement = namedtuple('ReportElement', ['simplificationName', 'runtime', 'adds', 'muls', 'divs', 'total'])
class Report:
def __init__(self):
self.elements = []
def add(self, element):
self.elements.append(element)
def __str__(self):
try:
import tabulate
return tabulate(self.elements, headers=['Name', 'Runtime', 'Adds', 'Muls', 'Divs', 'Total'])
except ImportError:
result = "Name, Adds, Muls, Divs, Runtime\n"
for e in self.elements:
result += ",".join([str(tuple_item) for tuple_item in e]) + "\n"
return result
def _repr_html_(self):
html_table = '<table style="border:none">'
html_table += "<tr><th>Name</th>" \
"<th>Runtime</th>" \
"<th>Adds</th>" \
"<th>Muls</th>" \
"<th>Divs</th>" \
"<th>Total</th></tr>"
line = "<tr><td>{simplificationName}</td>" \
"<td>{runtime}</td> <td>{adds}</td> <td>{muls}</td> <td>{divs}</td> <td>{total}</td> </tr>"
for e in self.elements:
# noinspection PyProtectedMember
html_table += line.format(**e._asdict())
html_table += "</table>"
return html_table
import timeit
report = Report()
op = assignment_collection.operation_count
total = op['adds'] + op['muls'] + op['divs']
report.add(ReportElement("OriginalTerm", '-', op['adds'], op['muls'], op['divs'], total))
for t in self._rules:
start_time = timeit.default_timer()
assignment_collection = t(assignment_collection)
end_time = timeit.default_timer()
op = assignment_collection.operation_count
time_str = f"{(end_time - start_time) * 1000:.2f} ms"
total = op['adds'] + op['muls'] + op['divs']
report.add(ReportElement(t.__name__, time_str, op['adds'], op['muls'], op['divs'], total))
return report
def show_intermediate_results(self, assignment_collection: AssignmentCollection,
symbols: Optional[Sequence[sp.Symbol]] = None) -> Any:
"""Shows the assignment collection after the application of each rule as HTML report for Jupyter notebook.
Args:
assignment_collection: the collection to apply the rules to
symbols: if not None, only the assignments are shown that have one of these symbols as left hand side
"""
class IntermediateResults:
def __init__(self, strategy, collection, restrict_symbols):
self.strategy = strategy
self.assignment_collection = collection
self.restrict_symbols = restrict_symbols
def __str__(self):
def print_assignment_collection(title, c):
text = title
if self.restrict_symbols:
text += "\n".join([str(e) for e in c.new_filtered(self.restrict_symbols).main_assignments])
else:
text += (" " * 3 + (" " * 3).join(str(c).splitlines(True)))
return text
result = print_assignment_collection("Initial Version", self.assignment_collection)
collection = self.assignment_collection
for rule in self.strategy.rules:
collection = rule(collection)
result += print_assignment_collection(rule.__name__, collection)
return result
def _repr_html_(self):
def print_assignment_collection(title, c):
text = f'<h5 style="padding-bottom:10px">{title}</h5> <div style="padding-left:20px;">'
if self.restrict_symbols:
text += "\n".join(["$$" + sp.latex(e) + '$$'
for e in c.new_filtered(self.restrict_symbols).main_assignments])
else:
# noinspection PyProtectedMember
text += c._repr_html_()
text += "</div>"
return text
result = print_assignment_collection("Initial Version", self.assignment_collection)
collection = self.assignment_collection
for rule in self.strategy.rules:
collection = rule(collection)
result += print_assignment_collection(rule.__name__, collection)
return result
return IntermediateResults(self, assignment_collection, symbols)
def __repr__(self):
result = "Simplification Strategy:\n"
for t in self._rules:
result += f" - {t.__name__}\n"
return result
import sympy as sp
from pystencils.sympyextensions import is_constant
# Subexpression Insertion
def insert_subexpressions(ac, selection_callback, skip=None):
"""
Removes a number of subexpressions from an assignment collection by
inserting their right-hand side wherever they occur.
Args:
- selection_callback: Function that is called to qualify subexpressions
for insertion. Should return `True` for any subexpression that is to be
inserted, and `False` otherwise.
- skip: Set of symbols (left-hand sides of subexpressions) that should be
ignored even if qualified by the callback.
"""
if skip is None:
skip = set()
i = 0
while i < len(ac.subexpressions):
exp = ac.subexpressions[i]
if exp.lhs not in skip and selection_callback(exp):
ac = ac.new_with_inserted_subexpression(exp.lhs)
else:
i += 1
return ac
def insert_aliases(ac, **kwargs):
"""Inserts subexpressions that are aliases of other symbols,
i.e. their right-hand side is only another symbol."""
return insert_subexpressions(ac, lambda x: isinstance(x.rhs, sp.Symbol), **kwargs)
def insert_zeros(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is zero."""
zero = sp.Integer(0)
return insert_subexpressions(ac, lambda x: x.rhs == zero, **kwargs)
def insert_constants(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is constant,
i.e. contains no symbols."""
return insert_subexpressions(ac, lambda x: is_constant(x.rhs), **kwargs)
def insert_symbol_times_minus_one(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is just a
negation of another symbol."""
def callback(exp):
rhs = exp.rhs
minus_one = sp.Integer(-1)
atoms = rhs.atoms(sp.Symbol)
return len(atoms) == 1 and rhs == minus_one * atoms.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_constant_multiples(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is a constant
multiplied with another symbol."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
numbers = rhs.atoms(sp.Number)
return len(symbols) == 1 and len(numbers) == 1 and \
rhs == numbers.pop() * symbols.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_constant_additions(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is a sum of a
constant and another symbol."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
numbers = rhs.atoms(sp.Number)
return len(symbols) == 1 and len(numbers) == 1 and \
rhs == numbers.pop() + symbols.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_squares(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is another symbol squared."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
return len(symbols) == 1 and rhs == symbols.pop() ** 2
return insert_subexpressions(ac, callback, **kwargs)
def bind_symbols_to_skip(insertion_function, skip):
return lambda ac: insertion_function(ac, skip=skip)
from pystencils.simp import (SimplificationStrategy, insert_constants, insert_symbol_times_minus_one,
insert_constant_multiples, insert_constant_additions, insert_squares, insert_zeros)
def create_simplification_strategy():
"""
Creates a default simplification `ps.simp.SimplificationStrategy`. The idea behind the default simplification
strategy is to reduce the number of subexpressions by inserting single constants and to evaluate constant
terms beforehand.
"""
s = SimplificationStrategy()
s.add(insert_symbol_times_minus_one)
s.add(insert_constant_multiples)
s.add(insert_constant_additions)
s.add(insert_squares)
s.add(insert_zeros)
s.add(insert_constants)
s.add(lambda ac: ac.new_without_unused_subexpressions())
import sympy as sp
from pystencils.field import create_numpy_array_with_layout, get_layout_of_array
class SliceMaker(object):
def __getitem__(self, item):
return item
make_slice = SliceMaker()
class SlicedGetter(object):
def __init__(self, function_returning_array):
self._functionReturningArray = function_returning_array
def __getitem__(self, item):
return self._functionReturningArray(item)
class SlicedGetterDataHandling:
def __init__(self, data_handling, name):
self.dh = data_handling
self.name = name
def __getitem__(self, slice_obj):
if slice_obj is None:
slice_obj = make_slice[:, :] if self.data_handling.dim == 2 else make_slice[:, :, 0.5]
return self.dh.gather_array(self.name, slice_obj).squeeze()
def normalize_slice(slices, sizes):
"""Converts slices with floating point and/or negative entries to integer slices"""
if len(slices) != len(sizes):
raise ValueError("Slice dimension does not match sizes")
result = []
for s, size in zip(slices, sizes):
if type(s) is int:
if s < 0:
s = size + s
result.append(s)
continue
if type(s) is float:
result.append(int(s * size))
continue
assert (type(s) is slice)
if s.start is None:
new_start = 0
elif type(s.start) is float:
new_start = int(s.start * size)
elif not isinstance(s.start, sp.Basic) and s.start < 0:
new_start = size + s.start
else:
new_start = s.start
if s.stop is None:
new_stop = size
elif type(s.stop) is float:
new_stop = int(s.stop * size)
elif not isinstance(s.stop, sp.Basic) and s.stop < 0:
new_stop = size + s.stop
else:
new_stop = s.stop
result.append(slice(new_start, new_stop, s.step if s.step is not None else 1))
return tuple(result)
def shift_slice(slices, offset):
def shift_slice_component(slice_comp, shift_offset):
if slice_comp is None:
return None
elif isinstance(slice_comp, int):
return slice_comp + shift_offset
elif isinstance(slice_comp, float):
return slice_comp # relative entries are not shifted
elif isinstance(slice_comp, slice):
return slice(shift_slice_component(slice_comp.start, shift_offset),
shift_slice_component(slice_comp.stop, shift_offset),
slice_comp.step)
else:
raise ValueError()
if hasattr(offset, '__len__'):
return tuple(shift_slice_component(k, off) for k, off in zip(slices, offset))
else:
if isinstance(slices, slice) or isinstance(slices, int) or isinstance(slices, float):
return shift_slice_component(slices, offset)
else:
return tuple(shift_slice_component(k, offset) for k in slices)
def slice_from_direction(direction_name, dim, normal_offset=0, tangential_offset=0):
"""
Create a slice from a direction named by compass scheme:
i.e. 'N' for north returns same as make_slice[:, -1]
the naming is:
- x: W, E (west, east)
- y: S, N (south, north)
- z: B, T (bottom, top)
Also combinations are allowed like north-east 'NE'
:param direction_name: name of direction as explained above
:param dim: dimension of the returned slice (should be 2 or 3)
:param normal_offset: the offset in 'normal' direction: e.g. slice_from_direction('N',2, normal_offset=2)
would return make_slice[:, -3]
:param tangential_offset: offset in the other directions: e.g. slice_from_direction('N',2, tangential_offset=2)
would return make_slice[2:-2, -1]
"""
if tangential_offset == 0:
result = [slice(None, None, None)] * dim
else:
result = [slice(tangential_offset, -tangential_offset, None)] * dim
normal_slice_high, normal_slice_low = -1 - normal_offset, normal_offset
for dim_idx, (low_name, high_name) in enumerate([('W', 'E'), ('S', 'N'), ('B', 'T')]):
if low_name in direction_name:
assert high_name not in direction_name, "Invalid direction name"
result[dim_idx] = normal_slice_low
if high_name in direction_name:
assert low_name not in direction_name, "Invalid direction name"
result[dim_idx] = normal_slice_high
return tuple(result)
def remove_ghost_layers(arr, index_dimensions=0, ghost_layers=1):
if ghost_layers <= 0:
return arr
dimensions = len(arr.shape)
spatial_dimensions = dimensions - index_dimensions
indexing = [slice(ghost_layers, -ghost_layers, None), ] * spatial_dimensions
indexing += [slice(None, None, None)] * index_dimensions
return arr[tuple(indexing)]
def add_ghost_layers(arr, index_dimensions=0, ghost_layers=1, layout=None):
dimensions = len(arr.shape)
spatial_dimensions = dimensions - index_dimensions
new_shape = [e + 2 * ghost_layers for e in arr.shape[:spatial_dimensions]] + list(arr.shape[spatial_dimensions:])
if layout is None:
layout = get_layout_of_array(arr)
result = create_numpy_array_with_layout(new_shape, layout)
result.fill(0.0)
indexing = [slice(ghost_layers, -ghost_layers, None), ] * spatial_dimensions
indexing += [slice(None, None, None)] * index_dimensions
result[tuple(indexing)] = arr
return result
def get_slice_before_ghost_layer(direction, ghost_layers=1, thickness=None, full_slice=False):
"""
Returns slicing expression for region before ghost layer
:param direction: tuple specifying direction of slice
:param ghost_layers: number of ghost layers
:param thickness: thickness of the slice, defaults to number of ghost layers
:param full_slice: if true also the ghost cells in directions orthogonal to direction are contained in the
returned slice. Example (d=W ): if full_slice then also the ghost layer in N-S and T-B
are included, otherwise only inner cells are returned
"""
if not thickness:
thickness = ghost_layers
full_slice_inc = ghost_layers if not full_slice else 0
slices = []
for dir_component in direction:
if dir_component == -1:
s = slice(ghost_layers, thickness + ghost_layers)
elif dir_component == 0:
end = -full_slice_inc
s = slice(full_slice_inc, end if end != 0 else None)
elif dir_component == 1:
start = -thickness - ghost_layers
end = -ghost_layers
s = slice(start if start != 0 else None, end if end != 0 else None)
else:
raise ValueError("Invalid direction: only -1, 0, 1 components are allowed")
slices.append(s)
return tuple(slices)
def get_ghost_region_slice(direction, ghost_layers=1, thickness=None, full_slice=False):
"""
Returns slice of ghost region. For parameters see :func:`get_slice_before_ghost_layer`
"""
if not thickness:
thickness = ghost_layers
assert thickness > 0
assert thickness <= ghost_layers
full_slice_inc = ghost_layers if not full_slice else 0
slices = []
for dir_component in direction:
if dir_component == -1:
s = slice(ghost_layers - thickness, ghost_layers)
elif dir_component == 0:
end = -full_slice_inc
s = slice(full_slice_inc, end if end != 0 else None)
elif dir_component == 1:
start = -ghost_layers
end = - ghost_layers + thickness
s = slice(start if start != 0 else None, end if end != 0 else None)
else:
raise ValueError("Invalid direction: only -1, 0, 1 components are allowed")
slices.append(s)
return tuple(slices)
def get_periodic_boundary_src_dst_slices(stencil, ghost_layers=1, thickness=None):
src_dst_slice_tuples = []
for d in stencil:
if sum([abs(e) for e in d]) == 0:
continue
inv_dir = (-e for e in d)
src = get_slice_before_ghost_layer(inv_dir, ghost_layers, thickness=thickness, full_slice=False)
dst = get_ghost_region_slice(d, ghost_layers, thickness=thickness, full_slice=False)
src_dst_slice_tuples.append((src, dst))
return src_dst_slice_tuples
def get_periodic_boundary_functor(stencil, ghost_layers=1, thickness=None):
"""
Returns a function that applies periodic boundary conditions
:param stencil: sequence of directions e.g. ( [0,1], [0,-1] ) for y periodicity
:param ghost_layers: how many ghost layers the array has
:param thickness: how many of the ghost layers to copy, None means 'all'
:return: function that takes a single array and applies the periodic copy operation
"""
src_dst_slice_tuples = get_periodic_boundary_src_dst_slices(stencil, ghost_layers, thickness)
def functor(pdfs, **_):
for src_slice, dst_slice in src_dst_slice_tuples:
pdfs[dst_slice] = pdfs[src_slice]
return functor
def slice_intersection(slice1, slice2):
slice1 = [s if not isinstance(s, int) else slice(s, s + 1, None) for s in slice1]
slice2 = [s if not isinstance(s, int) else slice(s, s + 1, None) for s in slice2]
new_min = [max(s1.start, s2.start) for s1, s2 in zip(slice1, slice2)]
new_max = [min(s1.stop, s2.stop) for s1, s2 in zip(slice1, slice2)]
if any(max_p - min_p < 0 for min_p, max_p in zip(new_min, new_max)):
return None
return [slice(min_p, max_p, None) for min_p, max_p in zip(new_min, new_max)]
import sympy
import pystencils
import pystencils.astnodes
x_, y_, z_ = tuple(pystencils.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(3))
x_staggered, y_staggered, z_staggered = x_ + 0.5, y_ + 0.5, z_ + 0.5
def x_vector(ndim):
return sympy.Matrix(tuple(pystencils.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(ndim)))
def x_staggered_vector(ndim):
return sympy.Matrix(tuple(
pystencils.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) + 0.5 for i in range(ndim)
))
"""This submodule offers functions to work with stencils in expression an offset-list form."""
from collections import defaultdict
from typing import Sequence
import numpy as np
import sympy as sp
from pystencils.utils import binary_numbers
def inverse_direction(direction):
"""Returns inverse i.e. negative of given direction tuple
Example:
>>> inverse_direction((1, -1, 0))
(-1, 1, 0)
"""
return tuple([-i for i in direction])
def inverse_direction_string(direction):
"""Returns inverse of given direction string"""
return offset_to_direction_string(inverse_direction(direction_string_to_offset(direction)))
def is_valid(stencil, max_neighborhood=None):
"""
Tests if a nested sequence is a valid stencil i.e. all the inner sequences have the same length.
If max_neighborhood is specified, it is also verified that the stencil does not contain any direction components
with absolute value greater than the maximal neighborhood.
Examples:
>>> is_valid([(1, 0), (1, 0, 0)]) # stencil entries have different length
False
>>> is_valid([(2, 0), (1, 0)])
True
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=1)
False
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=2)
True
"""
expected_dim = len(stencil[0])
for d in stencil:
if len(d) != expected_dim:
return False
if max_neighborhood is not None:
for d_i in d:
if abs(d_i) > max_neighborhood:
return False
return True
def is_symmetric(stencil):
"""Tests for every direction d, that -d is also in the stencil
Examples:
>>> is_symmetric([(1, 0), (0, 1)])
False
>>> is_symmetric([(1, 0), (-1, 0)])
True
"""
for d in stencil:
if inverse_direction(d) not in stencil:
return False
return True
def have_same_entries(s1, s2):
"""Checks if two stencils are the same
Examples:
>>> stencil1 = [(1, 0), (-1, 0), (0, 1), (0, -1)]
>>> stencil2 = [(-1, 0), (0, -1), (1, 0), (0, 1)]
>>> stencil3 = [(-1, 0), (0, -1), (1, 0)]
>>> have_same_entries(stencil1, stencil2)
True
>>> have_same_entries(stencil1, stencil3)
False
"""
if len(s1) != len(s2):
return False
return len(set(s1) - set(s2)) == 0
# -------------------------------------Expression - Coefficient Form Conversion ----------------------------------------
def coefficient_dict(expr):
"""Extracts coefficients in front of field accesses in a expression.
Expression may only access a single field at a single index.
Returns:
center, coefficient dict, nonlinear part
where center is the single field that is accessed in expression accessed at center
and coefficient dict maps offsets to coefficients. The nonlinear part is everything that is not in the form of
coefficient times field access.
Examples:
>>> import pystencils as ps
>>> f = ps.fields("f(3) : double[2D]")
>>> field, coeffs, nonlinear_part = coefficient_dict(2 * f[0, 1](1) + 3 * f[-1, 0](1) + 123)
>>> assert nonlinear_part == 123 and field == f(1)
>>> sorted(coeffs.items())
[((-1, 0), 3), ((0, 1), 2)]
"""
from pystencils.field import Field
expr = expr.expand()
field_accesses = expr.atoms(Field.Access)
fields = set(fa.field for fa in field_accesses)
accessed_indices = set(fa.index for fa in field_accesses)
if len(fields) != 1:
raise ValueError("Could not extract stencil coefficients. "
"Expression has to be a linear function of exactly one field.")
if len(accessed_indices) != 1:
raise ValueError("Could not extract stencil coefficients. Field is accessed at multiple indices")
field = fields.pop()
idx = accessed_indices.pop()
coeffs = defaultdict(lambda: 0)
coeffs.update({fa.offsets: expr.coeff(fa) for fa in field_accesses})
linear_part = sum(c * field[off](*idx) for off, c in coeffs.items())
nonlinear_part = expr - linear_part
return field(*idx), coeffs, nonlinear_part
def coefficients(expr):
"""Returns two lists - one with accessed offsets and one with their coefficients.
Same restrictions as `coefficient_dict` apply. Expression must not have any nonlinear part
>>> import pystencils as ps
>>> f = ps.fields("f(3) : double[2D]")
>>> coff = coefficients(2 * f[0, 1](1) + 3 * f[-1, 0](1))
"""
field_center, coeffs, nonlinear_part = coefficient_dict(expr)
assert nonlinear_part == 0
stencil = list(coeffs.keys())
entries = [coeffs[c] for c in stencil]
return stencil, entries
def coefficient_list(expr, matrix_form=False):
"""Returns stencil coefficients in the form of nested lists
Same restrictions as `coefficient_dict` apply. Expression must not have any nonlinear part
Examples:
>>> import pystencils as ps
>>> f = ps.fields("f: double[2D]")
>>> coefficient_list(2 * f[0, 1] + 3 * f[-1, 0])
[[0, 0, 0], [3, 0, 0], [0, 2, 0]]
>>> coefficient_list(2 * f[0, 1] + 3 * f[-1, 0], matrix_form=True)
Matrix([
[0, 2, 0],
[3, 0, 0],
[0, 0, 0]])
"""
field_center, coeffs, nonlinear_part = coefficient_dict(expr)
assert nonlinear_part == 0
field = field_center.field
dim = field.spatial_dimensions
max_offsets = defaultdict(lambda: 0)
for offset in coeffs.keys():
for d, off in enumerate(offset):
max_offsets[d] = max(max_offsets[d], abs(off))
if dim == 1:
result = [coeffs[(i,)] for i in range(-max_offsets[0], max_offsets[0] + 1)]
return sp.Matrix(result) if matrix_form else result
else:
y_range = list(range(-max_offsets[1], max_offsets[1] + 1))
if matrix_form:
y_range.reverse()
if dim == 2:
result = [[coeffs[(i, j)]
for i in range(-max_offsets[0], max_offsets[0] + 1)]
for j in y_range]
return sp.Matrix(result) if matrix_form else result
elif dim == 3:
result = [[[coeffs[(i, j, k)]
for i in range(-max_offsets[0], max_offsets[0] + 1)]
for j in y_range]
for k in range(-max_offsets[2], max_offsets[2] + 1)]
return [sp.Matrix(l) for l in result] if matrix_form else result
else:
raise ValueError("Can only handle fields with 1,2 or 3 spatial dimensions")
# ------------------------------------- Point-on-compass notation ------------------------------------------------------
def offset_component_to_direction_string(coordinate_id: int, value: int) -> str:
"""Translates numerical offset to string notation.
x offsets are labeled with east 'E' and 'W',
y offsets with north 'N' and 'S' and
z offsets with top 'T' and bottom 'B'
If the absolute value of the offset is bigger than 1, this number is prefixed.
Args:
coordinate_id: integer 0, 1 or 2 standing for x,y and z
value: integer offset
Examples:
>>> offset_component_to_direction_string(0, 1)
'E'
>>> offset_component_to_direction_string(1, 2)
'2N'
"""
assert 0 <= coordinate_id < 3, "Works only for at most 3D arrays"
name_components = (('W', 'E'), # west, east
('S', 'N'), # south, north
('B', 'T')) # bottom, top
if value == 0:
result = ""
elif value < 0:
result = name_components[coordinate_id][0]
else:
result = name_components[coordinate_id][1]
if abs(value) > 1:
result = "%d%s" % (abs(value), result)
return result
def offset_to_direction_string(offsets: Sequence[int]) -> str:
"""
Translates numerical offset to string notation.
For details see :func:`offset_component_to_direction_string`
Args:
offsets: 3-tuple with x,y,z offset
Examples:
>>> offset_to_direction_string([1, -1, 0])
'SE'
>>> offset_to_direction_string(([-3, 0, -2]))
'2B3W'
"""
if len(offsets) > 3:
return str(offsets)
names = ["", "", ""]
for i in range(len(offsets)):
names[i] = offset_component_to_direction_string(i, offsets[i])
name = "".join(reversed(names))
if name == "":
name = "C"
return name
def direction_string_to_offset(direction: str, dim: int = 3):
"""
Reverse mapping of :func:`offset_to_direction_string`
Args:
direction: string representation of offset
dim: dimension of offset, i.e the length of the returned list
Examples:
>>> direction_string_to_offset('NW', dim=3)
array([-1, 1, 0])
>>> direction_string_to_offset('NW', dim=2)
array([-1, 1])
>>> direction_string_to_offset(offset_to_direction_string((3,-2,1)))
array([ 3, -2, 1])
"""
offset_dict = {
'C': np.array([0, 0, 0]),
'W': np.array([-1, 0, 0]),
'E': np.array([1, 0, 0]),
'S': np.array([0, -1, 0]),
'N': np.array([0, 1, 0]),
'B': np.array([0, 0, -1]),
'T': np.array([0, 0, 1]),
}
offset = np.array([0, 0, 0])
while len(direction) > 0:
factor = 1
first_non_digit = 0
while direction[first_non_digit].isdigit():
first_non_digit += 1
if first_non_digit > 0:
factor = int(direction[:first_non_digit])
direction = direction[first_non_digit:]
cur_offset = offset_dict[direction[0]]
offset += factor * cur_offset
direction = direction[1:]
return offset[:dim]
def adjacent_directions(direction):
"""
Returns all adjacent directions for a direction as tuple of tuples. This is useful for exmple to find all directions
relevant for neighbour communication.
Args:
direction: tuple representing a direction. For example (0, 1, 0) for the northern side
Examples:
>>> adjacent_directions((0, 0, 0))
((0, 0, 0),)
>>> adjacent_directions((0, 1, 0))
((0, 1, 0),)
>>> adjacent_directions((0, 1, 1))
((0, 0, 1), (0, 1, 0), (0, 1, 1))
>>> adjacent_directions((-1, -1))
((-1, -1), (-1, 0), (0, -1))
"""
result = set()
if all(e == 0 for e in direction):
result.add(direction)
return tuple(result)
binary_numbers_list = binary_numbers(len(direction))
for adjacent_direction in binary_numbers_list:
for i, entry in enumerate(direction):
if entry == 0:
adjacent_direction[i] = 0
if entry == -1 and adjacent_direction[i] == 1:
adjacent_direction[i] = -1
if not all(e == 0 for e in adjacent_direction):
result.add(tuple(adjacent_direction))
return tuple(sorted(result))
# -------------------------------------- Visualization -----------------------------------------------------------------
def plot(stencil, **kwargs):
dim = len(stencil[0])
if dim == 2:
plot_2d(stencil, **kwargs)
else:
slicing = False
if 'slice' in kwargs:
slicing = kwargs['slice']
del kwargs['slice']
if slicing:
plot_3d_slicing(stencil, **kwargs)
else:
plot_3d(stencil, **kwargs)
def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs):
"""
Creates a matplotlib 2D plot of the stencil
Args:
stencil: sequence of directions
axes: optional matplotlib axes
figure: optional matplotlib figure
data: data to annotate the directions with, if none given, the indices are used
textsize: size of annotation text
"""
from matplotlib.patches import BoxStyle
import matplotlib.pyplot as plt
if axes is None:
if figure is None:
figure = plt.gcf()
axes = figure.gca()
text_box_style = BoxStyle("Round", pad=0.3)
head_length = 0.1
max_offsets = [max(abs(int(d[c])) for d in stencil) for c in (0, 1)]
if data is None:
data = list(range(len(stencil)))
for direction, annotation in zip(stencil, data):
assert len(direction) == 2, "Works only for 2D stencils"
direction = tuple(int(i) for i in direction)
if not (direction[0] == 0 and direction[1] == 0):
axes.arrow(0, 0, direction[0], direction[1], head_width=0.08, head_length=head_length, color='k')
if isinstance(annotation, sp.Basic):
annotation = "$" + sp.latex(annotation) + "$"
else:
annotation = str(annotation)
def position_correction(d, magnitude=0.18):
if d < 0:
return -magnitude
elif d > 0:
return +magnitude
else:
return 0
text_position = [direction[c] + position_correction(direction[c]) for c in (0, 1)]
axes.text(x=text_position[0], y=text_position[1], s=annotation, verticalalignment='center',
zorder=30, horizontalalignment='center', size=textsize,
bbox=dict(boxstyle=text_box_style, facecolor='#00b6eb', alpha=0.85, linewidth=0))
axes.set_axis_off()
axes.set_aspect('equal')
max_offsets = [m if m > 0 else 0.1 for m in max_offsets]
border = 0.1
axes.set_xlim([-border - max_offsets[0], border + max_offsets[0]])
axes.set_ylim([-border - max_offsets[1], border + max_offsets[1]])
def plot_3d_slicing(stencil, slice_axis=2, figure=None, data=None, **kwargs):
"""Visualizes a 3D, first-neighborhood stencil by plotting 3 slices along a given axis.
Args:
stencil: stencil as sequence of directions
slice_axis: 0, 1, or 2 indicating the axis to slice through
figure: optional matplotlib figure
data: optional data to print as text besides the arrows
"""
import matplotlib.pyplot as plt
for d in stencil:
for element in d:
assert element == -1 or element == 0 or element == 1, "This function can only first neighborhood stencils"
if figure is None:
figure = plt.gcf()
axes = [figure.add_subplot(1, 3, i + 1) for i in range(3)]
splitted_directions = [[], [], []]
splitted_data = [[], [], []]
axes_names = ['x', 'y', 'z']
for i, d in enumerate(stencil):
split_idx = d[slice_axis] + 1
reduced_dir = tuple([element for j, element in enumerate(d) if j != slice_axis])
splitted_directions[split_idx].append(reduced_dir)
splitted_data[split_idx].append(i if data is None else data[i])
for i in range(3):
plot_2d(splitted_directions[i], axes=axes[i], data=splitted_data[i], **kwargs)
for i in [-1, 0, 1]:
axes[i + 1].set_title("Cut at %s=%d" % (axes_names[slice_axis], i), y=1.08)
def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'):
"""
Draws 3D stencil into a 3D coordinate system, parameters are similar to :func:`visualize_stencil_2d`
If data is None, no labels are drawn. To draw the labels as in the 2D case, use ``data=list(range(len(stencil)))``
"""
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
import matplotlib.pyplot as plt
from matplotlib.patches import BoxStyle
from itertools import product, combinations
import numpy as np
class Arrow3D(FancyArrowPatch):
def __init__(self, xs, ys, zs, *args, **kwargs):
FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
self._verts3d = xs, ys, zs
def do_3d_projection(self, *_):
xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
return np.min(zs)
if axes is None:
if figure is None:
figure = plt.figure()
axes = figure.add_subplot(projection='3d')
try:
axes.set_aspect("equal")
except NotImplementedError:
pass
if data is None:
data = [None] * len(stencil)
text_offset = 1.25
text_box_style = BoxStyle("Round", pad=0.3)
# Draw cell (cube)
r = [-1, 1]
for s, e in combinations(np.array(list(product(r, r, r))), 2):
if np.sum(np.abs(s - e)) == r[1] - r[0]:
axes.plot(*zip(s, e), color="k", alpha=0.5)
for d, annotation in zip(stencil, data):
assert len(d) == 3, "Works only for 3D stencils"
d = tuple(int(i) for i in d)
if not (d[0] == 0 and d[1] == 0 and d[2] == 0):
if d[0] == 0:
color = '#348abd'
elif d[1] == 0:
color = '#fac364'
elif sum([abs(d) for d in d]) == 2:
color = '#95bd50'
else:
color = '#808080'
a = Arrow3D([0, d[0]], [0, d[1]], [0, d[2]], mutation_scale=20, lw=2, arrowstyle="-|>", color=color)
axes.add_artist(a)
if annotation:
if isinstance(annotation, sp.Basic):
annotation = "$" + sp.latex(annotation) + "$"
else:
annotation = str(annotation)
axes.text(x=d[0] * text_offset, y=d[1] * text_offset, z=d[2] * text_offset,
s=annotation, verticalalignment='center', zorder=30,
size=textsize, bbox=dict(boxstyle=text_box_style, facecolor='#777777', alpha=0.6, linewidth=0))
axes.set_xlim([-text_offset * 1.1, text_offset * 1.1])
axes.set_ylim([-text_offset * 1.1, text_offset * 1.1])
axes.set_zlim([-text_offset * 1.1, text_offset * 1.1])
axes.set_axis_off()
def plot_expression(expr, **kwargs):
"""Displays coefficients of a linear update expression of a single field as matplotlib arrow drawing."""
stencil, coeffs = coefficients(expr)
dim = len(stencil[0])
assert 0 < dim <= 3
if dim == 1:
return coefficient_list(expr, matrix_form=True)
elif dim == 2:
return plot_2d(stencil, data=coeffs, **kwargs)
elif dim == 3:
return plot_3d_slicing(stencil, data=coeffs, **kwargs)
import itertools
import operator
import warnings
from collections import Counter, defaultdict
from functools import partial, reduce
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union
import sympy as sp
from sympy import PolynomialError
from sympy.functions import Abs
from sympy.core.numbers import Zero
from pystencils.assignment import Assignment
from pystencils.functions import DivFunc
from pystencils.typing import CastFunc, get_type_of_expression, PointerType, VectorType
from pystencils.typing.typed_sympy import FieldPointerSymbol
T = TypeVar('T')
def prod(seq: Iterable[T]) -> T:
"""Takes a sequence and returns the product of all elements"""
return reduce(operator.mul, seq, 1)
def remove_small_floats(expr, threshold):
"""Removes all sp.Float objects whose absolute value is smaller than threshold
>>> expr = sp.sympify("x + 1e-15 * y")
>>> remove_small_floats(expr, 1e-14)
x
"""
if isinstance(expr, sp.Float) and sp.Abs(expr) < threshold:
return 0
else:
new_args = [remove_small_floats(c, threshold) for c in expr.args]
return expr.func(*new_args) if new_args else expr
def is_integer_sequence(sequence: Iterable) -> bool:
"""Checks if all elements of the passed sequence can be cast to integers"""
try:
for i in sequence:
int(i)
return True
except TypeError:
return False
def scalar_product(a: Iterable[T], b: Iterable[T]) -> T:
"""Scalar product between two sequences."""
return sum(a_i * b_i for a_i, b_i in zip(a, b))
def kronecker_delta(*args):
"""Kronecker delta for variable number of arguments, 1 if all args are equal, otherwise 0"""
for a in args:
if a != args[0]:
return 0
return 1
def tanh_step_function_approximation(x, step_location, kind='right', steepness=0.0001):
"""Approximation of step function by a tanh function
>>> tanh_step_function_approximation(1.2, step_location=1.0, kind='right')
1.00000000000000
>>> tanh_step_function_approximation(0.9, step_location=1.0, kind='right')
0
>>> tanh_step_function_approximation(1.1, step_location=1.0, kind='left')
0
>>> tanh_step_function_approximation(0.9, step_location=1.0, kind='left')
1.00000000000000
>>> tanh_step_function_approximation(0.5, step_location=(0, 1), kind='middle')
1
"""
if kind == 'left':
return (1 - sp.tanh((x - step_location) / steepness)) / 2
elif kind == 'right':
return (1 + sp.tanh((x - step_location) / steepness)) / 2
elif kind == 'middle':
x1, x2 = step_location
return 1 - (tanh_step_function_approximation(x, x1, 'left', steepness)
+ tanh_step_function_approximation(x, x2, 'right', steepness))
def multidimensional_sum(i, dim):
"""Multidimensional summation
Example:
>>> list(multidimensional_sum(2, dim=3))
[(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]
"""
prod_args = [range(dim)] * i
return itertools.product(*prod_args)
def normalize_product(product: sp.Expr) -> List[sp.Expr]:
"""Expects a sympy expression that can be interpreted as a product and returns a list of all factors.
Removes sp.Pow nodes that have integer exponent by representing them as single factors in list.
Returns:
* for a Mul node list of factors ('args')
* for a Pow node with positive integer exponent a list of factors
* for other node types [product] is returned
"""
def handle_pow(power):
if power.exp.is_integer and power.exp.is_number and power.exp > 0:
return [power.base] * power.exp
else:
return [power]
if isinstance(product, sp.Pow):
return handle_pow(product)
elif isinstance(product, sp.Mul):
result = []
for a in product.args:
if a.func == sp.Pow:
result += handle_pow(a)
else:
result.append(a)
return result
else:
return [product]
def symmetric_product(*args, with_diagonal: bool = True) -> Iterable:
"""Similar to itertools.product but yields only values where the index is ascending i.e. values below/up to diagonal
Examples:
>>> list(symmetric_product([1, 2, 3], ['a', 'b', 'c']))
[(1, 'a'), (1, 'b'), (1, 'c'), (2, 'b'), (2, 'c'), (3, 'c')]
>>> list(symmetric_product([1, 2, 3], ['a', 'b', 'c'], with_diagonal=False))
[(1, 'b'), (1, 'c'), (2, 'c')]
"""
ranges = [range(len(a)) for a in args]
for idx in itertools.product(*ranges):
valid_index = True
for t in range(1, len(idx)):
if (with_diagonal and idx[t - 1] > idx[t]) or (not with_diagonal and idx[t - 1] >= idx[t]):
valid_index = False
break
if valid_index:
yield tuple(a[i] for a, i in zip(args, idx))
def fast_subs(expression: T, substitutions: Dict,
skip: Optional[Callable[[sp.Expr], bool]] = None) -> T:
"""Similar to sympy subs function.
Args:
expression: expression where parts should be substituted
substitutions: dict defining substitutions by mapping from old to new terms
skip: function that marks expressions to be skipped (if True is returned) - that means that in these skipped
expressions no substitutions are done
This version is much faster for big substitution dictionaries than sympy version
"""
if type(expression) is sp.Matrix:
return expression.copy().applyfunc(partial(fast_subs, substitutions=substitutions))
def visit(expr, evaluate=True):
if skip and skip(expr):
return expr
elif hasattr(expr, "fast_subs"):
return expr.fast_subs(substitutions, skip)
elif expr in substitutions:
return substitutions[expr]
elif not hasattr(expr, 'args'):
return expr
elif isinstance(expr, (sp.UnevaluatedExpr, DivFunc)):
args = [visit(a, False) for a in expr.args]
return expr.func(*args)
else:
param_list = [visit(a, evaluate) for a in expr.args]
if isinstance(expr, (sp.Mul, sp.Add)):
return expr if not param_list else expr.func(*param_list, evaluate=evaluate)
return expr if not param_list else expr.func(*param_list)
if len(substitutions) == 0:
return expression
else:
return visit(expression)
def is_constant(expr):
"""Simple version of checking if a sympy expression is constant.
Works also for piecewise defined functions - sympy's is_constant() has a problem there, see:
https://github.com/sympy/sympy/issues/16662
"""
return len(expr.free_symbols) == 0
def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
required_match_replacement: Optional[Union[int, float]] = 0.5,
required_match_original: Optional[Union[int, float]] = None) -> sp.Expr:
"""Transformation for replacing a given subexpression inside a sum.
Examples:
The next example demonstrates the advantage of replace_additive compared to sympy.subs:
>>> x, y, z, k = sp.symbols("x y z k")
>>> subs_additive(3*x + 3*y, replacement=k, subexpression=x + y)
3*k
Terms that don't match completely can be substituted at the cost of additional terms.
This trade-off is managed using the required_match parameters.
>>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=1.0)
3*x + 3*y + z
>>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=0.5)
3*k - 2*z
>>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=2)
3*k - 2*z
Args:
expr: input expression
replacement: expression that is inserted for subexpression (if found)
subexpression: expression to replace
required_match_replacement:
* if float: the percentage of terms of the subexpression that has to be matched in order to replace
* if integer: the total number of terms that has to be matched in order to replace
* None: is equal to integer 1
* if both match parameters are given, both restrictions have to be fulfilled (i.e. logical AND)
required_match_original:
* if float: the percentage of terms of the original addition expression that has to be matched
* if integer: the total number of terms that has to be matched in order to replace
* None: is equal to integer 1
Returns:
new expression with replacement
"""
def normalize_match_parameter(match_parameter, expression_length):
if match_parameter is None:
return 1
elif isinstance(match_parameter, float):
assert 0 <= match_parameter <= 1
res = int(match_parameter * expression_length)
return max(res, 1)
elif isinstance(match_parameter, int):
assert match_parameter > 0
return match_parameter
raise ValueError("Invalid parameter")
normalized_replacement_match = normalize_match_parameter(required_match_replacement, len(subexpression.args))
if isinstance(subexpression, sp.Number):
return expr.subs({replacement: subexpression})
def visit(current_expr):
if current_expr.is_Add:
expr_max_length = max(len(current_expr.args), len(subexpression.args))
normalized_current_expr_match = normalize_match_parameter(required_match_original, expr_max_length)
expr_coefficients = current_expr.as_coefficients_dict()
subexpression_coefficient_dict = subexpression.as_coefficients_dict()
intersection = set(subexpression_coefficient_dict.keys()).intersection(set(expr_coefficients))
if len(intersection) >= max(normalized_replacement_match, normalized_current_expr_match):
# find common factor
factors = defaultdict(int)
skips = 0
for common_symbol in subexpression_coefficient_dict.keys():
if common_symbol not in expr_coefficients:
skips += 1
continue
factor = expr_coefficients[common_symbol] / subexpression_coefficient_dict[common_symbol]
factors[sp.simplify(factor)] += 1
common_factor = max(factors.items(), key=operator.itemgetter(1))[0]
if factors[common_factor] >= max(normalized_current_expr_match, normalized_replacement_match):
return current_expr - common_factor * subexpression + common_factor * replacement
# if no subexpression was found
param_list = [visit(a) for a in current_expr.args]
if not param_list:
return current_expr
else:
if current_expr.func == sp.Mul and Zero() in param_list:
return sp.simplify(current_expr)
else:
return current_expr.func(*param_list, evaluate=False)
return visit(expr)
def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol],
positive: Optional[bool] = None,
replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr:
"""Replaces second order mixed terms like 4*x*y by 2*( (x+y)**2 - x**2 - y**2 ).
This makes the term longer - simplify usually is undoing these - however this
transformation can be done to find more common sub-expressions
Args:
expr: input expression
search_symbols: symbols that are searched for
for example, given [x,y,z] terms like x*y, x*z, z*y are replaced
positive: there are two ways to do this substitution, either with term
(x+y)**2 or (x-y)**2 . if positive=True the first version is done,
if positive=False the second version is done, if positive=None the
sign is determined by the sign of the mixed term that is replaced
replace_mixed: if a list is passed here, the expr x+y or x-y is replaced by a special new symbol
and the replacement equation is added to the list
"""
mixed_symbols_replaced = set([e.lhs for e in replace_mixed]) if replace_mixed is not None else set()
if expr.is_Mul:
distinct_search_symbols = set()
nr_of_search_terms = 0
other_factors = sp.Integer(1)
for t in expr.args:
if t in search_symbols:
nr_of_search_terms += 1
distinct_search_symbols.add(t)
else:
other_factors *= t
if len(distinct_search_symbols) == 2 and nr_of_search_terms == 2:
u, v = sorted(list(distinct_search_symbols), key=lambda symbol: symbol.name)
if positive is None:
other_factors_without_symbols = other_factors
for s in other_factors.atoms(sp.Symbol):
other_factors_without_symbols = other_factors_without_symbols.subs(s, 1)
positive = other_factors_without_symbols.is_positive
assert positive is not None
sign = 1 if positive else -1
if replace_mixed is not None:
new_symbol_str = 'P' if positive else 'M'
mixed_symbol_name = u.name + new_symbol_str + v.name
mixed_symbol = sp.Symbol(mixed_symbol_name.replace("_", ""))
if mixed_symbol not in mixed_symbols_replaced:
mixed_symbols_replaced.add(mixed_symbol)
replace_mixed.append(Assignment(mixed_symbol, u + sign * v))
else:
mixed_symbol = u + sign * v
return sp.Rational(1, 2) * sign * other_factors * (mixed_symbol ** 2 - u ** 2 - v ** 2)
param_list = [replace_second_order_products(a, search_symbols, positive, replace_mixed) for a in expr.args]
result = expr.func(*param_list, evaluate=False) if param_list else expr
return result
def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order: int = 3) -> sp.Expr:
"""Removes all terms that contain more than 'order' factors of given 'symbols'
Example:
>>> x, y = sp.symbols("x y")
>>> term = x**2 * y + y**2 * x + y**3 + x + y ** 2
>>> remove_higher_order_terms(term, order=2, symbols=[x, y])
x + y**2
"""
from sympy.core.power import Pow
from sympy.core.add import Add, Mul
result = 0
expr = expr.expand()
def velocity_factors_in_product(product):
factor_count = 0
if type(product) is Mul:
for factor in product.args:
if type(factor) is Pow:
if factor.args[0] in symbols:
factor_count += factor.args[1]
if factor in symbols:
factor_count += 1
elif type(product) is Pow:
if product.args[0] in symbols:
factor_count += product.args[1]
return factor_count
if type(expr) is Mul or type(expr) is Pow:
if velocity_factors_in_product(expr) <= order:
return expr
else:
return Zero()
if type(expr) is not Add:
return expr
for sum_term in expr.args:
if velocity_factors_in_product(sum_term) <= order:
result += sum_term
return result
def complete_the_square(expr: sp.Expr, symbol_to_complete: sp.Symbol,
new_variable: sp.Symbol) -> Tuple[sp.Expr, Optional[Tuple[sp.Symbol, sp.Expr]]]:
"""Transforms second order polynomial into only squared part.
Examples:
>>> a, b, c, s, n = sp.symbols("a b c s n")
>>> expr = a * s**2 + b * s + c
>>> completed_expr, substitution = complete_the_square(expr, symbol_to_complete=s, new_variable=n)
>>> completed_expr
a*n**2 + c - b**2/(4*a)
>>> substitution
(n, s + b/(2*a))
Returns:
(replaced_expr, tuple to pass to subs, such that old expr comes out again)
"""
p = sp.Poly(expr, symbol_to_complete)
coefficients = p.all_coeffs()
if len(coefficients) != 3:
return expr, None
a, b, _ = coefficients
expr = expr.subs(symbol_to_complete, new_variable - b / (2 * a))
return sp.simplify(expr), (new_variable, symbol_to_complete + b / (2 * a))
def complete_the_squares_in_exp(expr: sp.Expr, symbols_to_complete: Sequence[sp.Symbol]):
"""Completes squares in arguments of exponential which makes them simpler to integrate.
Very useful for integrating Maxwell-Boltzmann equilibria and its moment generating function
"""
dummies = [sp.Dummy() for _ in symbols_to_complete]
def visit(term):
if term.func == sp.exp:
exp_arg = term.args[0]
for symbol_to_complete, dummy in zip(symbols_to_complete, dummies):
exp_arg, substitution = complete_the_square(exp_arg, symbol_to_complete, dummy)
return sp.exp(sp.expand(exp_arg))
else:
param_list = [visit(a) for a in term.args]
if not param_list:
return term
else:
return term.func(*param_list)
result = visit(expr)
for s, d in zip(symbols_to_complete, dummies):
result = result.subs(d, s)
return result
def extract_most_common_factor(term):
"""Processes a sum of fractions: determines the most common factor and splits term in common factor and rest"""
coefficient_dict = term.as_coefficients_dict()
counter = Counter([Abs(v) for v in coefficient_dict.values()])
common_factor, occurrences = max(counter.items(), key=operator.itemgetter(1))
if occurrences == 1 and (1 in counter):
common_factor = 1
return common_factor, term / common_factor
def recursive_collect(expr, symbols, order_by_occurences=False):
"""Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1,
and so on.
``expr`` must be rewritable as a polynomial in the given ``symbols``.
It it is not, ``recursive_collect`` will fail quietly, returning the original expression.
Args:
expr: A sympy expression.
symbols: A sequence of symbols
order_by_occurences: If True, during recursive descent, always collect the symbol occuring
most often in the expression.
"""
if order_by_occurences:
symbols = list(expr.atoms(sp.Symbol) & set(symbols))
symbols = sorted(symbols, key=expr.count, reverse=True)
if len(symbols) == 0:
return expr
symbol = symbols[0]
collected = expr.collect(symbol)
try:
collected_poly = sp.Poly(collected, symbol)
except PolynomialError:
return expr
coeffs = collected_poly.all_coeffs()[::-1]
rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs))
return rec_sum
def summands(expr):
return set(expr.args) if isinstance(expr, sp.Add) else {expr}
def simplify_by_equality(expr, a, b, c):
"""
Uses the equality a = b + c, where a and b must be symbols, to simplify expr
by attempting to express additive combinations of two quantities by the third.
This works on expressions that are reducible to the form
:math:`a * (...) + b * (...) + c * (...)`,
without any mixed terms of a, b and c.
"""
if not isinstance(a, sp.Symbol) or not isinstance(b, sp.Symbol):
raise ValueError("a and b must be symbols.")
c = sp.sympify(c)
if not (isinstance(c, sp.Symbol) or is_constant(c)):
raise ValueError("c must be either a symbol or a constant!")
expr = sp.sympify(expr)
expr_expanded = sp.expand(expr)
a_coeff = expr_expanded.coeff(a, 1)
expr_expanded -= (a * a_coeff).expand()
b_coeff = expr_expanded.coeff(b, 1)
expr_expanded -= (b * b_coeff).expand()
if isinstance(c, sp.Symbol):
c_coeff = expr_expanded.coeff(c, 1)
rest = expr_expanded - (c * c_coeff).expand()
else:
c_coeff = expr_expanded / c
rest = 0
a_summands = summands(a_coeff)
b_summands = summands(b_coeff)
c_summands = summands(c_coeff)
# replace b + c by a
b_plus_c_coeffs = b_summands & c_summands
for coeff in b_plus_c_coeffs:
rest += a * coeff
b_summands -= b_plus_c_coeffs
c_summands -= b_plus_c_coeffs
# replace a - b by c
neg_b_summands = {-x for x in b_summands}
a_minus_b_coeffs = a_summands & neg_b_summands
for coeff in a_minus_b_coeffs:
rest += c * coeff
a_summands -= a_minus_b_coeffs
b_summands -= {-x for x in a_minus_b_coeffs}
# replace a - c by b
neg_c_summands = {-x for x in c_summands}
a_minus_c_coeffs = a_summands & neg_c_summands
for coeff in a_minus_c_coeffs:
rest += b * coeff
a_summands -= a_minus_c_coeffs
c_summands -= {-x for x in a_minus_c_coeffs}
# put it back together
return (rest + a * sum(a_summands) + b * sum(b_summands) + c * sum(c_summands)).expand()
def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
only_type: Optional[str] = 'real') -> Dict[str, int]:
"""Counts the number of additions, multiplications and division.
Args:
term: a sympy expression (term, assignment) or sequence of sympy objects
only_type: 'real' or 'int' to count only operations on these types, or None for all
Returns:
dict with 'adds', 'muls' and 'divs' keys
"""
from pystencils.fast_approximation import fast_sqrt, fast_inv_sqrt, fast_division
result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0,
'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0}
if isinstance(term, Sequence):
for element in term:
r = count_operations(element, only_type)
for operation_name in result.keys():
result[operation_name] += r[operation_name]
return result
elif isinstance(term, Assignment):
term = term.rhs
def check_type(e):
if only_type is None:
return True
if isinstance(e, FieldPointerSymbol) and only_type == "real":
return only_type == "int"
try:
base_type = get_type_of_expression(e)
except ValueError:
return False
if isinstance(base_type, VectorType):
return False
if isinstance(base_type, PointerType):
return only_type == 'int'
if only_type == 'int' and (base_type.is_int() or base_type.is_uint()):
return True
if only_type == 'real' and (base_type.is_float()):
return True
else:
return base_type == only_type
def visit(t):
visit_children = True
if t.func is sp.Add:
if check_type(t):
result['adds'] += len(t.args) - 1
elif t.func in [sp.Or, sp.And]:
pass
elif t.func is sp.Mul:
if check_type(t):
result['muls'] += len(t.args) - 1
for a in t.args:
if a == 1 or a == -1:
result['muls'] -= 1
elif isinstance(t, sp.Float) or isinstance(t, sp.Rational):
pass
elif isinstance(t, sp.Symbol):
visit_children = False
elif isinstance(t, sp.Indexed):
visit_children = False
elif t.is_integer:
pass
elif isinstance(t, CastFunc):
visit_children = False
visit(t.args[0])
elif t.func is fast_sqrt:
result['fast_sqrts'] += 1
elif t.func is fast_inv_sqrt:
result['fast_inv_sqrts'] += 1
elif t.func is fast_division:
result['fast_div'] += 1
elif t.func is sp.Pow:
if check_type(t.args[0]):
visit_children = True
if t.exp.is_integer and t.exp.is_number:
if t.exp >= 0:
result['muls'] += int(t.exp) - 1
else:
if result['muls'] > 0:
result['muls'] -= 1
result['divs'] += 1
result['muls'] += (-int(t.exp)) - 1
elif sp.nsimplify(t.exp) == sp.Rational(1, 2):
result['sqrts'] += 1
elif sp.nsimplify(t.exp) == -sp.Rational(1, 2):
result["sqrts"] += 1
result["divs"] += 1
else:
warnings.warn(f"Cannot handle exponent {t.exp} of sp.Pow node")
else:
warnings.warn("Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate")
elif t.func is sp.Piecewise:
for child_term, condition in t.args:
visit(child_term)
visit_children = False
elif isinstance(t, (sp.Rel, sp.UnevaluatedExpr)):
pass
elif isinstance(t, DivFunc):
result["divs"] += 1
else:
warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate")
if visit_children:
for a in t.args:
visit(a)
visit(term)
return result
def count_operations_in_ast(ast) -> Dict[str, int]:
"""Counts number of operations in an abstract syntax tree, see also :func:`count_operations`"""
from pystencils.astnodes import SympyAssignment
result = defaultdict(int)
def visit(node):
if isinstance(node, SympyAssignment):
r = count_operations(node.rhs)
for k, v in r.items():
result[k] += v
else:
for arg in node.args:
visit(arg)
visit(ast)
return result
def common_denominator(expr: sp.Expr) -> sp.Expr:
"""Finds least common multiple of all denominators occurring in an expression"""
denominators = [r.q for r in expr.atoms(sp.Rational)]
return sp.lcm(denominators)
def get_symmetric_part(expr: sp.Expr, symbols: Iterable[sp.Symbol]) -> sp.Expr:
"""
Returns the symmetric part of a sympy expressions.
Args:
expr: sympy expression, labeled here as :math:`f`
symbols: sequence of symbols which are considered as degrees of freedom, labeled here as :math:`x_0, x_1,...`
Returns:
:math:`\frac{1}{2} [ f(x_0, x_1, ..) + f(-x_0, -x_1) ]`
"""
substitution_dict = {e: -e for e in symbols}
return sp.Rational(1, 2) * (expr + expr.subs(substitution_dict))
class SymbolCreator:
def __getattribute__(self, name):
return sp.Symbol(name)
import time
from pystencils.integer_functions import modulo_ceil
class TimeLoop:
def __init__(self, steps=2):
self._call_data = []
self._fixed_steps = steps
self._pre_run_functions = []
self._post_run_functions = []
self._single_step_functions = []
self.time_steps_run = 0
@property
def fixed_steps(self):
return self._fixed_steps
def add_pre_run_function(self, f):
self._pre_run_functions.append(f)
def add_post_run_function(self, f):
self._post_run_functions.append(f)
def add_single_step_function(self, f):
self._single_step_functions.append(f)
def add_call(self, functor, argument_list):
if hasattr(functor, 'kernel'):
functor = functor.kernel
if not isinstance(argument_list, list):
argument_list = [argument_list]
for argument_dict in argument_list:
self._call_data.append((functor, argument_dict))
def pre_run(self):
for f in self._pre_run_functions:
f()
def post_run(self):
for f in self._post_run_functions:
f()
def run(self, time_steps=1):
self.pre_run()
fixed_steps = self._fixed_steps
call_data = self._call_data
main_iterations, rest_iterations = divmod(time_steps, fixed_steps)
try:
for _ in range(main_iterations):
for func, kwargs in call_data:
func(**kwargs)
self.time_steps_run += fixed_steps
for _ in range(rest_iterations):
for func in self._single_step_functions:
func()
self.time_steps_run += 1
except KeyboardInterrupt:
pass
self.post_run()
def benchmark_run(self, time_steps=0, init_time_steps=0):
init_time_steps_rounded = modulo_ceil(init_time_steps, self._fixed_steps)
time_steps_rounded = modulo_ceil(time_steps, self._fixed_steps)
call_data = self._call_data
self.pre_run()
for i in range(init_time_steps_rounded // self._fixed_steps):
for func, kwargs in call_data:
func(**kwargs)
self.time_steps_run += init_time_steps_rounded
start = time.perf_counter()
for i in range(time_steps_rounded // self._fixed_steps):
for func, kwargs in call_data:
func(**kwargs)
end = time.perf_counter()
self.time_steps_run += time_steps_rounded
self.post_run()
time_for_one_iteration = (end - start) / time_steps
return time_for_one_iteration
def run_time_span(self, seconds):
iterations = 0
self.pre_run()
start = time.perf_counter()
while time.perf_counter() < start + seconds:
for func, kwargs in self._call_data:
func(**kwargs)
iterations += self._fixed_steps
end = time.perf_counter()
self.post_run()
self.time_steps_run += iterations
return iterations, end - start
def benchmark(self, time_for_benchmark=5, init_time_steps=2, number_of_time_steps_for_estimation='auto'):
"""Returns the time in seconds for one time step.
Args:
time_for_benchmark: number of seconds benchmark should take
init_time_steps: number of time steps run initially for warm up, to get arrays into cache etc
number_of_time_steps_for_estimation: time steps run before real benchmarks, to determine number of time
steps that approximately take 'time_for_benchmark' or 'auto'
"""
# Run a few time step to get first estimate
if number_of_time_steps_for_estimation == 'auto':
self.run(1)
iterations, total_time = self.run_time_span(0.5)
duration_of_time_step = total_time / iterations
else:
duration_of_time_step = self.benchmark_run(number_of_time_steps_for_estimation, init_time_steps)
# Run for approximately 'time_for_benchmark' seconds
time_steps = int(time_for_benchmark / duration_of_time_step)
time_steps = max(time_steps, 4)
return self.benchmark_run(time_steps, init_time_steps)
import hashlib
import pickle
import warnings
from collections import OrderedDict
from copy import deepcopy
from types import MappingProxyType
from typing import Set
import sympy as sp
import pystencils as ps
import pystencils.astnodes as ast
from pystencils.assignment import Assignment
from pystencils.typing import (CastFunc, PointerType, StructType, TypedSymbol, get_base_type,
ReinterpretCastFunc, get_next_parent_of_type, parents_of_type)
from pystencils.field import Field, FieldType
from pystencils.typing import FieldPointerSymbol
from pystencils.sympyextensions import fast_subs
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.slicing import normalize_slice
from pystencils.integer_functions import int_div
class NestedScopes:
"""Symbol visibility model using nested scopes
- every accessed symbol that was not defined before, is added as a "free parameter"
- free parameters are global, i.e. they are not in scopes
- push/pop adds or removes a scope
>>> s = NestedScopes()
>>> s.access_symbol("a")
>>> s.is_defined("a")
False
>>> s.free_parameters
{'a'}
>>> s.define_symbol("b")
>>> s.is_defined("b")
True
>>> s.push()
>>> s.is_defined_locally("b")
False
>>> s.define_symbol("c")
>>> s.pop()
>>> s.is_defined("c")
False
"""
def __init__(self):
self.free_parameters = set()
self._defined = [set()]
def access_symbol(self, symbol):
if not self.is_defined(symbol):
self.free_parameters.add(symbol)
def define_symbol(self, symbol):
self._defined[-1].add(symbol)
def is_defined(self, symbol):
return any(symbol in scopes for scopes in self._defined)
def is_defined_locally(self, symbol):
return symbol in self._defined[-1]
def push(self):
self._defined.append(set())
def pop(self):
self._defined.pop()
assert self.depth >= 1
@property
def depth(self):
return len(self._defined)
def filtered_tree_iteration(node, node_type, stop_type=None):
for arg in node.args:
if isinstance(arg, node_type):
yield arg
elif stop_type and isinstance(node, stop_type):
continue
yield from filtered_tree_iteration(arg, node_type)
def generic_visit(term, visitor):
if isinstance(term, AssignmentCollection):
new_main_assignments = generic_visit(term.main_assignments, visitor)
new_subexpressions = generic_visit(term.subexpressions, visitor)
return term.copy(new_main_assignments, new_subexpressions)
elif isinstance(term, list):
return [generic_visit(e, visitor) for e in term]
elif isinstance(term, Assignment):
return Assignment(term.lhs, generic_visit(term.rhs, visitor))
elif isinstance(term, sp.Matrix):
return term.applyfunc(lambda e: generic_visit(e, visitor))
else:
return visitor(term)
def iterate_loops_by_depth(node, nesting_depth):
"""Iterate all LoopOverCoordinate nodes in the given AST of the specified nesting depth.
Args:
node: Root node of the abstract syntax tree
nesting_depth: Nesting depth of the loops the pragmas should be applied to.
Outermost loop has depth 0.
A depth of -1 indicates the innermost loops.
Returns: Iterable listing all loop nodes of given nesting depth.
"""
from pystencils.astnodes import LoopOverCoordinate
def _internal_default(node, nesting_depth):
isloop = isinstance(node, LoopOverCoordinate)
if nesting_depth < 0: # here, a negative value indicates end of descent
return
elif nesting_depth == 0 and isloop:
yield node
else:
next_depth = nesting_depth - 1 if isloop else nesting_depth
for arg in node.args:
yield from _internal_default(arg, next_depth)
def _internal_innermost(node):
if isinstance(node, LoopOverCoordinate) and node.is_innermost_loop:
yield node
else:
for arg in node.args:
yield from _internal_innermost(arg)
if nesting_depth >= 0:
yield from _internal_default(node, nesting_depth)
elif nesting_depth == -1:
yield from _internal_innermost(node)
else:
raise ValueError(f"Invalid nesting depth: {nesting_depth}. Choose a nonnegative number, or -1.")
def unify_shape_symbols(body, common_shape, fields):
"""Replaces symbols for array sizes to ensure they are represented by the same unique symbol.
When creating a kernel with variable array sizes, all passed arrays must have the same size.
This is ensured when the kernel is called. Inside the kernel this means that only on symbol has to be used instead
of one for each field. For example shape_arr1[0] and shape_arr2[0] must be equal, so they should also be
represented by the same symbol.
Args:
body: ast node, for the kernel part where substitutions is made, is modified in-place
common_shape: shape of the field that was chosen
fields: all fields whose shapes should be replaced by common_shape
"""
substitutions = {}
for field in fields:
assert len(field.spatial_shape) == len(common_shape)
if not field.has_fixed_shape:
for common_shape_component, shape_component in zip(common_shape, field.spatial_shape):
if shape_component != common_shape_component:
substitutions[shape_component] = common_shape_component
if substitutions:
body.subs(substitutions)
def get_common_field(field_set):
"""Takes a set of pystencils Fields, checks if a common spatial shape exists and returns one
representative field, that can be used for shape information etc. in the kernel creation.
If the fields have different shapes ValueError is raised"""
nr_of_fixed_shaped_fields = 0
for f in field_set:
if f.has_fixed_shape:
nr_of_fixed_shaped_fields += 1
if nr_of_fixed_shaped_fields > 0 and nr_of_fixed_shaped_fields != len(field_set):
fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape])
var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape])
msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
msg += f"Variable shaped: {var_field_names} \nFixed shaped: {fixed_field_names}"
raise ValueError(msg)
shape_set = set([f.spatial_shape for f in field_set])
if nr_of_fixed_shaped_fields == len(field_set):
if len(shape_set) != 1:
raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
# Sort the fields by their name to ensure that always the same field is returned
reference_field = sorted(field_set, key=lambda e: str(e))[0]
return reference_field
def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None):
"""Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.
Args:
body: Block object with inner loop contents
iteration_slice: if not None, iteration is done only over this slice of the field
ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
if None, the number of ghost layers is determined automatically and assumed to be equal for a
all dimensions
loop_order: loop ordering from outer to inner loop (optimal ordering is same as layout)
Returns:
tuple of loop-node, ghost_layer_info
"""
# find correct ordering by inspecting participating FieldAccesses
absolut_accesses_only = False
field_accesses = body.atoms(Field.Access)
field_accesses = {e for e in field_accesses if not e.is_absolute_access}
if len(field_accesses) == 0: # when kernel contains only absolute accesses
absolut_accesses_only = True
# exclude accesses to buffers from field_list, because buffers are treated separately
field_list = [e.field for e in field_accesses if not (FieldType.is_buffer(e.field) or FieldType.is_custom(e.field))]
if len(field_list) == 0: # when kernel contains only custom fields
field_list = [e.field for e in field_accesses if not (FieldType.is_buffer(e.field))]
fields = set(field_list)
if loop_order is None:
loop_order = get_optimal_loop_ordering(fields)
if absolut_accesses_only:
absolut_access_fields = {e.field for e in body.atoms(Field.Access)}
common_field = get_common_field(absolut_access_fields)
common_shape = common_field.spatial_shape
else:
common_field = get_common_field(fields)
common_shape = common_field.spatial_shape
unify_shape_symbols(body, common_shape=common_shape, fields=fields)
if iteration_slice is not None:
iteration_slice = normalize_slice(iteration_slice, common_shape)
if ghost_layers is None:
if absolut_accesses_only:
required_ghost_layers = 0
else:
required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
if isinstance(ghost_layers, int):
ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)
current_body = body
for i, loop_coordinate in enumerate(reversed(loop_order)):
if iteration_slice is None:
begin = ghost_layers[loop_coordinate][0]
end = common_shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
current_body = ast.Block([new_loop])
else:
slice_component = iteration_slice[loop_coordinate]
if type(slice_component) is slice:
sc = slice_component
new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step)
current_body = ast.Block([new_loop])
else:
assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate),
sp.sympify(slice_component))
current_body.insert_front(assignment)
return current_body, ghost_layers
def get_common_indexed_element(indexed_elements: Set[sp.IndexedBase]) -> sp.IndexedBase:
assert len(indexed_elements) > 0, "indexed_elements can not be empty"
shape_set = {s.shape for s in indexed_elements}
if len(shape_set) != 1:
for shape in shape_set:
assert not isinstance(shape, int), "If indexed elements are used, they must all have the same shape"
return sorted(indexed_elements, key=lambda e: str(e))[0]
def add_outer_loop_over_indexed_elements(loop_node: ast.Block) -> ast.Block:
indexed_elements = loop_node.atoms(sp.Indexed)
if len(indexed_elements) == 0:
return loop_node
reference_element = get_common_indexed_element(indexed_elements)
index = reference_element.indices[0].atoms(TypedSymbol)
assert len(index) == 1, "index expressions must only contain one symbol representing the index"
new_loop = ast.LoopOverCoordinate(loop_node, 0, 0,
reference_element.shape[0], 1, custom_loop_ctr=index.pop())
return ast.Block([new_loop])
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
r"""
Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
Returns a new typed symbol, where the name encodes which coordinates have been resolved.
Args:
field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
coordinates: mapping of coordinate ids to its value, where stride*value is calculated
previous_ptr: the pointer which is de-referenced
Returns
tuple with the new pointer symbol and the calculated offset
Examples:
>>> field = Field.create_generic('myfield', spatial_dimensions=2, index_dimensions=1)
>>> x, y = sp.symbols("x y")
>>> prev_pointer = TypedSymbol("ptr", "double")
>>> create_intermediate_base_pointer(field[1,-2](5), {0: x}, prev_pointer)
(ptr_01, _stride_myfield_0*x + _stride_myfield_0)
>>> create_intermediate_base_pointer(field[1,-2](5), {0: x, 1 : y }, prev_pointer)
(ptr_01_1m2, _stride_myfield_0*x + _stride_myfield_0 + _stride_myfield_1*y - 2*_stride_myfield_1)
"""
field = field_access.field
offset = 0
name = ""
list_to_hash = []
for coordinate_id, coordinate_value in coordinates.items():
offset += field.strides[coordinate_id] * coordinate_value
if coordinate_id < field.spatial_dimensions:
offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id]
if field_access.offsets[coordinate_id].is_Integer:
name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id])
else:
list_to_hash.append(field_access.offsets[coordinate_id])
else:
if type(coordinate_value) is int:
name += "_%d%d" % (coordinate_id, coordinate_value)
else:
list_to_hash.append(coordinate_value)
if len(list_to_hash) > 0:
name += hashlib.md5(pickle.dumps(list_to_hash)).hexdigest()[:16]
name = name.replace("-", 'm')
new_ptr = TypedSymbol(previous_ptr.name + name, previous_ptr.dtype)
return new_ptr, offset
def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dimensions, index_dimensions):
"""
Creates base pointer specification for :func:`resolve_field_accesses` function.
Specification of how many and which intermediate pointers are created for a field access.
For example [ (0), (2,3,)] creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
zero directly in the field access. These specifications are defined dependent on the loop ordering.
This function translates more readable version into the specification above.
Allowed specifications:
- "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
spatialInner1 the loop enclosing the innermost
- "spatialOuter<int>" spatialOuter0 is the outermost loop
- "index<int>": index coordinate
- "<int>": specifying directly the coordinate
Args:
base_pointer_specification: nested list with above specifications
loop_order: list with ordering of loops from outer to inner
spatial_dimensions: number of spatial dimensions
index_dimensions: number of index dimensions
Returns:
list of tuples that can be passed to :func:`resolve_field_accesses`
Examples:
>>> parse_base_pointer_info([['spatialOuter0'], ['index0']], loop_order=[2,1,0],
... spatial_dimensions=3, index_dimensions=1)
[[0], [3], [1, 2]]
"""
result = []
specified_coordinates = set()
loop_order = list(reversed(loop_order))
for spec_group in base_pointer_specification:
new_group = []
def add_new_element(elem):
if elem >= spatial_dimensions + index_dimensions:
raise ValueError("Coordinate %d does not exist" % (elem,))
new_group.append(elem)
if elem in specified_coordinates:
raise ValueError("Coordinate %d specified two times" % (elem,))
specified_coordinates.add(elem)
for element in spec_group:
if type(element) is int:
add_new_element(element)
elif element.startswith("spatial"):
element = element[len("spatial"):]
if element.startswith("Inner"):
index = int(element[len("Inner"):])
add_new_element(loop_order[index])
elif element.startswith("Outer"):
index = int(element[len("Outer"):])
add_new_element(loop_order[-index])
elif element == "all":
for i in range(spatial_dimensions):
add_new_element(i)
else:
raise ValueError("Could not parse " + element)
elif element.startswith("index"):
index = int(element[len("index"):])
add_new_element(spatial_dimensions + index)
else:
raise ValueError(f"Unknown specification {element}")
result.append(new_group)
all_coordinates = set(range(spatial_dimensions + index_dimensions))
rest = all_coordinates - specified_coordinates
if rest:
result.append(list(rest))
return result
def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
"""Used for buffer fields to determine the linearized index of the buffer dependent on loop counter symbols.
Args:
ast_node: ast before any field accesses are resolved
loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
for GPU kernels: list of 'loop counters' from inner to outer loop
loop_iterations: iteration slice for each loop from inner to outer, for CPU kernels leave to default
Returns:
base buffer index - required by 'resolve_buffer_accesses' function
"""
if loop_counters is None or loop_iterations is None:
loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)]
loops.reverse()
parents_of_innermost_loop = list(parents_of_type(loops[0], ast.LoopOverCoordinate, include_current=True))
assert len(loops) == len(parents_of_innermost_loop)
assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
loop_counters = [loop.loop_counter_symbol for loop in loops]
loop_iterations = [slice(loop.start, loop.stop, loop.step) for loop in loops]
actual_sizes = list()
actual_steps = list()
for ctr, s in zip(loop_counters, loop_iterations):
if s.step != 1:
if (s.stop - s.start) % s.step == 0:
actual_sizes.append((s.stop - s.start) // s.step)
else:
actual_sizes.append(int_div((s.stop - s.start), s.step))
if (ctr - s.start) % s.step == 0:
actual_steps.append((ctr - s.start) // s.step)
else:
actual_steps.append(int_div((ctr - s.start), s.step))
else:
actual_sizes.append(s.stop - s.start)
actual_steps.append(ctr - s.start)
field_accesses = ast_node.atoms(Field.Access)
buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
buffer_index_size = len(buffer_accesses)
base_buffer_index = actual_steps[0]
actual_stride = 1
for idx, actual_step in enumerate(actual_steps[1:]):
cur_stride = actual_sizes[idx]
actual_stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
base_buffer_index += actual_stride * actual_step
return base_buffer_index * buffer_index_size
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=None):
if read_only_field_names is None:
read_only_field_names = set()
def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, Field.Access):
field_access = expr
# Do not apply transformation if field is not a buffer
if not FieldType.is_buffer(field_access.field):
return expr
buffer = field_access.field
field_ptr = FieldPointerSymbol(buffer.name, buffer.dtype, const=buffer.name in read_only_field_names)
buffer_index = base_buffer_index
if len(field_access.index) > 1:
raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!')
if len(field_access.index) > 0:
cell_index = field_access.index[0]
buffer_index += cell_index
result = ast.ResolvedFieldAccess(field_ptr, buffer_index, field_access.field, field_access.offsets,
field_access.index)
return visit_sympy_expr(result, enclosing_block, sympy_assignment)
else:
if isinstance(expr, ast.ResolvedFieldAccess):
return expr
new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
return expr.func(*new_args, **kwargs) if new_args else expr
def visit_node(sub_ast):
if isinstance(sub_ast, ast.SympyAssignment):
enclosing_block = sub_ast.parent
assert type(enclosing_block) is ast.Block
sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
else:
for i, a in enumerate(sub_ast.args):
visit_node(a)
return visit_node(ast_node)
def resolve_field_accesses(ast_node, read_only_field_names=None,
field_to_base_pointer_info=MappingProxyType({}),
field_to_fixed_coordinates=MappingProxyType({})):
"""
Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing
Args:
ast_node: the AST root
read_only_field_names: set of field names which are considered read-only
field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created
for details see :func:`parse_base_pointer_info`
field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
counters to index the field these symbols are used as coordinates
Returns
transformed AST
"""
if read_only_field_names is None:
read_only_field_names = set()
field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, Field.Access):
field_access = expr
field = field_access.field
if field_access.indirect_addressing_fields:
new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment)
for off in field_access.offsets)
new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment)
if isinstance(ind, sp.Basic) else ind
for ind in field_access.index)
field_access = Field.Access(field_access.field, new_offsets,
new_indices, field_access.is_absolute_access)
if field.name in field_to_base_pointer_info:
base_pointer_info = field_to_base_pointer_info[field.name]
else:
base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))]
field_ptr = FieldPointerSymbol(
field.name,
field.dtype,
const=field.name in read_only_field_names)
def create_coordinate_dict(group_param):
coordinates = {}
for e in group_param:
if e < field.spatial_dimensions:
if field.name in field_to_fixed_coordinates:
if not field_access.is_absolute_access:
coordinates[e] = field_to_fixed_coordinates[field.name][e]
else:
coordinates[e] = 0
else:
if not field_access.is_absolute_access:
coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
else:
coordinates[e] = 0
coordinates[e] *= field.dtype.item_size
else:
if isinstance(field.dtype, StructType):
assert field.index_dimensions == 1
accessed_field_name = field_access.index[0]
if isinstance(accessed_field_name, sp.Symbol):
accessed_field_name = accessed_field_name.name
assert isinstance(accessed_field_name, str)
coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
else:
coordinates[e] = field_access.index[e - field.spatial_dimensions]
return coordinates
last_pointer = field_ptr
for group in reversed(base_pointer_info[1:]):
coord_dict = create_coordinate_dict(group)
new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
if new_ptr not in enclosing_block.symbols_defined:
new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False, use_auto=False)
enclosing_block.insert_before(new_assignment, sympy_assignment)
last_pointer = new_ptr
coord_dict = create_coordinate_dict(base_pointer_info[0])
_, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field,
field_access.offsets, field_access.index)
if isinstance(get_base_type(field_access.field.dtype), StructType):
accessed_field_name = field_access.index[0]
if isinstance(accessed_field_name, sp.Symbol):
accessed_field_name = accessed_field_name.name
new_type = field_access.field.dtype.get_element_type(accessed_field_name)
result = ReinterpretCastFunc(result, new_type)
return visit_sympy_expr(result, enclosing_block, sympy_assignment)
else:
if isinstance(expr, ast.ResolvedFieldAccess):
return expr
if hasattr(expr, 'args'):
new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
else:
new_args = []
kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
return expr.func(*new_args, **kwargs) if new_args else expr
def visit_node(sub_ast):
if isinstance(sub_ast, ast.SympyAssignment):
enclosing_block = sub_ast.parent
assert type(enclosing_block) is ast.Block
sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
elif isinstance(sub_ast, ast.Conditional):
enclosing_block = sub_ast.parent
assert type(enclosing_block) is ast.Block
sub_ast.condition_expr = visit_sympy_expr(sub_ast.condition_expr, enclosing_block, sub_ast)
visit_node(sub_ast.true_block)
if sub_ast.false_block:
visit_node(sub_ast.false_block)
else:
if isinstance(sub_ast, (bool, int, float)):
return
for a in sub_ast.args:
visit_node(a)
return visit_node(ast_node)
def move_constants_before_loop(ast_node):
"""Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.
Call this after creating the loop structure with :func:`make_loop_over_domain`
"""
def find_block_to_move_to(node):
"""
Traverses parents of node as long as the symbols are independent and returns a (parent) block
the assignment can be safely moved to
:param node: SympyAssignment inside a Block
:return blockToInsertTo, childOfBlockToInsertBefore
"""
assert isinstance(node.parent, ast.Block)
def modifies_or_declares(node: ast.Node, symbol_names: Set[str]) -> bool:
if isinstance(node, (ps.Assignment, ast.SympyAssignment)):
if isinstance(node.lhs, ast.ResolvedFieldAccess):
return node.lhs.typed_symbol.name in symbol_names
else:
return node.lhs.name in symbol_names
elif isinstance(node, ast.Block):
for arg in node.args:
if isinstance(arg, ast.SympyAssignment) and arg.is_declaration:
continue
if modifies_or_declares(arg, symbol_names):
return True
return False
elif isinstance(node, ast.LoopOverCoordinate):
return modifies_or_declares(node.body, symbol_names)
elif isinstance(node, ast.Conditional):
return (
modifies_or_declares(node.true_block, symbol_names)
or (node.false_block and modifies_or_declares(node.false_block, symbol_names))
)
elif isinstance(node, ast.KernelFunction):
return False
else:
defs = {s.name for s in node.symbols_defined}
return bool(symbol_names.intersection(defs))
dependencies = {s.name for s in node.undefined_symbols}
last_block = node.parent
last_block_child = node
element = node.parent
prev_element = node
while element:
if isinstance(element, (ast.Conditional, ast.KernelFunction)):
# Never move out of Conditionals or KernelFunctions.
break
elif isinstance(element, ast.Block):
last_block = element
last_block_child = prev_element
if any(modifies_or_declares(sibling, dependencies) for sibling in element.args):
# The node depends on one of the statements in this block.
# Do not move further out.
break
elif isinstance(element, ast.LoopOverCoordinate):
if element.loop_counter_symbol.name in dependencies:
# The node depends on the loop counter.
# Do not move out of this loop.
break
else:
raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n'
f'The expression {element} of type {type(element)} is not known yet.')
# No dependencies to symbols defined/modified within the current element.
# We can move the node up one level and in front of the current element.
prev_element = element
element = element.parent
return last_block, last_block_child
def check_if_assignment_already_in_block(assignment, target_block, rhs_or_lhs=True):
for arg in target_block.args:
if type(arg) is not ast.SympyAssignment:
continue
if (rhs_or_lhs and arg.rhs == assignment.rhs) or (not rhs_or_lhs and arg.lhs == assignment.lhs):
return arg
return None
def get_blocks(node, result_list):
if isinstance(node, ast.Block):
result_list.append(node)
if isinstance(node, ast.Node):
for a in node.args:
get_blocks(a, result_list)
all_blocks = []
get_blocks(ast_node, all_blocks)
for block in all_blocks:
children = block.take_child_nodes()
for child in children:
if not isinstance(child, ast.SympyAssignment): # only move SympyAssignments
block.append(child)
continue
target, child_to_insert_before = find_block_to_move_to(child)
if target == block: # movement not possible
target.append(child)
else:
if isinstance(child, ast.SympyAssignment):
exists_already = check_if_assignment_already_in_block(child, target, False)
else:
exists_already = False
if not exists_already:
target.insert_before(child, child_to_insert_before)
elif exists_already and exists_already.rhs == child.rhs:
if target.args.index(exists_already) > target.args.index(child_to_insert_before):
assert target.args.count(exists_already) == 1
assert target.args.count(child_to_insert_before) == 1
target.args.remove(exists_already)
target.insert_before(exists_already, child_to_insert_before)
else:
# this variable already exists in outer block, but with different rhs
# -> symbol has to be renamed
assert isinstance(child.lhs, TypedSymbol)
new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const),
child_to_insert_before)
block.append(ast.SympyAssignment(child.lhs, new_symbol, is_const=child.is_const))
def split_inner_loop(ast_node: ast.Node, symbol_groups):
"""
Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
Args:
ast_node: AST root
symbol_groups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups
and which no symbol in a symbol group depends on, are not updated!
"""
all_loops = ast_node.atoms(ast.LoopOverCoordinate)
inner_loop = [loop for loop in all_loops if loop.is_innermost_loop]
assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
inner_loop = inner_loop[0]
assert type(inner_loop.body) is ast.Block
outer_loop = [loop for loop in all_loops if loop.is_outermost_loop]
assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
outer_loop = outer_loop[0]
symbols_with_temporary_array = OrderedDict()
assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args if hasattr(a, 'lhs'))
assignment_groups = []
for symbol_group in symbol_groups:
# get all dependent symbols
symbols_to_process = list(symbol_group)
symbols_resolved = set()
while symbols_to_process:
s = symbols_to_process.pop()
if s in symbols_resolved:
continue
if s in assignment_map: # if there is no assignment inside the loop body it is independent already
for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
if not isinstance(new_symbol, Field.Access) and \
new_symbol not in symbols_with_temporary_array:
symbols_to_process.append(new_symbol)
symbols_resolved.add(s)
for symbol in symbol_group:
if not isinstance(symbol, Field.Access):
assert type(symbol) is TypedSymbol
new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
symbols_with_temporary_array[symbol] = sp.IndexedBase(
new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
assignment_group = []
for assignment in inner_loop.body.args:
if assignment.lhs in symbols_resolved:
# use fast_subs here because it checks if multiplications should be evaluated or not
new_rhs = fast_subs(assignment.rhs, symbols_with_temporary_array)
if not isinstance(assignment.lhs, Field.Access) and assignment.lhs in symbol_group:
assert type(assignment.lhs) is TypedSymbol
new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
new_lhs = sp.IndexedBase(new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
else:
new_lhs = assignment.lhs
assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
assignment_groups.append(assignment_group)
new_loops = [
inner_loop.new_loop_with_different_body(ast.Block(group))
for group in assignment_groups
]
inner_loop.parent.replace(inner_loop, ast.Block(new_loops))
for tmp_array in symbols_with_temporary_array:
tmp_array_pointer = TypedSymbol(tmp_array.name, PointerType(tmp_array.dtype))
alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop, inner_loop.start)
free_node = ast.TemporaryMemoryFree(alloc_node)
outer_loop.parent.insert_front(alloc_node)
outer_loop.parent.append(free_node)
def cut_loop(loop_node, cutting_points):
"""Cuts loop at given cutting points.
One loop is transformed into len(cuttingPoints)+1 new loops that range from
old_begin to cutting_points[1], ..., cutting_points[-1] to old_end
Modifies the ast in place. Note Issue #5783 of SymPy. Deepcopy will evaluate mul
https://github.com/sympy/sympy/issues/5783
Returns:
list of new loop nodes
"""
if loop_node.step != 1:
raise NotImplementedError("Can only split loops that have a step of 1")
new_loops = ast.Block([])
new_start = loop_node.start
cutting_points = list(cutting_points) + [loop_node.stop]
for new_end in cutting_points:
if new_end - new_start == 1:
new_body = deepcopy(loop_node.body)
new_body.subs({loop_node.loop_counter_symbol: new_start})
new_loops.append(new_body)
elif new_end - new_start == 0:
pass
else:
new_loop = ast.LoopOverCoordinate(
deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
new_start, new_end, loop_node.step)
new_loops.append(new_loop)
new_start = new_end
loop_node.parent.replace(loop_node, new_loops)
return new_loops
def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = False) -> None:
"""Removes conditionals that are always true/false.
Args:
node: ast node, all descendants of this node are simplified
loop_counter_simplification: if enabled, tries to detect if a conditional is always true/false
depending on the surrounding loop. For example if the surrounding loop goes from
x=0 to 10 and the condition is x < 0, it is removed.
This analysis needs the integer set library (ISL) islpy, so it is not done by
default.
"""
from sympy.codegen.rewriting import ReplaceOptim, optimize
remove_casts = ReplaceOptim(lambda e: isinstance(e, CastFunc), lambda p: p.expr)
for conditional in node.atoms(ast.Conditional):
# TODO simplify conditional before the type system! Casts make it very hard here
condition_expression = optimize(conditional.condition_expr, [remove_casts])
condition_expression = sp.simplify(condition_expression)
if condition_expression == sp.true:
conditional.parent.replace(conditional, [conditional.true_block])
elif condition_expression == sp.false:
conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
elif loop_counter_simplification:
try:
# noinspection PyUnresolvedReferences
from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional
simplify_loop_counter_dependent_conditional(conditional)
except ImportError:
warnings.warn("Integer simplifications in conditionals skipped, because ISLpy package not installed")
def cleanup_blocks(node: ast.Node) -> None:
"""Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """
if isinstance(node, ast.SympyAssignment):
return
elif isinstance(node, ast.Block):
for a in list(node.args):
cleanup_blocks(a)
if len(node.args) <= 1 and isinstance(node.parent, ast.Block):
node.parent.replace(node, node.args)
return
else:
for a in node.args:
cleanup_blocks(a)
def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, include_first=True) -> None:
"""Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last or
first and last element"""
all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop]
assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
inner_loop = all_inner_loops.pop()
for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
if include_first:
cut_loop(loop, [loop.start + 1, loop.stop - 1])
else:
cut_loop(loop, [loop.stop - 1])
simplify_conditionals(function_node.body, loop_counter_simplification=True)
cleanup_blocks(function_node.body)
move_constants_before_loop(function_node.body)
cleanup_blocks(function_node.body)
# --------------------------------------- Helper Functions -------------------------------------------------------------
def get_optimal_loop_ordering(fields):
"""
Determines the optimal loop order for a given set of fields.
If the fields have different memory layout or different sizes an exception is thrown.
Args:
fields: sequence of fields
Returns:
list of coordinate ids, where the first list entry should be the outermost loop
"""
assert len(fields) > 0
ref_field = next(iter(fields))
for field in fields:
if field.spatial_dimensions != ref_field.spatial_dimensions:
raise ValueError(
"All fields have to have the same number of spatial dimensions. Spatial field dimensions: "
+ str({f.name: f.spatial_shape
for f in fields}))
layouts = set([field.layout for field in fields])
if len(layouts) > 1:
raise ValueError(
"Due to different layout of the fields no optimal loop ordering exists "
+ str({f.name: f.layout
for f in fields}))
layout = list(layouts)[0]
return list(layout)
def get_loop_hierarchy(ast_node):
"""Determines the loop structure around a given AST node, i.e. the node has to be inside the loops.
Returns:
sequence of LoopOverCoordinate nodes, starting from outer loop to innermost loop
"""
result = []
node = ast_node
while node is not None:
node = get_next_parent_of_type(node, ast.LoopOverCoordinate)
if node:
result.append(node.coordinate_to_loop_over)
return reversed(result)
def get_loop_counter_symbol_hierarchy(ast_node):
"""Determines the loop counter symbols around a given AST node.
:param ast_node: the AST node
:return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop
"""
result = []
node = ast_node
while node is not None:
node = get_next_parent_of_type(node, ast.LoopOverCoordinate)
if node:
result.append(node.loop_counter_symbol)
return result
def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None:
"""Replaces the stride of the innermost loop of a variable sized kernel with 1 (assumes optimal loop ordering).
Variable sized kernels can handle arbitrary field sizes and field shapes. However, the kernel is most efficient
if the innermost loop accesses the fields with stride 1. The inner loop can also only be vectorized if the inner
stride is 1. This transformation hard codes this inner stride to one to enable e.g. vectorization.
Warning: the assumption is not checked at runtime!
"""
inner_loops = []
inner_loop_counters = set()
for loop in filtered_tree_iteration(ast_node,
ast.LoopOverCoordinate,
stop_type=ast.SympyAssignment):
if loop.is_innermost_loop:
inner_loops.append(loop)
inner_loop_counters.add(loop.coordinate_to_loop_over)
if len(inner_loop_counters) != 1:
raise ValueError("Inner loops iterate over different coordinates")
inner_loop_counter = inner_loop_counters.pop()
parameters = ast_node.get_parameters()
stride_params = [
p.symbol for p in parameters
if p.is_field_stride and p.symbol.coordinate == inner_loop_counter
]
subs_dict = {stride_param: 1 for stride_param in stride_params}
if subs_dict:
ast_node.subs(subs_dict)
def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
"""Blocking of loops to enhance cache locality. Modifies the ast node in-place.
Args:
ast_node: kernel function node before vectorization transformation has been applied
block_size: sequence defining block size in x, y, (z) direction.
If chosen as zero the direction will not be used for blocking.
Returns:
number of dimensions blocked
"""
loops = [
l for l in filtered_tree_iteration(
ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
]
body = ast_node.body
coordinates = []
coordinates_taken_into_account = 0
loop_starts = {}
loop_stops = {}
for loop in loops:
coord = loop.coordinate_to_loop_over
if coord not in coordinates:
coordinates.append(coord)
loop_starts[coord] = loop.start
loop_stops[coord] = loop.stop
else:
assert loop.start == loop_starts[coord] and loop.stop == loop_stops[coord], \
f"Multiple loops over coordinate {coord} with different loop bounds"
# Create the outer loops that iterate over the blocks
outer_loop = None
for coord in reversed(coordinates):
if block_size[coord] == 0:
continue
coordinates_taken_into_account += 1
body = ast.Block([outer_loop]) if outer_loop else body
outer_loop = ast.LoopOverCoordinate(body,
coord,
loop_starts[coord],
loop_stops[coord],
step=block_size[coord],
is_block_loop=True)
ast_node.body = ast.Block([outer_loop])
# modify the existing loops to only iterate within one block
for inner_loop in loops:
coord = inner_loop.coordinate_to_loop_over
if block_size[coord] == 0:
continue
block_ctr = ast.LoopOverCoordinate.get_block_loop_counter_symbol(coord)
loop_range = inner_loop.stop - inner_loop.start
if sp.sympify(
loop_range).is_number and loop_range % block_size[coord] == 0:
stop = block_ctr + block_size[coord]
else:
stop = sp.Min(inner_loop.stop, block_ctr + block_size[coord])
inner_loop.start = block_ctr
inner_loop.stop = stop
return coordinates_taken_into_account
from pystencils.typing.cast_functions import (CastFunc, BooleanCastFunc, VectorMemoryAccess, ReinterpretCastFunc,
PointerArithmeticFunc)
from pystencils.typing.types import (is_supported_type, numpy_name_to_c, AbstractType, BasicType, VectorType,
PointerType, StructType, create_type)
from pystencils.typing.typed_sympy import (assumptions_from_dtype, TypedSymbol, FieldStrideSymbol, FieldShapeSymbol,
FieldPointerSymbol, CFunction)
from pystencils.typing.utilities import (typed_symbols, get_base_type, result_type, collate_types,
get_type_of_expression, get_next_parent_of_type, parents_of_type)
__all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCastFunc', 'PointerArithmeticFunc',
'is_supported_type', 'numpy_name_to_c', 'AbstractType', 'BasicType',
'VectorType', 'PointerType', 'StructType', 'create_type', 'assumptions_from_dtype',
'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol', 'CFunction',
'typed_symbols', 'get_base_type', 'result_type', 'collate_types',
'get_type_of_expression', 'get_next_parent_of_type', 'parents_of_type']
import numpy as np
import sympy as sp
from sympy.logic.boolalg import Boolean
from pystencils.typing.types import AbstractType, BasicType
from pystencils.typing.typed_sympy import TypedSymbol
class CastFunc(sp.Function):
"""
CastFunc is used in order to introduce static casts. They are especially useful as a way to signal what type
a certain node should have, if it is impossible to add a type to a node, e.g. a sp.Number.
"""
is_Atom = True
def __new__(cls, *args, **kwargs):
if len(args) != 2:
pass
expr, dtype, *other_args = args
# If we have two consecutive casts, throw the inner one away.
# This optimisation is only available for simple casts. Thus the == is intended here!
if expr.__class__ == CastFunc:
expr = expr.args[0]
if not isinstance(dtype, AbstractType):
dtype = BasicType(dtype)
# to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
# however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
# to problems when for example comparing cast_func's for equality
#
# lhs = bitwise_and(a, cast_func(1, 'int'))
# rhs = cast_func(0, 'int')
# print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
# -> thus a separate class boolean_cast_func is introduced
if isinstance(expr, Boolean) and (not isinstance(expr, TypedSymbol) or expr.dtype == BasicType('bool')):
cls = BooleanCastFunc
return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
return self.args[0].canonical
else:
raise NotImplementedError()
@property
def is_commutative(self):
return self.args[0].is_commutative
@property
def dtype(self):
return self.args[1]
@property
def expr(self):
return self.args[0]
@property
def is_integer(self):
"""
Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate
For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
"""
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
else:
return super().is_integer
@property
def is_negative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if hasattr(self.dtype, 'numpy_dtype'):
if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
return False
return super().is_negative
@property
def is_nonnegative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if self.is_negative is False:
return True
else:
return super().is_nonnegative
@property
def is_real(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or np.issubdtype(self.dtype.numpy_dtype,
np.floating) or super().is_real
else:
return super().is_real
class BooleanCastFunc(CastFunc, Boolean):
# TODO: documentation
pass
class VectorMemoryAccess(CastFunc):
"""
Special memory access for vectorized kernel.
Arguments: read/write expression, type, aligned, non-temporal, mask (or none), stride
"""
nargs = (6,)
class ReinterpretCastFunc(CastFunc):
"""
Reinterpret cast is necessary for the StructType
"""
pass
class PointerArithmeticFunc(sp.Function, Boolean):
# TODO: documentation, or deprecate!
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
return self.args[0].canonical
else:
raise NotImplementedError()
from collections import namedtuple
from typing import Union, Tuple, Any, DefaultDict
import logging
import numpy as np
import sympy as sp
from sympy import Piecewise
from sympy.core.numbers import NegativeOne
from sympy.core.relational import Relational
from sympy.functions.elementary.piecewise import ExprCondPair
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
from sympy.functions.elementary.hyperbolic import HyperbolicFunction
from sympy.functions.elementary.integers import RoundFunction
from sympy.logic.boolalg import BooleanFunction
from sympy.logic.boolalg import BooleanAtom
from pystencils import astnodes as ast
from pystencils.functions import DivFunc, AddressOf
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.field import Field
from pystencils.typing.types import BasicType, PointerType
from pystencils.typing.utilities import collate_types
from pystencils.typing.cast_functions import CastFunc, BooleanCastFunc
from pystencils.typing.typed_sympy import TypedSymbol
from pystencils.fast_approximation import fast_sqrt, fast_division, fast_inv_sqrt
from pystencils.utils import ContextVar
class TypeAdder:
# TODO: specification -> jupyter notebook
"""Checks if the input to create_kernel is valid.
Test the following conditions:
- SSA Form for pure symbols:
- Every pure symbol may occur only once as left-hand-side of an assignment
- Every pure symbol that is read, may not be written to later
- Independence / Parallelization condition:
- a field that is written may only be read at exact the same spatial position
(Pure symbols are symbols that are not Field.Accesses)
"""
FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
def __init__(self, type_for_symbol: DefaultDict[str, BasicType], default_number_float: BasicType,
default_number_int: BasicType):
self.type_for_symbol = type_for_symbol
self.default_number_float = ContextVar(default_number_float)
self.default_number_int = ContextVar(default_number_int)
def visit(self, obj):
if isinstance(obj, (list, tuple)):
return [self.visit(e) for e in obj]
if isinstance(obj, ast.SympyAssignment):
return self.process_assignment(obj)
elif isinstance(obj, ast.Conditional):
condition, condition_type = self.figure_out_type(obj.condition_expr)
assert condition_type == BasicType('bool')
true_block = self.visit(obj.true_block)
false_block = None if obj.false_block is None else self.visit(
obj.false_block)
return ast.Conditional(condition, true_block=true_block, false_block=false_block)
elif isinstance(obj, ast.Block):
return ast.Block([self.visit(e) for e in obj.args])
elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
return obj
else:
raise ValueError("Invalid object in kernel " + str(type(obj)))
def process_assignment(self, assignment: ast.SympyAssignment) -> ast.SympyAssignment:
# for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
new_rhs, rhs_type = self.figure_out_type(assignment.rhs)
lhs = assignment.lhs
if not isinstance(lhs, (Field.Access, TypedSymbol)):
if isinstance(lhs, sp.Symbol):
self.type_for_symbol[lhs.name] = rhs_type
else:
raise ValueError(f'Lhs: `{lhs}` is not a subtype of sp.Symbol')
new_lhs, lhs_type = self.figure_out_type(lhs)
assert isinstance(new_lhs, (Field.Access, TypedSymbol))
if lhs_type != rhs_type:
logging.debug(f'Lhs"{new_lhs} of type "{lhs_type}" is assigned with a different datatype '
f'rhs: "{new_rhs}" of type "{rhs_type}".')
return ast.SympyAssignment(new_lhs, CastFunc(new_rhs, lhs_type), assignment.is_const, assignment.use_auto)
else:
return ast.SympyAssignment(new_lhs, new_rhs, assignment.is_const, assignment.use_auto)
# Type System Specification
# - Defined Types: TypedSymbol, Field, Field.Access, ...?
# - Indexed: always unsigned_integer64
# - Undefined Types: Symbol
# - Is specified in Config in the dict or as 'default_type' or behaves like `auto` in the case of lhs.
# - Constants/Numbers: Are either integer or floating. The precision and sign is specified via config
# - Example: 1.4 config:float32 -> float32
# - Expressions deduce types from arguments
# - Functions deduce types from arguments
# - default_type and default_float and default_int can be given for a list of assignment, or
# individually as a list for assignment
# Possible Problems - Do we need to support this?
# - Mixture in expression with int and float
# - Mixture in expression with uint64 and sint64
# TODO Logging: Lowest log level should log all casts ----> cast factory, make cast should contain logging
def figure_out_type(self, expr) -> Tuple[Any, Union[BasicType, PointerType]]:
# Trivial cases
from pystencils.field import Field
import pystencils.integer_functions
from pystencils.bit_masks import flag_cond
bool_type = BasicType('bool')
# TOOO: check the access
if isinstance(expr, Field.Access):
return expr, expr.dtype
elif isinstance(expr, TypedSymbol):
return expr, expr.dtype
elif isinstance(expr, sp.Symbol):
t = TypedSymbol(expr.name, self.type_for_symbol[expr.name])
return t, t.dtype
elif isinstance(expr, np.generic):
assert False, f'Why do we have a np.generic in rhs???? {expr}'
elif isinstance(expr, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)):
return expr, BasicType('float32') # see https://en.cppreference.com/w/cpp/numeric/math/INFINITY
elif isinstance(expr, sp.Number):
if expr.is_Integer:
data_type = self.default_number_int.get()
elif expr.is_Float or expr.is_Rational:
data_type = self.default_number_float.get()
else:
assert False, f'{sp.Number} is neither Float nor Integer'
return CastFunc(expr, data_type), data_type
elif isinstance(expr, AddressOf):
of = expr.args[0]
# TODO Basically this should do address_of already
assert isinstance(of, (Field.Access, TypedSymbol, Field))
return expr, PointerType(of.dtype)
elif isinstance(expr, BooleanAtom):
return expr, bool_type
elif isinstance(expr, Relational):
# TODO Jan: Code duplication with general case
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
if isinstance(expr, sp.Equality) and collated_type.is_float():
logging.warning(f"Using floating point numbers in equality comparison: {expr}")
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
new_eq = expr.func(*new_args)
return new_eq, bool_type
elif isinstance(expr, CastFunc):
new_expr, _ = self.figure_out_type(expr.expr)
return expr.func(*[new_expr, expr.dtype]), expr.dtype
elif isinstance(expr, ast.ConditionalFieldAccess):
access, access_type = self.figure_out_type(expr.access)
value, value_type = self.figure_out_type(expr.outofbounds_value)
condition, condition_type = self.figure_out_type(expr.outofbounds_condition)
assert condition_type == bool_type
collated_type = collate_types([access_type, value_type])
if collated_type == access_type:
new_access = access
else:
logging.warning(f"In {expr} the Field Access had to be casted to {collated_type}. This is "
f"probably due to a type missmatch of the Field and the value of "
f"ConditionalFieldAccess")
new_access = CastFunc(access, collated_type)
new_value = value if value_type == collated_type else CastFunc(value, collated_type)
return expr.func(new_access, condition, new_value), collated_type
elif isinstance(expr, (vec_any, vec_all)):
return expr, bool_type
elif isinstance(expr, BooleanFunction):
args_types = [self.figure_out_type(a) for a in expr.args]
new_args = [a if t.dtype_eq(bool_type) else BooleanCastFunc(a, bool_type) for a, t in args_types]
return expr.func(*new_args), bool_type
elif type(expr, ) in pystencils.integer_functions.__dict__.values() or isinstance(expr, sp.Mod):
args_types = [self.figure_out_type(a) for a in expr.args]
collated_type = collate_types([t for _, t in args_types])
# TODO: should we downcast to integer? If yes then which integer type?
if not collated_type.is_int():
raise ValueError(f"Integer functions or Modulo need to be used with integer types "
f"but {collated_type} was given")
return expr, collated_type
elif isinstance(expr, flag_cond):
# do not process the arguments to the bit shift - they must remain integers
args_types = [self.figure_out_type(a) for a in (expr.args[i] for i in range(2, len(expr.args)))]
collated_type = collate_types([t for _, t in args_types])
new_expressions = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
return expr.func(expr.args[0], expr.args[1], *new_expressions), collated_type
# elif isinstance(expr, sp.Mul):
# raise NotImplementedError('sp.Mul')
# # TODO can we ignore this and move it to general expr handling, i.e. removing Mul? (See todo in backend)
# # args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)]
elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label
return expr, typed_symbol.dtype
elif isinstance(expr, ExprCondPair):
expr_expr, expr_type = self.figure_out_type(expr.expr)
condition, condition_type = self.figure_out_type(expr.cond)
if condition_type != bool_type:
logging.warning(f'Condition "{condition}" is of type "{condition_type}" and not "bool"')
return expr.func(expr_expr, condition), expr_type
elif isinstance(expr, Piecewise):
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
new_args = []
for a, t in args_types:
if t != collated_type:
if isinstance(a, ExprCondPair):
new_args.append(a.func(CastFunc(a.expr, collated_type), a.cond))
else:
new_args.append(CastFunc(a, collated_type))
else:
new_args.append(a)
return expr.func(*new_args) if new_args else expr, collated_type
elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction,
HyperbolicFunction, sp.log, RoundFunction)):
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
new_func = expr.func(*new_args) if new_args else expr
if collated_type == BasicType('float64'):
return new_func, collated_type
else:
return CastFunc(new_func, collated_type), collated_type
elif isinstance(expr, (fast_sqrt, fast_division, fast_inv_sqrt)):
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = BasicType('float32')
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
new_func = expr.func(*new_args) if new_args else expr
return CastFunc(new_func, collated_type), collated_type
elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)):
# Subtraction is realised a multiplication with -1 in SymPy. Thus we exclude the coefficient in this case
# and resolve the typing entirely with the expression itself
if isinstance(expr, sp.Mul):
c, e = expr.as_coeff_Mul()
if c == NegativeOne():
args_types = self.figure_out_type(e)
new_args = [NegativeOne(), args_types[0]]
return expr.func(*new_args, evaluate=False), args_types[1]
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
if isinstance(collated_type, PointerType):
if isinstance(expr, sp.Add):
return expr.func(*[a for a, _ in args_types]), collated_type
else:
raise NotImplementedError(f'Pointer Arithmetic is implemented only for Add, not {expr}')
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
if isinstance(expr, (sp.Add, sp.Mul)):
return expr.func(*new_args, evaluate=False) if new_args else expr, collated_type
else:
return expr.func(*new_args) if new_args else expr, collated_type
else:
raise NotImplementedError(f'expr {type(expr)}: {expr} unknown to typing')
from typing import List
from pystencils.astnodes import Node
from pystencils.config import CreateKernelConfig
from pystencils.typing.leaf_typing import TypeAdder
def add_types(node_list: List[Node], config: CreateKernelConfig):
"""Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
The AST needs to be a pystencils AST. Thus, in the list of nodes every entry must be inherited from
`pystencils.astnodes.Node`
Additionally returns sets of all fields which are read/written
Args:
node_list: List of pystencils Nodes.
config: CreateKernelConfig
Returns:
``typed_equations`` list of equations where symbols have been replaced by typed symbols
"""
check = TypeAdder(type_for_symbol=config.data_type,
default_number_float=config.default_number_float,
default_number_int=config.default_number_int)
return check.visit(node_list)
from typing import Union
import numpy as np
import sympy as sp
from sympy.core.cache import cacheit
from pystencils.typing.types import BasicType, create_type, PointerType
def assumptions_from_dtype(dtype: Union[BasicType, np.dtype]):
"""Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype
Args:
dtype (BasicType, np.dtype): a Numpy data type
Returns:
A dict of SymPy assumptions
"""
if hasattr(dtype, 'numpy_dtype'):
dtype = dtype.numpy_dtype
assumptions = dict()
try:
if np.issubdtype(dtype, np.integer):
assumptions.update({'integer': True})
if np.issubdtype(dtype, np.unsignedinteger):
assumptions.update({'negative': False})
if np.issubdtype(dtype, np.integer) or \
np.issubdtype(dtype, np.floating):
assumptions.update({'real': True})
except Exception: # TODO this is dirty
pass
return assumptions
class TypedSymbol(sp.Symbol):
def __new__(cls, *args, **kwds):
obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, name, dtype, **kwargs): # TODO does not match signature of sp.Symbol???
# TODO: also Symbol should be allowed ---> see sympy Variable
assumptions = assumptions_from_dtype(dtype)
assumptions.update(kwargs)
obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions)
try:
obj.numpy_dtype = create_type(dtype)
except (TypeError, ValueError):
# on error keep the string
obj.numpy_dtype = dtype
return obj
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
@property
def dtype(self):
return self.numpy_dtype
def _hashable_content(self):
return super()._hashable_content(), hash(self.numpy_dtype)
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), self.assumptions0
@property
def canonical(self):
return self
@property
def reversed(self):
return self
@property
def headers(self):
headers = []
try:
if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
try:
if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
return headers
SHAPE_DTYPE = BasicType('int64', const=True)
STRIDE_DTYPE = BasicType('int64', const=True)
class FieldStrideSymbol(TypedSymbol):
"""Sympy symbol representing the stride value of a field in a specific coordinate."""
def __new__(cls, *args, **kwds):
obj = FieldStrideSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, field_name, coordinate):
name = f"_stride_{field_name}_{coordinate}"
obj = super(FieldStrideSymbol, cls).__xnew__(cls, name, STRIDE_DTYPE, positive=True)
obj.field_name = field_name
obj.coordinate = coordinate
return obj
def __getnewargs__(self):
return self.field_name, self.coordinate
def __getnewargs_ex__(self):
return (self.field_name, self.coordinate), {}
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def _hashable_content(self):
return super()._hashable_content(), self.coordinate, self.field_name
class FieldShapeSymbol(TypedSymbol):
"""Sympy symbol representing the shape value of a sequence of fields. In a kernel iterating over multiple fields
there is only one set of `FieldShapeSymbol`s since all the fields have to be of equal size."""
def __new__(cls, *args, **kwds):
obj = FieldShapeSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, field_names, coordinate):
names = "_".join([field_name for field_name in field_names])
name = f"_size_{names}_{coordinate}"
obj = super(FieldShapeSymbol, cls).__xnew__(cls, name, SHAPE_DTYPE, positive=True)
obj.field_names = tuple(field_names)
obj.coordinate = coordinate
return obj
def __getnewargs__(self):
return self.field_names, self.coordinate
def __getnewargs_ex__(self):
return (self.field_names, self.coordinate), {}
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def _hashable_content(self):
return super()._hashable_content(), self.coordinate, self.field_names
class FieldPointerSymbol(TypedSymbol):
"""Sympy symbol representing the pointer to the beginning of the field data."""
def __new__(cls, *args, **kwds):
obj = FieldPointerSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, field_name, field_dtype, const):
from pystencils.typing.utilities import get_base_type
name = f"_data_{field_name}"
dtype = PointerType(get_base_type(field_dtype), const=const, restrict=True)
obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype)
obj.field_name = field_name
return obj
def __getnewargs__(self):
return self.field_name, self.dtype, self.dtype.const
def __getnewargs_ex__(self):
return (self.field_name, self.dtype, self.dtype.const), {}
def _hashable_content(self):
return super()._hashable_content(), self.field_name
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
class CFunction(TypedSymbol):
def __new__(cls, function, dtype):
return CFunction.__xnew_cached_(cls, function, dtype)
def __new_stage2__(cls, function, dtype):
return super(CFunction, cls).__xnew__(cls, function, dtype)
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), {}
from abc import abstractmethod
from typing import Union
import numpy as np
import sympy as sp
def is_supported_type(dtype: np.dtype):
scalar = dtype.type
c = np.issubdtype(dtype, np.generic)
subclass = issubclass(scalar, np.floating) or issubclass(scalar, np.integer) or issubclass(scalar, np.bool_)
additional_checks = dtype.fields is None and dtype.hasobject is False and dtype.subdtype is None
return c and subclass and additional_checks
def numpy_name_to_c(name: str) -> str:
"""
Converts a np.dtype.name into a C type
Args:
name: np.dtype.name string
Returns:
type as a C string
"""
if name == 'float64':
return 'double'
elif name == 'float32':
return 'float'
elif name == 'float16' or name == 'half':
return 'half'
elif name.startswith('int'):
width = int(name[len("int"):])
return f"int{width}_t"
elif name.startswith('uint'):
width = int(name[len("uint"):])
return f"uint{width}_t"
elif name == 'bool':
return 'bool'
else:
raise NotImplementedError(f"Can't map numpy to C name for {name}")
class AbstractType(sp.Atom):
# TODO: Is it necessary to ineherit from sp.Atom?
def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls)
def _sympystr(self, *args, **kwargs):
return str(self)
@property
@abstractmethod
def base_type(self) -> Union[None, 'BasicType']:
"""
Returns: Returns BasicType of a Vector or Pointer type, None otherwise
"""
pass
@property
@abstractmethod
def item_size(self) -> int:
"""
Returns: Number of items.
E.g. width * item_size(basic_type) in vector's case, or simple numpy itemsize in Struct's case.
"""
pass
class BasicType(AbstractType):
"""
BasicType is defined with a const qualifier and a np.dtype.
"""
def __init__(self, dtype: Union[type, 'BasicType', str], const: bool = False):
if isinstance(dtype, BasicType):
self.numpy_dtype = dtype.numpy_dtype
self.const = dtype.const
else:
self.numpy_dtype = np.dtype(dtype)
self.const = const
assert is_supported_type(self.numpy_dtype), f'Type {self.numpy_dtype} is currently not supported!'
def __getnewargs__(self):
return self.numpy_dtype, self.const
def __getnewargs_ex__(self):
return (self.numpy_dtype, self.const), {}
@property
def base_type(self):
return None
@property
def item_size(self): # TODO: Do we want self.numpy_type.itemsize????
return 1
def is_float(self):
return issubclass(self.numpy_dtype.type, np.floating)
def is_half(self):
return issubclass(self.numpy_dtype.type, np.half)
def is_int(self):
return issubclass(self.numpy_dtype.type, np.integer)
def is_uint(self):
return issubclass(self.numpy_dtype.type, np.unsignedinteger)
def is_sint(self):
return issubclass(self.numpy_dtype.type, np.signedinteger)
def is_bool(self):
return issubclass(self.numpy_dtype.type, np.bool_)
def dtype_eq(self, other):
if not isinstance(other, BasicType):
return False
else:
return self.numpy_dtype == other.numpy_dtype
@property
def c_name(self) -> str:
return numpy_name_to_c(self.numpy_dtype.name)
def __str__(self):
return f'{self.c_name}{" const" if self.const else ""}'
def __repr__(self):
return f'BasicType( {str(self)} )'
def _repr_html_(self):
return f'BasicType( {str(self)} )'
def __eq__(self, other):
return self.dtype_eq(other) and self.const == other.const
def __hash__(self):
return hash(str(self))
class VectorType(AbstractType):
"""
VectorType consists of a BasicType and a width.
"""
instruction_set = None
def __init__(self, base_type: BasicType, width: int):
self._base_type = base_type
self.width = width
@property
def base_type(self):
return self._base_type
@property
def item_size(self):
return self.width * self.base_type.item_size
def __eq__(self, other):
if not isinstance(other, VectorType):
return False
else:
return (self.base_type, self.width) == (other.base_type, other.width)
def __str__(self):
if self.instruction_set is None:
return f"{self.base_type}[{self.width}]"
else:
# TODO VectorizationRevamp: this seems super weird. the instruction_set should know how to print a type out!
# TODO VectorizationRevamp: this is error prone. base_type could be cons=True. Use dtype instead
if self.base_type == create_type("int64") or self.base_type == create_type("int32"):
return self.instruction_set['int']
elif self.base_type == create_type("float64"):
return self.instruction_set['double']
elif self.base_type == create_type("float32"):
return self.instruction_set['float']
elif self.base_type == create_type("bool"):
return self.instruction_set['bool']
else:
raise NotImplementedError()
def __hash__(self):
return hash((self.base_type, self.width))
def __getnewargs__(self):
return self._base_type, self.width
def __getnewargs_ex__(self):
return (self._base_type, self.width), {}
class PointerType(AbstractType):
def __init__(self, base_type: BasicType, const: bool = False, restrict: bool = True, double_pointer: bool = False):
self._base_type = base_type
self.const = const
self.restrict = restrict
self.double_pointer = double_pointer
def __getnewargs__(self):
return self.base_type, self.const, self.restrict, self.double_pointer
def __getnewargs_ex__(self):
return (self.base_type, self.const, self.restrict, self.double_pointer), {}
@property
def alias(self):
return not self.restrict
@property
def base_type(self):
return self._base_type
@property
def item_size(self):
if self.double_pointer:
raise NotImplementedError("The item_size for double_pointer is not implemented")
else:
return self.base_type.item_size
def __eq__(self, other):
if not isinstance(other, PointerType):
return False
else:
own = (self.base_type, self.const, self.restrict, self.double_pointer)
return own == (other.base_type, other.const, other.restrict, other.double_pointer)
def __str__(self):
restrict_str = "RESTRICT" if self.restrict else ""
const_str = "const" if self.const else ""
if self.double_pointer:
return f'{str(self.base_type)} ** {restrict_str} {const_str}'
else:
return f'{str(self.base_type)} * {restrict_str} {const_str}'
def __repr__(self):
return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self):
return hash((self._base_type, self.const, self.restrict, self.double_pointer))
class StructType(AbstractType):
"""
A list of types (with C offsets).
It is implemented with uint8_t and casts to the correct datatype.
"""
def __init__(self, numpy_type, const=False):
self.const = const
self._dtype = np.dtype(numpy_type)
def __getnewargs__(self):
return self.numpy_dtype, self.const
def __getnewargs_ex__(self):
return (self.numpy_dtype, self.const), {}
@property
def base_type(self):
return None
@property
def numpy_dtype(self):
return self._dtype
@property
def item_size(self):
return self.numpy_dtype.itemsize
def get_element_offset(self, element_name):
return self.numpy_dtype.fields[element_name][1]
def get_element_type(self, element_name):
np_element_type = self.numpy_dtype.fields[element_name][0]
return BasicType(np_element_type, self.const)
def has_element(self, element_name):
return element_name in self.numpy_dtype.fields
def __eq__(self, other):
if not isinstance(other, StructType):
return False
else:
return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
def __str__(self):
# structs are handled byte-wise
result = "uint8_t"
if self.const:
result += " const"
return result
def __repr__(self):
return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self):
return hash((self.numpy_dtype, self.const))
def create_type(specification: Union[type, AbstractType, str]) -> AbstractType:
# TODO: Deprecated Use the constructor of BasicType or StructType instead
"""Creates a subclass of Type according to a string or an object of subclass Type.
Args:
specification: Type object, or a string
Returns:
Type object, or a new Type object parsed from the string
"""
if isinstance(specification, AbstractType):
return specification
else:
numpy_dtype = np.dtype(specification)
if numpy_dtype.fields is None:
return BasicType(numpy_dtype, const=False)
else:
return StructType(numpy_dtype, const=False)
from collections import defaultdict
from functools import partial
from typing import Tuple, Union, Sequence
import numpy as np
import sympy as sp
from sympy.logic.boolalg import Boolean, BooleanFunction
import pystencils
from pystencils.cache import memorycache_if_hashable
from pystencils.typing.types import BasicType, VectorType, PointerType, create_type
from pystencils.typing.cast_functions import CastFunc
from pystencils.typing.typed_sympy import TypedSymbol
from pystencils.utils import all_equal
def typed_symbols(names, dtype, **kwargs):
"""
Creates TypedSymbols with the same functionality as sympy.symbols
Args:
names: See sympy.symbols
dtype: The data type all symbols will have
**kwargs: Key value arguments passed to sympy.symbols
Returns:
TypedSymbols
"""
symbols = sp.symbols(names, **kwargs)
if isinstance(symbols, Tuple):
return tuple(TypedSymbol(str(s), dtype) for s in symbols)
else:
return TypedSymbol(str(symbols), dtype)
def get_base_type(data_type):
"""
Returns the BasicType of a Pointer or a Vector
"""
while data_type.base_type is not None:
data_type = data_type.base_type
return data_type
def result_type(*args: np.dtype):
"""Returns the type of the result if the np.dtype arguments would be collated.
We can't use numpy functionality, because numpy casts don't behave exactly like C casts"""
s = sorted(args, key=lambda x: x.itemsize)
def kind_to_value(kind: str) -> int:
if kind == 'f':
return 3
elif kind == 'i':
return 2
elif kind == 'u':
return 1
elif kind == 'b':
return 0
else:
raise NotImplementedError(f'{kind=} is not a supported kind of a type. See "numpy.dtype.kind" for options')
s = sorted(s, key=lambda x: kind_to_value(x.kind))
return s[-1]
def collate_types(types: Sequence[Union[BasicType, VectorType]]):
"""
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy.
"""
# Pointer arithmetic case i.e. pointer + [int, uint] is allowed
if any(isinstance(t, PointerType) for t in types):
pointer_type = None
for t in types:
if isinstance(t, PointerType):
if pointer_type is not None:
raise ValueError(f'Cannot collate the combination of two pointer types "{pointer_type}" and "{t}"')
pointer_type = t
elif isinstance(t, BasicType):
if not (t.is_int() or t.is_uint()):
raise ValueError("Invalid pointer arithmetic")
else:
raise ValueError("Invalid pointer arithmetic")
return pointer_type
# # peel of vector types, if at least one vector type occurred the result will also be the vector type
vector_type = [t for t in types if isinstance(t, VectorType)]
if not all_equal(t.width for t in vector_type):
raise ValueError("Collation failed because of vector types with different width")
# TODO: check if this is needed
# def peel_off_type(dtype, type_to_peel_off):
# while type(dtype) is type_to_peel_off:
# dtype = dtype.base_type
# return dtype
# types = [peel_off_type(t, VectorType) for t in types]
types = [t.base_type if isinstance(t, VectorType) else t for t in types]
# now we should have a list of basic types - struct types are not yet supported
assert all(type(t) is BasicType for t in types)
result_numpy_type = result_type(*(t.numpy_dtype for t in types))
result = BasicType(result_numpy_type)
if vector_type:
result = VectorType(result, vector_type[0].width)
return result
# TODO get_type_of_expression should be used after leaf_typing. So no defaults should be necessary
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
default_float_type='double',
default_int_type='int',
symbol_type_dict=None):
from pystencils.astnodes import ResolvedFieldAccess
from pystencils.cpu.vectorization import vec_all, vec_any
if default_float_type == 'float':
default_float_type = 'float32'
if not symbol_type_dict:
symbol_type_dict = defaultdict(lambda: create_type('double'))
# TODO this line is quite hard to understand, if possible simpl
get_type = partial(get_type_of_expression,
default_float_type=default_float_type,
default_int_type=default_int_type,
symbol_type_dict=symbol_type_dict)
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
return create_type(default_int_type)
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type(default_float_type)
elif isinstance(expr, ResolvedFieldAccess):
return expr.field.dtype
elif isinstance(expr, pystencils.field.Field.Access):
return expr.field.dtype
elif isinstance(expr, TypedSymbol):
return expr.dtype
elif isinstance(expr, sp.Symbol):
# TODO delete if case
if symbol_type_dict:
return symbol_type_dict[expr.name]
else:
raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
elif isinstance(expr, CastFunc):
return expr.args[1]
elif isinstance(expr, (vec_any, vec_all)):
return create_type("bool")
elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
return collated_result_type
elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label
return typed_symbol.dtype.base_type
elif isinstance(expr, (Boolean, BooleanFunction)):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool")
vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
if vec_args:
result = VectorType(result, width=vec_args[0].width)
return result
elif isinstance(expr, sp.Pow):
base_type = get_type(expr.args[0])
if expr.exp.is_integer:
return base_type
else:
return collate_types([create_type(default_float_type), base_type])
elif isinstance(expr, (sp.Sum, sp.Product)):
return get_type(expr.args[0])
elif isinstance(expr, sp.Expr):
expr: sp.Expr
if expr.args:
types = tuple(get_type(a) for a in expr.args)
return collate_types(types)
else:
if expr.is_integer:
return create_type(default_int_type)
else:
return create_type(default_float_type)
raise NotImplementedError("Could not determine type for", expr, type(expr))
# Fix for sympy versions from 1.9
sympy_version = sp.__version__.split('.')
sympy_version_int = int(sympy_version[0]) * 100 + int(sympy_version[1])
if sympy_version_int >= 109:
# __setstate__ would bypass the contructor, so we remove it
if sympy_version_int >= 111:
del sp.Basic.__setstate__
del sp.Symbol.__setstate__
else:
sp.Number.__getstate__ = sp.Basic.__getstate__
del sp.Basic.__getstate__
# __reduce_ex__ would strip kwargs, so we override it
def basic_reduce_ex(self, protocol):
if hasattr(self, '__getnewargs_ex__'):
args, kwargs = self.__getnewargs_ex__()
else:
args, kwargs = self.__getnewargs__(), {}
if hasattr(self, '__getstate__'):
state = self.__getstate__()
else:
state = None
return partial(type(self), **kwargs), args, state
sp.Basic.__reduce_ex__ = basic_reduce_ex
def get_next_parent_of_type(node, parent_type):
"""Returns the next parent node of given type or None, if root is reached.
Traverses the AST nodes parents until a parent of given type was found.
If no such parent is found, None is returned
"""
parent = node.parent
while parent is not None:
if isinstance(parent, parent_type):
return parent
parent = parent.parent
return None
def parents_of_type(node, parent_type, include_current=False):
"""Generator for all parent nodes of given type"""
parent = node if include_current else node.parent
while parent is not None:
if isinstance(parent, parent_type):
yield parent
parent = parent.parent
import os
import itertools
from itertools import groupby
from collections import Counter
from contextlib import contextmanager
from tempfile import NamedTemporaryFile
from typing import Mapping
import numpy as np
import sympy as sp
class DotDict(dict):
"""Normal dict with additional dot access for all keys"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
# Recursively make DotDict: https://stackoverflow.com/questions/13520421/recursive-dotdict
def __init__(self, dct={}):
for key, value in dct.items():
if isinstance(value, dict):
value = DotDict(value)
self[key] = value
def all_equal(iterable):
"""
Returns ``True`` if all the elements are equal to each other.
Copied from: more-itertools 8.12.0
"""
g = groupby(iterable)
return next(g, True) and not next(g, False)
def recursive_dict_update(d, u):
"""Updates the first dict argument, using second dictionary recursively.
Examples:
>>> d = {'sub_dict': {'a': 1, 'b': 2}, 'outer': 42}
>>> u = {'sub_dict': {'a': 5, 'c': 10}, 'outer': 41, 'outer2': 43}
>>> recursive_dict_update(d, u)
{'sub_dict': {'a': 5, 'b': 2, 'c': 10}, 'outer': 41, 'outer2': 43}
"""
d = d.copy()
for k, v in u.items():
if isinstance(v, Mapping):
r = recursive_dict_update(d.get(k, {}), v)
d[k] = r
else:
d[k] = u[k]
return d
@contextmanager
def atomic_file_write(file_path):
target_folder = os.path.dirname(os.path.abspath(file_path))
with NamedTemporaryFile(delete=False, dir=target_folder) as f:
f.file.close()
yield f.name
os.replace(f.name, file_path)
def fully_contains(l1, l2):
"""Tests if elements of sequence 1 are in sequence 2 in same or higher number.
>>> fully_contains([1, 1, 2], [1, 2]) # 1 is only present once in second list
False
>>> fully_contains([1, 1, 2], [1, 1, 4, 2])
True
"""
l1_counter = Counter(l1)
l2_counter = Counter(l2)
for element, count in l1_counter.items():
if l2_counter[element] < count:
return False
return True
def boolean_array_bounding_box(boolean_array):
"""Returns bounding box around "true" area of boolean array
>>> a = np.zeros((4, 4), dtype=bool)
>>> a[1:-1, 1:-1] = True
>>> boolean_array_bounding_box(a) == [(1, 3), (1, 3)]
True
"""
dim = boolean_array.ndim
shape = boolean_array.shape
assert 0 not in shape, "Shape must not contain zero"
bounds = []
for ax in itertools.combinations(reversed(range(dim)), dim - 1):
nonzero = np.any(boolean_array, axis=ax)
t = np.where(nonzero)[0][[0, -1]]
bounds.append((t[0], t[1] + 1))
return bounds
def binary_numbers(n):
"""Returns all binary numbers up to 2^n - 1
Example:
>>> binary_numbers(2)
[[0, 0], [0, 1], [1, 0], [1, 1]]
"""
result = list()
for i in range(1 << n):
binary_number = bin(i)[2:]
binary_number = '0' * (n - len(binary_number)) + binary_number
result.append((list(map(int, binary_number))))
return result
class LinearEquationSystem:
"""Symbolic linear system of equations - consisting of matrix and right hand side.
Equations can be added incrementally. System is held in reduced row echelon form to quickly determine if
system has a single, multiple, or no solution.
Example:
>>> x, y= sp.symbols("x, y")
>>> les = LinearEquationSystem([x, y])
>>> les.add_equation(x - y - 3)
>>> les.solution_structure()
'multiple'
>>> les.add_equation(x + y - 4)
>>> les.solution_structure()
'single'
>>> les.solution()
{x: 7/2, y: 1/2}
"""
def __init__(self, unknowns):
size = len(unknowns)
self._matrix = sp.zeros(size, size + 1)
self.unknowns = unknowns
self.next_zero_row = 0
self._reduced = True
def copy(self):
"""Returns a copy of the equation system."""
new = LinearEquationSystem(self.unknowns)
new._matrix = self._matrix.copy()
new.next_zero_row = self.next_zero_row
return new
def add_equation(self, linear_equation):
"""Add a linear equation as sympy expression. Implicit "-0" is assumed. Equation has to be linear and contain
only unknowns passed to the constructor otherwise a ValueError is raised. """
self._resize_if_necessary()
linear_equation = linear_equation.expand()
zero_row_idx = self.next_zero_row
self.next_zero_row += 1
control = 0
for i, unknown in enumerate(self.unknowns):
self._matrix[zero_row_idx, i] = linear_equation.coeff(unknown)
control += unknown * self._matrix[zero_row_idx, i]
rest = linear_equation - control
if rest.atoms(sp.Symbol):
raise ValueError("Not a linear equation in the unknowns")
self._matrix[zero_row_idx, -1] = -rest
self._reduced = False
def add_equations(self, linear_equations):
"""Add a sequence of equations. For details see `add_equation`. """
self._resize_if_necessary(len(linear_equations))
for eq in linear_equations:
self.add_equation(eq)
def set_unknown_zero(self, unknown_idx):
"""Sets an unknown to zero - pass the index not the variable itself!"""
assert unknown_idx < len(self.unknowns)
self._resize_if_necessary()
self._matrix[self.next_zero_row, unknown_idx] = 1
self.next_zero_row += 1
self._reduced = False
def reduce(self):
"""Brings the system in reduced row echelon form."""
if self._reduced:
return
self._matrix = self._matrix.rref()[0]
self._update_next_zero_row()
self._reduced = True
@property
def matrix(self):
"""Return a matrix that represents the equation system.
Has one column more than unknowns for the affine part."""
self.reduce()
return self._matrix
@property
def rank(self):
self.reduce()
return self.next_zero_row
def solution_structure(self):
"""Returns either 'multiple', 'none' or 'single' to indicate how many solutions the system has."""
self.reduce()
non_zero_rows = self.next_zero_row
num_unknowns = len(self.unknowns)
if non_zero_rows == 0:
return 'multiple'
*row_begin, left, right = self._matrix.row(non_zero_rows - 1)
if non_zero_rows > num_unknowns:
return 'none'
elif non_zero_rows == num_unknowns:
if left == 0 and right != 0:
return 'none'
else:
return 'single'
elif non_zero_rows < num_unknowns:
if right != 0 and left == 0 and all(e == 0 for e in row_begin):
return 'none'
else:
return 'multiple'
def solution(self):
"""Solves the system. Under- and overdetermined systems are supported.
Returns a dictionary mapping symbol to solution value."""
return sp.solve_linear_system(self._matrix, *self.unknowns)
def _resize_if_necessary(self, new_rows=1):
if self.next_zero_row + new_rows > self._matrix.shape[0]:
self._matrix = self._matrix.row_insert(self._matrix.shape[0] + 1,
sp.zeros(new_rows, self._matrix.shape[1]))
def _update_next_zero_row(self):
result = self._matrix.shape[0]
while result >= 0:
row_to_check = result - 1
if any(e != 0 for e in self._matrix.row(row_to_check)):
break
result -= 1
self.next_zero_row = result
class ContextVar:
def __init__(self, value):
self.stack = [value]
@contextmanager
def __call__(self, new_value):
self.stack.append(new_value)
yield self
self.stack.pop()
def get(self):
return self.stack[-1]
import sympy as sp
import operator
from collections import defaultdict, Sequence
import warnings
def fastSubs(term, subsDict):
"""Similar to sympy subs function.
This version is much faster for big substitution dictionaries than sympy version"""
def visit(expr):
if expr in subsDict:
return subsDict[expr]
if not hasattr(expr, 'args'):
return expr
paramList = [visit(a) for a in expr.args]
return expr if not paramList else expr.func(*paramList)
return visit(term)
def replaceAdditive(expr, replacement, subExpression, requiredMatchReplacement=0.5, requiredMatchOriginal=None):
"""
Transformation for replacing a given subexpression inside a sum
Example 1:
expr = 3*x + 3 * y
replacement = k
subExpression = x+y
return = 3*k
Example 2:
expr = 3*x + 3 * y + z
replacement = k
subExpression = x+y+z
return:
if minimalMatchingTerms >=3 the expression would not be altered
if smaller than 3 the result is 3*k - 2*z
:param expr: input expression
:param replacement: expression that is inserted for subExpression (if found)
:param subExpression: expression to replace
:param requiredMatchReplacement:
- if float: the percentage of terms of the subExpression that has to be matched in order to replace
- if integer: the total number of terms that has to be matched in order to replace
- None: is equal to integer 1
- if both match parameters are given, both restrictions have to be fulfilled (i.e. logical AND)
:param requiredMatchOriginal:
- if float: the percentage of terms of the original addition expression that has to be matched
- if integer: the total number of terms that has to be matched in order to replace
- None: is equal to integer 1
:return: new expression with replacement
"""
def normalizeMatchParameter(matchParameter, expressingLength):
if matchParameter is None:
return 1
elif isinstance(matchParameter, float):
assert 0 <= matchParameter <= 1
res = int(matchParameter * expressingLength)
return max(res, 1)
elif isinstance(matchParameter, int):
assert matchParameter > 0
return matchParameter
raise ValueError("Invalid parameter")
normalizedReplacementMatch = normalizeMatchParameter(requiredMatchReplacement, len(subExpression.args))
def visit(currentExpr):
if currentExpr.is_Add:
exprMaxLength = max(len(currentExpr.args), len(subExpression.args))
normalizedCurrentExprMatch = normalizeMatchParameter(requiredMatchOriginal, exprMaxLength)
exprCoeffs = currentExpr.as_coefficients_dict()
subexprCoeffDict = subExpression.as_coefficients_dict()
intersection = set(subexprCoeffDict.keys()).intersection(set(exprCoeffs))
if len(intersection) >= max(normalizedReplacementMatch, normalizedCurrentExprMatch):
# find common factor
factors = defaultdict(lambda: 0)
skips = 0
for commonSymbol in subexprCoeffDict.keys():
if commonSymbol not in exprCoeffs:
skips += 1
continue
factor = exprCoeffs[commonSymbol] / subexprCoeffDict[commonSymbol]
factors[sp.simplify(factor)] += 1
commonFactor = max(factors.items(), key=operator.itemgetter(1))[0]
if factors[commonFactor] >= max(normalizedCurrentExprMatch, normalizedReplacementMatch):
return currentExpr - commonFactor * subExpression + commonFactor * replacement
# if no subexpression was found
paramList = [visit(a) for a in currentExpr.args]
if not paramList:
return currentExpr
else:
return currentExpr.func(*paramList, evaluate=False)
return visit(expr)
def replaceSecondOrderProducts(expr, searchSymbols, positive=None, replaceMixed=None):
"""
Replaces second order mixed terms like x*y by 2* ( (x+y)**2 - x**2 - y**2 )
This makes the term longer - simplify usually is undoing these - however this
transformation can be done to find more common sub-expressions
:param expr: input expression
:param searchSymbols: list of symbols that are searched for
Example: given [ x,y,z] terms like x*y, x*z, z*y are replaced
:param positive: there are two ways to do this substitution, either with term
(x+y)**2 or (x-y)**2 . if positive=True the first version is done,
if positive=False the second version is done, if positive=None the
sign is determined by the sign of the mixed term that is replaced
:param replaceMixed: if a list is passed here the expr x+y or x-y is replaced by a special new symbol
the replacement equation is added to the list
:return:
"""
if replaceMixed is not None:
mixedSymbolsReplaced = set([e.lhs for e in replaceMixed])
if expr.is_Mul:
distinctVelTerms = set()
nrOfVelTerms = 0
otherFactors = 1
for t in expr.args:
if t in searchSymbols:
nrOfVelTerms += 1
distinctVelTerms.add(t)
else:
otherFactors *= t
if len(distinctVelTerms) == 2 and nrOfVelTerms == 2:
u, v = list(distinctVelTerms)
if positive is None:
otherFactorsWithoutSymbols = otherFactors
for s in otherFactors.atoms(sp.Symbol):
otherFactorsWithoutSymbols = otherFactorsWithoutSymbols.subs(s, 1)
positive = otherFactorsWithoutSymbols.is_positive
assert positive is not None
sign = 1 if positive else -1
if replaceMixed is not None:
newSymbolStr = 'P' if positive else 'M'
mixedSymbolName = u.name + newSymbolStr + v.name
mixedSymbol = sp.Symbol(mixedSymbolName.replace("_", ""))
if mixedSymbol not in mixedSymbolsReplaced:
mixedSymbolsReplaced.add(mixedSymbol)
replaceMixed.append(sp.Eq(mixedSymbol, u + sign * v))
else:
mixedSymbol = u + sign * v
return sp.Rational(1, 2) * sign * otherFactors * (mixedSymbol ** 2 - u ** 2 - v ** 2)
paramList = [replaceSecondOrderProducts(a, searchSymbols, positive, replaceMixed) for a in expr.args]
result = expr.func(*paramList, evaluate=False) if paramList else expr
return result
def removeHigherOrderTerms(term, order=3, symbols=None):
"""
Remove all terms from a sum that contain 'order' or more factors of given 'symbols'
Example: symbols = ['u_x', 'u_y'] and order =2
removes terms u_x**2, u_x*u_y, u_y**2, u_x**3, ....
"""
from sympy.core.power import Pow
from sympy.core.add import Add, Mul
result = 0
term = term.expand()
if not symbols:
symbols = sp.symbols(" ".join(["u_%d" % (i,) for i in range(3)]))
symbols += sp.symbols(" ".join(["u_%d" % (i,) for i in range(3)]), real=True)
def velocityFactorsInProduct(product):
uFactorCount = 0
for factor in product.args:
if type(factor) == Pow:
if factor.args[0] in symbols:
uFactorCount += factor.args[1]
if factor in symbols:
uFactorCount += 1
return uFactorCount
if type(term) == Mul:
if velocityFactorsInProduct(term) <= order:
return term
else:
return sp.Rational(0, 1)
if type(term) != Add:
return term
for sumTerm in term.args:
if velocityFactorsInProduct(sumTerm) <= order:
result += sumTerm
return result
def completeTheSquare(expr, symbolToComplete, newVariable):
"""
Transforms second order polynomial into only squared part i.e.
a*symbolToComplete**2 + b*symbolToComplete + c
is transformed into
newVariable**2 + d
returns replacedExpr, "a tuple to to replace newVariable such that old expr comes out again"
if given expr is not a second order polynomial:
return expr, None
"""
p = sp.Poly(expr, symbolToComplete)
coeffs = p.all_coeffs()
if len(coeffs) != 3:
return expr, None
a, b, _ = coeffs
expr = expr.subs(symbolToComplete, newVariable - b / (2 * a))
return sp.simplify(expr), (newVariable, symbolToComplete + b / (2 * a))
def makeExponentialFuncArgumentSquares(expr, variablesToCompleteSquares):
"""Completes squares in arguments of exponential which makes them simpler to integrate
Very useful for integrating Maxwell-Boltzmann and its moment generating function"""
expr = sp.simplify(expr)
dim = len(variablesToCompleteSquares)
dummies = [sp.Dummy() for i in range(dim)]
def visit(term):
if term.func == sp.exp:
expArg = term.args[0]
for i in range(dim):
expArg, substitution = completeTheSquare(expArg, variablesToCompleteSquares[i], dummies[i])
return sp.exp(sp.simplify(expArg))
else:
paramList = [visit(a) for a in term.args]
if not paramList:
return term
else:
return term.func(*paramList)
result = visit(expr)
for i in range(dim):
result = result.subs(dummies[i], variablesToCompleteSquares[i])
return result
def pow2mul(expr):
"""
Convert integer powers in an expression to Muls, like a**2 => a*a.
"""
pows = list(expr.atoms(sp.Pow))
if any(not e.is_Integer for b, e in (i.as_base_exp() for i in pows)):
raise ValueError("A power contains a non-integer exponent")
repl = zip(pows, (sp.Mul(*[b]*e, evaluate=False) for b, e in (i.as_base_exp() for i in pows)))
return expr.subs(repl)
def extractMostCommonFactor(term):
"""Processes a sum of fractions: determines the most common factor and splits term in common factor and rest"""
import operator
from collections import Counter
from sympy.functions import Abs
coeffDict = term.as_coefficients_dict()
counter = Counter([Abs(v) for v in coeffDict.values()])
commonFactor, occurances = max(counter.items(), key=operator.itemgetter(1))
if occurances == 1 and (1 in counter):
commonFactor = 1
return commonFactor, term / commonFactor
def countNumberOfOperations(term):
"""
Counts the number of additions, multiplications and division
:param term: a sympy term, equation or sequence of terms/equations
:return: a dictionary with 'adds', 'muls' and 'divs' keys
"""
result = {'adds': 0, 'muls': 0, 'divs': 0}
if isinstance(term, Sequence):
for element in term:
r = countNumberOfOperations(element)
for operationName in result.keys():
result[operationName] += r[operationName]
return result
elif isinstance(term, sp.Eq):
term = term.rhs
term = term.evalf()
def visit(t):
visitChildren = True
if t.func is sp.Add:
result['adds'] += len(t.args) - 1
elif t.func is sp.Mul:
result['muls'] += len(t.args) - 1
for a in t.args:
if a == 1 or a == -1:
result['muls'] -= 1
elif t.func is sp.Float:
pass
elif isinstance(t, sp.Symbol):
pass
elif t.is_integer:
pass
elif t.func is sp.Pow:
visitChildren = False
if t.exp.is_integer and t.exp.is_number:
if t.exp >= 0:
result['muls'] += int(t.exp) - 1
else:
result['muls'] -= 1
result['divs'] += 1
result['muls'] += (-int(t.exp)) - 1
else:
warnings.warn("Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate")
else:
warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate")
if visitChildren:
for a in t.args:
visit(a)
visit(term)
return result