Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 1920 additions and 613 deletions
...@@ -6,8 +6,7 @@ import sympy as sp ...@@ -6,8 +6,7 @@ import sympy as sp
import pystencils import pystencils
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.simp.simplifications import ( from pystencils.simp.simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
from pystencils.sympyextensions import count_operations, fast_subs from pystencils.sympyextensions import count_operations, fast_subs
...@@ -20,13 +19,14 @@ class AssignmentCollection: ...@@ -20,13 +19,14 @@ class AssignmentCollection:
Additionally a dictionary of simplification hints is stored, which are set by the functions that create Additionally a dictionary of simplification hints is stored, which are set by the functions that create
assignment collections to transport information to the simplification system. assignment collections to transport information to the simplification system.
Attributes: Args:
main_assignments: list of assignments main_assignments: List of assignments. Main assignments are characterised, that the right hand side of each
subexpressions: list of assignments defining subexpressions used in main equations assignment is a field access. Thus the generated equations write on arrays.
simplification_hints: dict that is used to annotate the assignment collection with hints that are subexpressions: List of assignments defining subexpressions used in main equations
simplification_hints: Dict that is used to annotate the assignment collection with hints that are
used by the simplification system. See documentation of the simplification rules for used by the simplification system. See documentation of the simplification rules for
potentially required hints and their meaning. potentially required hints and their meaning.
subexpression_symbol_generator: generator for new symbols that are used when new subexpressions are added subexpression_symbol_generator: Generator for new symbols that are used when new subexpressions are added
used to get new symbols that are unique for this AssignmentCollection used to get new symbols that are unique for this AssignmentCollection
""" """
...@@ -34,9 +34,13 @@ class AssignmentCollection: ...@@ -34,9 +34,13 @@ class AssignmentCollection:
# ------------------------------- Creation & Inplace Manipulation -------------------------------------------------- # ------------------------------- Creation & Inplace Manipulation --------------------------------------------------
def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]], def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = {}, subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = None,
simplification_hints: Optional[Dict[str, Any]] = None, simplification_hints: Optional[Dict[str, Any]] = None,
subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None: subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None:
if subexpressions is None:
subexpressions = {}
if isinstance(main_assignments, Dict): if isinstance(main_assignments, Dict):
main_assignments = [Assignment(k, v) main_assignments = [Assignment(k, v)
for k, v in main_assignments.items()] for k, v in main_assignments.items()]
...@@ -57,8 +61,11 @@ class AssignmentCollection: ...@@ -57,8 +61,11 @@ class AssignmentCollection:
self.simplification_hints = simplification_hints self.simplification_hints = simplification_hints
ctrs = [int(n.name[3:])for n in self.rhs_symbols if "xi_" in n.name]
max_ctr = max(ctrs) + 1 if len(ctrs) > 0 else 0
if subexpression_symbol_generator is None: if subexpression_symbol_generator is None:
self.subexpression_symbol_generator = SymbolGen() self.subexpression_symbol_generator = SymbolGen(ctr=max_ctr)
else: else:
self.subexpression_symbol_generator = subexpression_symbol_generator self.subexpression_symbol_generator = subexpression_symbol_generator
...@@ -103,16 +110,21 @@ class AssignmentCollection: ...@@ -103,16 +110,21 @@ class AssignmentCollection:
return self.subexpressions + self.main_assignments return self.subexpressions + self.main_assignments
@property @property
def free_symbols(self) -> Set[sp.Symbol]: def rhs_symbols(self) -> Set[sp.Symbol]:
"""All symbols used in the assignment collection, which do not occur as left hand sides in any assignment.""" """All symbols used in the assignment collection, which occur on the rhs of any assignment."""
free_symbols = set() rhs_symbols = set()
for eq in self.all_assignments: for eq in self.all_assignments:
if isinstance(eq, Assignment): if isinstance(eq, Assignment):
free_symbols.update(eq.rhs.atoms(sp.Symbol)) rhs_symbols.update(eq.rhs.atoms(sp.Symbol))
elif isinstance(eq, pystencils.astnodes.Node): elif isinstance(eq, pystencils.astnodes.Node):
free_symbols.update(eq.undefined_symbols) rhs_symbols.update(eq.undefined_symbols)
return free_symbols - self.bound_symbols return rhs_symbols
@property
def free_symbols(self) -> Set[sp.Symbol]:
"""All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
return self.rhs_symbols - self.bound_symbols
@property @property
def bound_symbols(self) -> Set[sp.Symbol]: def bound_symbols(self) -> Set[sp.Symbol]:
...@@ -127,11 +139,15 @@ class AssignmentCollection: ...@@ -127,11 +139,15 @@ class AssignmentCollection:
bound_symbols_set = bound_symbols_set.union(*[ bound_symbols_set = bound_symbols_set.union(*[
assignment.symbols_defined for assignment in self.all_assignments assignment.symbols_defined for assignment in self.all_assignments
if isinstance(assignment, pystencils.astnodes.Node) if isinstance(assignment, pystencils.astnodes.Node)
] ])
)
return bound_symbols_set return bound_symbols_set
@property
def rhs_fields(self):
"""All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
return {s.field for s in self.rhs_symbols if hasattr(s, 'field')}
@property @property
def free_fields(self): def free_fields(self):
"""All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment.""" """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
...@@ -145,11 +161,9 @@ class AssignmentCollection: ...@@ -145,11 +161,9 @@ class AssignmentCollection:
@property @property
def defined_symbols(self) -> Set[sp.Symbol]: def defined_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur as left-hand-sides of one of the main equations""" """All symbols which occur as left-hand-sides of one of the main equations"""
return (set( lhs_set = set([assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)])
[assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)] return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments
).union(*[assignment.symbols_defined for assignment in self.main_assignments if isinstance( if isinstance(assignment, pystencils.astnodes.Node)]))
assignment, pystencils.astnodes.Node)]
))
@property @property
def operation_count(self): def operation_count(self):
...@@ -210,6 +224,7 @@ class AssignmentCollection: ...@@ -210,6 +224,7 @@ class AssignmentCollection:
return {s: func(*args, **kwargs) for s, func in lambdas.items()} return {s: func(*args, **kwargs) for s, func in lambdas.items()}
return f return f
# ---------------------------- Creating new modified collections --------------------------------------------------- # ---------------------------- Creating new modified collections ---------------------------------------------------
def copy(self, def copy(self,
...@@ -263,7 +278,7 @@ class AssignmentCollection: ...@@ -263,7 +278,7 @@ class AssignmentCollection:
own_definitions = set([e.lhs for e in self.main_assignments]) own_definitions = set([e.lhs for e in self.main_assignments])
other_definitions = set([e.lhs for e in other.main_assignments]) other_definitions = set([e.lhs for e in other.main_assignments])
assert len(own_definitions.intersection(other_definitions)) == 0, \ assert len(own_definitions.intersection(other_definitions)) == 0, \
"Cannot new_merged, since both collection define the same symbols" "Cannot merge collections, since both define the same symbols"
own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions} own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions}
substitution_dict = {} substitution_dict = {}
...@@ -271,12 +286,13 @@ class AssignmentCollection: ...@@ -271,12 +286,13 @@ class AssignmentCollection:
processed_other_subexpression_equations = [] processed_other_subexpression_equations = []
for other_subexpression_eq in other.subexpressions: for other_subexpression_eq in other.subexpressions:
if other_subexpression_eq.lhs in own_subexpression_symbols: if other_subexpression_eq.lhs in own_subexpression_symbols:
if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]: new_rhs = fast_subs(other_subexpression_eq.rhs, substitution_dict)
if new_rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
continue # exact the same subexpression equation exists already continue # exact the same subexpression equation exists already
else: else:
# different definition - a new name has to be introduced # different definition - a new name has to be introduced
new_lhs = next(self.subexpression_symbol_generator) new_lhs = next(self.subexpression_symbol_generator)
new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict)) new_eq = Assignment(new_lhs, new_rhs)
processed_other_subexpression_equations.append(new_eq) processed_other_subexpression_equations.append(new_eq)
substitution_dict[other_subexpression_eq.lhs] = new_lhs substitution_dict[other_subexpression_eq.lhs] = new_lhs
else: else:
...@@ -299,9 +315,9 @@ class AssignmentCollection: ...@@ -299,9 +315,9 @@ class AssignmentCollection:
if eq.lhs in symbols_to_extract: if eq.lhs in symbols_to_extract:
new_assignments.append(eq) new_assignments.append(eq)
new_sub_expr = [eq for eq in self.subexpressions new_sub_expr = [eq for eq in self.all_assignments
if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract] if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract]
return AssignmentCollection(new_assignments, new_sub_expr) return self.copy(new_assignments, new_sub_expr)
def new_without_unused_subexpressions(self) -> 'AssignmentCollection': def new_without_unused_subexpressions(self) -> 'AssignmentCollection':
"""Returns new collection that only contains subexpressions required to compute the main assignments.""" """Returns new collection that only contains subexpressions required to compute the main assignments."""
...@@ -324,8 +340,10 @@ class AssignmentCollection: ...@@ -324,8 +340,10 @@ class AssignmentCollection:
new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments] new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments]
return self.copy(new_eqs, new_subexpressions) return self.copy(new_eqs, new_subexpressions)
def new_without_subexpressions(self, subexpressions_to_keep: Set[sp.Symbol] = set()) -> 'AssignmentCollection': def new_without_subexpressions(self, subexpressions_to_keep=None) -> 'AssignmentCollection':
"""Returns a new collection where all subexpressions have been inserted.""" """Returns a new collection where all subexpressions have been inserted."""
if subexpressions_to_keep is None:
subexpressions_to_keep = set()
if len(self.subexpressions) == 0: if len(self.subexpressions) == 0:
return self.copy() return self.copy()
...@@ -334,7 +352,7 @@ class AssignmentCollection: ...@@ -334,7 +352,7 @@ class AssignmentCollection:
kept_subexpressions = [] kept_subexpressions = []
if self.subexpressions[0].lhs in subexpressions_to_keep: if self.subexpressions[0].lhs in subexpressions_to_keep:
substitution_dict = {} substitution_dict = {}
kept_subexpressions = self.subexpressions[0] kept_subexpressions.append(self.subexpressions[0])
else: else:
substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
...@@ -353,6 +371,7 @@ class AssignmentCollection: ...@@ -353,6 +371,7 @@ class AssignmentCollection:
def _repr_html_(self): def _repr_html_(self):
"""Interface to Jupyter notebook, to display as a nicely formatted HTML table""" """Interface to Jupyter notebook, to display as a nicely formatted HTML table"""
def make_html_equation_table(equations): def make_html_equation_table(equations):
no_border = 'style="border:none"' no_border = 'style="border:none"'
html_table = '<table style="border:none; width: 100%; ">' html_table = '<table style="border:none; width: 100%; ">'
...@@ -378,10 +397,10 @@ class AssignmentCollection: ...@@ -378,10 +397,10 @@ class AssignmentCollection:
def __str__(self): def __str__(self):
result = "Subexpressions:\n" result = "Subexpressions:\n"
for eq in self.subexpressions: for eq in self.subexpressions:
result += "\t{eq}\n".format(eq=eq) result += f"\t{eq}\n"
result += "Main Assignments:\n" result += "Main Assignments:\n"
for eq in self.main_assignments: for eq in self.main_assignments:
result += "\t{eq}\n".format(eq=eq) result += f"\t{eq}\n"
return result return result
def __iter__(self): def __iter__(self):
...@@ -438,14 +457,17 @@ class AssignmentCollection: ...@@ -438,14 +457,17 @@ class AssignmentCollection:
class SymbolGen: class SymbolGen:
"""Default symbol generator producing number symbols ζ_0, ζ_1, ...""" """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
def __init__(self, symbol="xi"): def __init__(self, symbol="xi", dtype=None, ctr=0):
self._ctr = 0 self._ctr = ctr
self._symbol = symbol self._symbol = symbol
self._dtype = dtype
def __iter__(self): def __iter__(self):
return self return self
def __next__(self): def __next__(self):
name = "{}_{}".format(self._symbol, self._ctr) name = f"{self._symbol}_{self._ctr}"
self._ctr += 1 self._ctr += 1
if self._dtype is not None:
return pystencils.TypedSymbol(name, self._dtype)
return sp.Symbol(name) return sp.Symbol(name)
from itertools import chain from itertools import chain
from typing import Callable, List, Sequence, Union from typing import Callable, List, Sequence, Union
from collections import defaultdict
import sympy as sp import sympy as sp
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.astnodes import Node from pystencils.astnodes import Node
from pystencils.field import AbstractField, Field from pystencils.field import Field
from pystencils.sympyextensions import subs_additive from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
from pystencils.typing import TypedSymbol
def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]: def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
...@@ -18,7 +20,7 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node] ...@@ -18,7 +20,7 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]
elif isinstance(e1, Node): elif isinstance(e1, Node):
symbols = e1.symbols_defined symbols = e1.symbols_defined
else: else:
raise NotImplementedError("Cannot sort topologically. Object of type " + type(e1) + " cannot be handled.") raise NotImplementedError(f"Cannot sort topologically. Object of type {type(e1)} cannot be handled.")
for lhs in symbols: for lhs in symbols:
for c2, e2 in enumerate(assignments): for c2, e2 in enumerate(assignments):
...@@ -83,6 +85,39 @@ def subexpression_substitution_in_main_assignments(ac): ...@@ -83,6 +85,39 @@ def subexpression_substitution_in_main_assignments(ac):
return ac.copy(result) return ac.copy(result)
def add_subexpressions_for_constants(ac):
"""Extracts constant factors to subexpressions in the given assignment collection.
SymPy will exclude common factors from a sum only if they are symbols. This simplification
can be applied to exclude common numeric constants from multiple terms of a sum. As a consequence,
the number of multiplications is reduced and in some cases, more common subexpressions can be found.
"""
constants_to_subexp_dict = defaultdict(lambda: next(ac.subexpression_symbol_generator))
def visit(expr):
args = list(expr.args)
if len(args) == 0:
return expr
if isinstance(expr, sp.Add) or isinstance(expr, sp.Mul):
for i, arg in enumerate(args):
if is_constant(arg) and abs(arg) != 1:
if arg < 0:
args[i] = - constants_to_subexp_dict[- arg]
else:
args[i] = constants_to_subexp_dict[arg]
return expr.func(*(visit(a) for a in args))
main_assignments = [Assignment(a.lhs, visit(a.rhs)) for a in ac.main_assignments]
subexpressions = [Assignment(a.lhs, visit(a.rhs)) for a in ac.subexpressions]
symbols_to_collect = set(constants_to_subexp_dict.values())
main_assignments = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in main_assignments]
subexpressions = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in subexpressions]
subexpressions = [Assignment(symb, c) for c, symb in constants_to_subexp_dict.items()] + subexpressions
return ac.copy(main_assignments=main_assignments, subexpressions=subexpressions)
def add_subexpressions_for_divisions(ac): def add_subexpressions_for_divisions(ac):
r"""Introduces subexpressions for all divisions which have no constant in the denominator. r"""Introduces subexpressions for all divisions which have no constant in the denominator.
...@@ -112,14 +147,14 @@ def add_subexpressions_for_sums(ac): ...@@ -112,14 +147,14 @@ def add_subexpressions_for_sums(ac):
addends = [] addends = []
def contains_sum(term): def contains_sum(term):
if term.func == sp.add.Add: if term.func == sp.Add:
return True return True
if term.is_Atom: if term.is_Atom:
return False return False
return any([contains_sum(a) for a in term.args]) return any([contains_sum(a) for a in term.args])
def search_addends(term): def search_addends(term):
if term.func == sp.add.Add: if term.func == sp.Add:
if all([not contains_sum(a) for a in term.args]): if all([not contains_sum(a) for a in term.args]):
addends.extend(term.args) addends.extend(term.args)
for a in term.args: for a in term.args:
...@@ -128,18 +163,20 @@ def add_subexpressions_for_sums(ac): ...@@ -128,18 +163,20 @@ def add_subexpressions_for_sums(ac):
for eq in ac.all_assignments: for eq in ac.all_assignments:
search_addends(eq.rhs) search_addends(eq.rhs)
addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, AbstractField.AbstractAccess)] addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, Field.Access)]
new_symbol_gen = ac.subexpression_symbol_generator new_symbol_gen = ac.subexpression_symbol_generator
substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)} substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)}
return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False) return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)
def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True): def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True, data_type=None):
r"""Substitutes field accesses on rhs of assignments with subexpressions r"""Substitutes field accesses on rhs of assignments with subexpressions
Can change semantics of the update rule (which is the goal of this transformation) Can change semantics of the update rule (which is the goal of this transformation)
This is useful if a field should be update in place - all values are loaded before into subexpression variables, This is useful if a field should be update in place - all values are loaded before into subexpression variables,
then the new values are computed and written to the same field in-place. then the new values are computed and written to the same field in-place.
Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have
this data type. This is useful for mixed precision kernels
""" """
field_reads = set() field_reads = set()
to_iterate = [] to_iterate = []
...@@ -151,7 +188,17 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments ...@@ -151,7 +188,17 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments
for assignment in to_iterate: for assignment in to_iterate:
if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'): if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'):
field_reads.update(assignment.rhs.atoms(Field.Access)) field_reads.update(assignment.rhs.atoms(Field.Access))
substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads}
if not field_reads:
return ac
substitutions = dict()
for fa in field_reads:
lhs = next(ac.subexpression_symbol_generator)
if data_type is not None:
substitutions.update({fa: TypedSymbol(lhs.name, data_type)})
else:
substitutions.update({fa: lhs})
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True,
substitute_on_lhs=False, sort_topologically=False) substitute_on_lhs=False, sort_topologically=False)
...@@ -172,7 +219,7 @@ def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs): ...@@ -172,7 +219,7 @@ def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]): def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]):
"""Applies sympy expand operation to all equations in collection.""" """Applies a given operation to all equations in collection."""
def f(ac): def f(ac):
return ac.copy(transform_rhs(ac.main_assignments, operation)) return ac.copy(transform_rhs(ac.main_assignments, operation))
...@@ -189,3 +236,31 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]): ...@@ -189,3 +236,31 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):
f.__name__ = operation.__name__ f.__name__ = operation.__name__
return f return f
# TODO Markus
# make this really work for Assignmentcollections
# this function should ONLY evaluate
# do the optims_c99 elsewhere optionally
# def apply_sympy_optimisations(ac: AssignmentCollection):
# """ Evaluates constant expressions (e.g. :math:`\\sqrt{3}` will be replaced by its floating point representation)
# and applies the default sympy optimisations. See sympy.codegen.rewriting
# """
#
# # Evaluates all constant terms
#
# assignments = ac.all_assignments
#
# evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
# lambda p: p.evalf())
#
# sympy_optimisations = [evaluate_constant_terms] + list(optims_c99)
#
# assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
# if hasattr(a, 'lhs')
# else a for a in assignments]
# assignments_nodes = [a.atoms(SympyAssignment) for a in assignments]
# for a in chain.from_iterable(assignments_nodes):
# a.optimize(sympy_optimisations)
#
# return AssignmentCollection(assignments)
...@@ -92,7 +92,7 @@ class SimplificationStrategy: ...@@ -92,7 +92,7 @@ class SimplificationStrategy:
assignment_collection = t(assignment_collection) assignment_collection = t(assignment_collection)
end_time = timeit.default_timer() end_time = timeit.default_timer()
op = assignment_collection.operation_count op = assignment_collection.operation_count
time_str = "%.2f ms" % ((end_time - start_time) * 1000,) time_str = f"{(end_time - start_time) * 1000:.2f} ms"
total = op['adds'] + op['muls'] + op['divs'] total = op['adds'] + op['muls'] + op['divs']
report.add(ReportElement(t.__name__, time_str, op['adds'], op['muls'], op['divs'], total)) report.add(ReportElement(t.__name__, time_str, op['adds'], op['muls'], op['divs'], total))
return report return report
...@@ -129,7 +129,7 @@ class SimplificationStrategy: ...@@ -129,7 +129,7 @@ class SimplificationStrategy:
def _repr_html_(self): def _repr_html_(self):
def print_assignment_collection(title, c): def print_assignment_collection(title, c):
text = '<h5 style="padding-bottom:10px">%s</h5> <div style="padding-left:20px;">' % (title, ) text = f'<h5 style="padding-bottom:10px">{title}</h5> <div style="padding-left:20px;">'
if self.restrict_symbols: if self.restrict_symbols:
text += "\n".join(["$$" + sp.latex(e) + '$$' text += "\n".join(["$$" + sp.latex(e) + '$$'
for e in c.new_filtered(self.restrict_symbols).main_assignments]) for e in c.new_filtered(self.restrict_symbols).main_assignments])
...@@ -151,5 +151,5 @@ class SimplificationStrategy: ...@@ -151,5 +151,5 @@ class SimplificationStrategy:
def __repr__(self): def __repr__(self):
result = "Simplification Strategy:\n" result = "Simplification Strategy:\n"
for t in self._rules: for t in self._rules:
result += " - %s\n" % (t.__name__,) result += f" - {t.__name__}\n"
return result return result
import sympy as sp
from pystencils.sympyextensions import is_constant
# Subexpression Insertion
def insert_subexpressions(ac, selection_callback, skip=None):
"""
Removes a number of subexpressions from an assignment collection by
inserting their right-hand side wherever they occur.
Args:
- selection_callback: Function that is called to qualify subexpressions
for insertion. Should return `True` for any subexpression that is to be
inserted, and `False` otherwise.
- skip: Set of symbols (left-hand sides of subexpressions) that should be
ignored even if qualified by the callback.
"""
if skip is None:
skip = set()
i = 0
while i < len(ac.subexpressions):
exp = ac.subexpressions[i]
if exp.lhs not in skip and selection_callback(exp):
ac = ac.new_with_inserted_subexpression(exp.lhs)
else:
i += 1
return ac
def insert_aliases(ac, **kwargs):
"""Inserts subexpressions that are aliases of other symbols,
i.e. their right-hand side is only another symbol."""
return insert_subexpressions(ac, lambda x: isinstance(x.rhs, sp.Symbol), **kwargs)
def insert_zeros(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is zero."""
zero = sp.Integer(0)
return insert_subexpressions(ac, lambda x: x.rhs == zero, **kwargs)
def insert_constants(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is constant,
i.e. contains no symbols."""
return insert_subexpressions(ac, lambda x: is_constant(x.rhs), **kwargs)
def insert_symbol_times_minus_one(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is just a
negation of another symbol."""
def callback(exp):
rhs = exp.rhs
minus_one = sp.Integer(-1)
atoms = rhs.atoms(sp.Symbol)
return len(atoms) == 1 and rhs == minus_one * atoms.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_constant_multiples(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is a constant
multiplied with another symbol."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
numbers = rhs.atoms(sp.Number)
return len(symbols) == 1 and len(numbers) == 1 and \
rhs == numbers.pop() * symbols.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_constant_additions(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is a sum of a
constant and another symbol."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
numbers = rhs.atoms(sp.Number)
return len(symbols) == 1 and len(numbers) == 1 and \
rhs == numbers.pop() + symbols.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_squares(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is another symbol squared."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
return len(symbols) == 1 and rhs == symbols.pop() ** 2
return insert_subexpressions(ac, callback, **kwargs)
def bind_symbols_to_skip(insertion_function, skip):
return lambda ac: insertion_function(ac, skip=skip)
from pystencils.simp import (SimplificationStrategy, insert_constants, insert_symbol_times_minus_one,
insert_constant_multiples, insert_constant_additions, insert_squares, insert_zeros)
def create_simplification_strategy():
"""
Creates a default simplification `ps.simp.SimplificationStrategy`. The idea behind the default simplification
strategy is to reduce the number of subexpressions by inserting single constants and to evaluate constant
terms beforehand.
"""
s = SimplificationStrategy()
s.add(insert_symbol_times_minus_one)
s.add(insert_constant_multiples)
s.add(insert_constant_additions)
s.add(insert_squares)
s.add(insert_zeros)
s.add(insert_constants)
s.add(lambda ac: ac.new_without_unused_subexpressions())
...@@ -89,9 +89,12 @@ def shift_slice(slices, offset): ...@@ -89,9 +89,12 @@ def shift_slice(slices, offset):
raise ValueError() raise ValueError()
if hasattr(offset, '__len__'): if hasattr(offset, '__len__'):
return [shift_slice_component(k, off) for k, off in zip(slices, offset)] return tuple(shift_slice_component(k, off) for k, off in zip(slices, offset))
else: else:
return [shift_slice_component(k, offset) for k in slices] if isinstance(slices, slice) or isinstance(slices, int) or isinstance(slices, float):
return shift_slice_component(slices, offset)
else:
return tuple(shift_slice_component(k, offset) for k in slices)
def slice_from_direction(direction_name, dim, normal_offset=0, tangential_offset=0): def slice_from_direction(direction_name, dim, normal_offset=0, tangential_offset=0):
......
...@@ -5,6 +5,8 @@ from typing import Sequence ...@@ -5,6 +5,8 @@ from typing import Sequence
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from pystencils.utils import binary_numbers
def inverse_direction(direction): def inverse_direction(direction):
"""Returns inverse i.e. negative of given direction tuple """Returns inverse i.e. negative of given direction tuple
...@@ -34,6 +36,8 @@ def is_valid(stencil, max_neighborhood=None): ...@@ -34,6 +36,8 @@ def is_valid(stencil, max_neighborhood=None):
True True
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=1) >>> is_valid([(2, 0), (1, 0)], max_neighborhood=1)
False False
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=2)
True
""" """
expected_dim = len(stencil[0]) expected_dim = len(stencil[0])
for d in stencil: for d in stencil:
...@@ -67,8 +71,11 @@ def have_same_entries(s1, s2): ...@@ -67,8 +71,11 @@ def have_same_entries(s1, s2):
Examples: Examples:
>>> stencil1 = [(1, 0), (-1, 0), (0, 1), (0, -1)] >>> stencil1 = [(1, 0), (-1, 0), (0, 1), (0, -1)]
>>> stencil2 = [(-1, 0), (0, -1), (1, 0), (0, 1)] >>> stencil2 = [(-1, 0), (0, -1), (1, 0), (0, 1)]
>>> stencil3 = [(-1, 0), (0, -1), (1, 0)]
>>> have_same_entries(stencil1, stencil2) >>> have_same_entries(stencil1, stencil2)
True True
>>> have_same_entries(stencil1, stencil3)
False
""" """
if len(s1) != len(s2): if len(s1) != len(s2):
return False return False
...@@ -288,6 +295,38 @@ def direction_string_to_offset(direction: str, dim: int = 3): ...@@ -288,6 +295,38 @@ def direction_string_to_offset(direction: str, dim: int = 3):
return offset[:dim] return offset[:dim]
def adjacent_directions(direction):
"""
Returns all adjacent directions for a direction as tuple of tuples. This is useful for exmple to find all directions
relevant for neighbour communication.
Args:
direction: tuple representing a direction. For example (0, 1, 0) for the northern side
Examples:
>>> adjacent_directions((0, 0, 0))
((0, 0, 0),)
>>> adjacent_directions((0, 1, 0))
((0, 1, 0),)
>>> adjacent_directions((0, 1, 1))
((0, 0, 1), (0, 1, 0), (0, 1, 1))
>>> adjacent_directions((-1, -1))
((-1, -1), (-1, 0), (0, -1))
"""
result = set()
if all(e == 0 for e in direction):
result.add(direction)
return tuple(result)
binary_numbers_list = binary_numbers(len(direction))
for adjacent_direction in binary_numbers_list:
for i, entry in enumerate(direction):
if entry == 0:
adjacent_direction[i] = 0
if entry == -1 and adjacent_direction[i] == 1:
adjacent_direction[i] = -1
if not all(e == 0 for e in adjacent_direction):
result.add(tuple(adjacent_direction))
return tuple(sorted(result))
# -------------------------------------- Visualization ----------------------------------------------------------------- # -------------------------------------- Visualization -----------------------------------------------------------------
...@@ -314,6 +353,7 @@ def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs) ...@@ -314,6 +353,7 @@ def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs)
Args: Args:
stencil: sequence of directions stencil: sequence of directions
axes: optional matplotlib axes axes: optional matplotlib axes
figure: optional matplotlib figure
data: data to annotate the directions with, if none given, the indices are used data: data to annotate the directions with, if none given, the indices are used
textsize: size of annotation text textsize: size of annotation text
""" """
...@@ -335,7 +375,7 @@ def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs) ...@@ -335,7 +375,7 @@ def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs)
for direction, annotation in zip(stencil, data): for direction, annotation in zip(stencil, data):
assert len(direction) == 2, "Works only for 2D stencils" assert len(direction) == 2, "Works only for 2D stencils"
direction = tuple(int(i) for i in direction) direction = tuple(int(i) for i in direction)
if not(direction[0] == 0 and direction[1] == 0): if not (direction[0] == 0 and direction[1] == 0):
axes.arrow(0, 0, direction[0], direction[1], head_width=0.08, head_length=head_length, color='k') axes.arrow(0, 0, direction[0], direction[1], head_width=0.08, head_length=head_length, color='k')
if isinstance(annotation, sp.Basic): if isinstance(annotation, sp.Basic):
...@@ -351,7 +391,7 @@ def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs) ...@@ -351,7 +391,7 @@ def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs)
else: else:
return 0 return 0
text_position = [direction[c] + position_correction(direction[c]) for c in (0, 1)] text_position = [direction[c] + position_correction(direction[c]) for c in (0, 1)]
axes.text(*text_position, annotation, verticalalignment='center', axes.text(x=text_position[0], y=text_position[1], s=annotation, verticalalignment='center',
zorder=30, horizontalalignment='center', size=textsize, zorder=30, horizontalalignment='center', size=textsize,
bbox=dict(boxstyle=text_box_style, facecolor='#00b6eb', alpha=0.85, linewidth=0)) bbox=dict(boxstyle=text_box_style, facecolor='#00b6eb', alpha=0.85, linewidth=0))
...@@ -369,6 +409,7 @@ def plot_3d_slicing(stencil, slice_axis=2, figure=None, data=None, **kwargs): ...@@ -369,6 +409,7 @@ def plot_3d_slicing(stencil, slice_axis=2, figure=None, data=None, **kwargs):
Args: Args:
stencil: stencil as sequence of directions stencil: stencil as sequence of directions
slice_axis: 0, 1, or 2 indicating the axis to slice through slice_axis: 0, 1, or 2 indicating the axis to slice through
figure: optional matplotlib figure
data: optional data to print as text besides the arrows data: optional data to print as text besides the arrows
""" """
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -414,16 +455,17 @@ def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'): ...@@ -414,16 +455,17 @@ def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'):
FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
self._verts3d = xs, ys, zs self._verts3d = xs, ys, zs
def draw(self, renderer): def do_3d_projection(self, *_):
xs3d, ys3d, zs3d = self._verts3d xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
FancyArrowPatch.draw(self, renderer)
return np.min(zs)
if axes is None: if axes is None:
if figure is None: if figure is None:
figure = plt.figure() figure = plt.figure()
axes = figure.gca(projection='3d') axes = figure.add_subplot(projection='3d')
try: try:
axes.set_aspect("equal") axes.set_aspect("equal")
except NotImplementedError: except NotImplementedError:
...@@ -439,7 +481,7 @@ def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'): ...@@ -439,7 +481,7 @@ def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'):
r = [-1, 1] r = [-1, 1]
for s, e in combinations(np.array(list(product(r, r, r))), 2): for s, e in combinations(np.array(list(product(r, r, r))), 2):
if np.sum(np.abs(s - e)) == r[1] - r[0]: if np.sum(np.abs(s - e)) == r[1] - r[0]:
axes.plot3D(*zip(s, e), color="k", alpha=0.5) axes.plot(*zip(s, e), color="k", alpha=0.5)
for d, annotation in zip(stencil, data): for d, annotation in zip(stencil, data):
assert len(d) == 3, "Works only for 3D stencils" assert len(d) == 3, "Works only for 3D stencils"
...@@ -463,8 +505,8 @@ def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'): ...@@ -463,8 +505,8 @@ def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'):
else: else:
annotation = str(annotation) annotation = str(annotation)
axes.text(d[0] * text_offset, d[1] * text_offset, d[2] * text_offset, axes.text(x=d[0] * text_offset, y=d[1] * text_offset, z=d[2] * text_offset,
annotation, verticalalignment='center', zorder=30, s=annotation, verticalalignment='center', zorder=30,
size=textsize, bbox=dict(boxstyle=text_box_style, facecolor='#777777', alpha=0.6, linewidth=0)) size=textsize, bbox=dict(boxstyle=text_box_style, facecolor='#777777', alpha=0.6, linewidth=0))
axes.set_xlim([-text_offset * 1.1, text_offset * 1.1]) axes.set_xlim([-text_offset * 1.1, text_offset * 1.1])
......
...@@ -6,10 +6,14 @@ from functools import partial, reduce ...@@ -6,10 +6,14 @@ from functools import partial, reduce
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union
import sympy as sp import sympy as sp
from sympy import PolynomialError
from sympy.functions import Abs from sympy.functions import Abs
from sympy.core.numbers import Zero
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.data_types import cast_func, get_base_type, get_type_of_expression from pystencils.functions import DivFunc
from pystencils.typing import CastFunc, get_type_of_expression, PointerType, VectorType
from pystencils.typing.typed_sympy import FieldPointerSymbol
T = TypeVar('T') T = TypeVar('T')
...@@ -156,17 +160,23 @@ def fast_subs(expression: T, substitutions: Dict, ...@@ -156,17 +160,23 @@ def fast_subs(expression: T, substitutions: Dict,
if type(expression) is sp.Matrix: if type(expression) is sp.Matrix:
return expression.copy().applyfunc(partial(fast_subs, substitutions=substitutions)) return expression.copy().applyfunc(partial(fast_subs, substitutions=substitutions))
def visit(expr): def visit(expr, evaluate=True):
if skip and skip(expr): if skip and skip(expr):
return expr return expr
if hasattr(expr, "fast_subs"): elif hasattr(expr, "fast_subs"):
return expr.fast_subs(substitutions, skip) return expr.fast_subs(substitutions, skip)
if expr in substitutions: elif expr in substitutions:
return substitutions[expr] return substitutions[expr]
if not hasattr(expr, 'args'): elif not hasattr(expr, 'args'):
return expr return expr
param_list = [visit(a) for a in expr.args] elif isinstance(expr, (sp.UnevaluatedExpr, DivFunc)):
return expr if not param_list else expr.func(*param_list) args = [visit(a, False) for a in expr.args]
return expr.func(*args)
else:
param_list = [visit(a, evaluate) for a in expr.args]
if isinstance(expr, (sp.Mul, sp.Add)):
return expr if not param_list else expr.func(*param_list, evaluate=evaluate)
return expr if not param_list else expr.func(*param_list)
if len(substitutions) == 0: if len(substitutions) == 0:
return expression return expression
...@@ -233,6 +243,9 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, ...@@ -233,6 +243,9 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
normalized_replacement_match = normalize_match_parameter(required_match_replacement, len(subexpression.args)) normalized_replacement_match = normalize_match_parameter(required_match_replacement, len(subexpression.args))
if isinstance(subexpression, sp.Number):
return expr.subs({replacement: subexpression})
def visit(current_expr): def visit(current_expr):
if current_expr.is_Add: if current_expr.is_Add:
expr_max_length = max(len(current_expr.args), len(subexpression.args)) expr_max_length = max(len(current_expr.args), len(subexpression.args))
...@@ -260,8 +273,8 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, ...@@ -260,8 +273,8 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
if not param_list: if not param_list:
return current_expr return current_expr
else: else:
if current_expr.func == sp.Mul and sp.numbers.Zero() in param_list: if current_expr.func == sp.Mul and Zero() in param_list:
return sp.numbers.Zero() return sp.simplify(current_expr)
else: else:
return current_expr.func(*param_list, evaluate=False) return current_expr.func(*param_list, evaluate=False)
...@@ -271,7 +284,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, ...@@ -271,7 +284,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol], def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol],
positive: Optional[bool] = None, positive: Optional[bool] = None,
replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr: replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr:
"""Replaces second order mixed terms like x*y by 2*( (x+y)**2 - x**2 - y**2 ). """Replaces second order mixed terms like 4*x*y by 2*( (x+y)**2 - x**2 - y**2 ).
This makes the term longer - simplify usually is undoing these - however this This makes the term longer - simplify usually is undoing these - however this
transformation can be done to find more common sub-expressions transformation can be done to find more common sub-expressions
...@@ -292,7 +305,7 @@ def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Sym ...@@ -292,7 +305,7 @@ def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Sym
if expr.is_Mul: if expr.is_Mul:
distinct_search_symbols = set() distinct_search_symbols = set()
nr_of_search_terms = 0 nr_of_search_terms = 0
other_factors = 1 other_factors = sp.Integer(1)
for t in expr.args: for t in expr.args:
if t in search_symbols: if t in search_symbols:
nr_of_search_terms += 1 nr_of_search_terms += 1
...@@ -343,7 +356,7 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order ...@@ -343,7 +356,7 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order
factor_count = 0 factor_count = 0
if type(product) is Mul: if type(product) is Mul:
for factor in product.args: for factor in product.args:
if type(factor) == Pow: if type(factor) is Pow:
if factor.args[0] in symbols: if factor.args[0] in symbols:
factor_count += factor.args[1] factor_count += factor.args[1]
if factor in symbols: if factor in symbols:
...@@ -353,13 +366,13 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order ...@@ -353,13 +366,13 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order
factor_count += product.args[1] factor_count += product.args[1]
return factor_count return factor_count
if type(expr) == Mul or type(expr) == Pow: if type(expr) is Mul or type(expr) is Pow:
if velocity_factors_in_product(expr) <= order: if velocity_factors_in_product(expr) <= order:
return expr return expr
else: else:
return sp.Rational(0, 1) return Zero()
if type(expr) != Add: if type(expr) is not Add:
return expr return expr
for sum_term in expr.args: for sum_term in expr.args:
...@@ -429,7 +442,104 @@ def extract_most_common_factor(term): ...@@ -429,7 +442,104 @@ def extract_most_common_factor(term):
return common_factor, term / common_factor return common_factor, term / common_factor
def count_operations(term: Union[sp.Expr, List[sp.Expr]], def recursive_collect(expr, symbols, order_by_occurences=False):
"""Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1,
and so on.
``expr`` must be rewritable as a polynomial in the given ``symbols``.
It it is not, ``recursive_collect`` will fail quietly, returning the original expression.
Args:
expr: A sympy expression.
symbols: A sequence of symbols
order_by_occurences: If True, during recursive descent, always collect the symbol occuring
most often in the expression.
"""
if order_by_occurences:
symbols = list(expr.atoms(sp.Symbol) & set(symbols))
symbols = sorted(symbols, key=expr.count, reverse=True)
if len(symbols) == 0:
return expr
symbol = symbols[0]
collected = expr.collect(symbol)
try:
collected_poly = sp.Poly(collected, symbol)
except PolynomialError:
return expr
coeffs = collected_poly.all_coeffs()[::-1]
rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs))
return rec_sum
def summands(expr):
return set(expr.args) if isinstance(expr, sp.Add) else {expr}
def simplify_by_equality(expr, a, b, c):
"""
Uses the equality a = b + c, where a and b must be symbols, to simplify expr
by attempting to express additive combinations of two quantities by the third.
This works on expressions that are reducible to the form
:math:`a * (...) + b * (...) + c * (...)`,
without any mixed terms of a, b and c.
"""
if not isinstance(a, sp.Symbol) or not isinstance(b, sp.Symbol):
raise ValueError("a and b must be symbols.")
c = sp.sympify(c)
if not (isinstance(c, sp.Symbol) or is_constant(c)):
raise ValueError("c must be either a symbol or a constant!")
expr = sp.sympify(expr)
expr_expanded = sp.expand(expr)
a_coeff = expr_expanded.coeff(a, 1)
expr_expanded -= (a * a_coeff).expand()
b_coeff = expr_expanded.coeff(b, 1)
expr_expanded -= (b * b_coeff).expand()
if isinstance(c, sp.Symbol):
c_coeff = expr_expanded.coeff(c, 1)
rest = expr_expanded - (c * c_coeff).expand()
else:
c_coeff = expr_expanded / c
rest = 0
a_summands = summands(a_coeff)
b_summands = summands(b_coeff)
c_summands = summands(c_coeff)
# replace b + c by a
b_plus_c_coeffs = b_summands & c_summands
for coeff in b_plus_c_coeffs:
rest += a * coeff
b_summands -= b_plus_c_coeffs
c_summands -= b_plus_c_coeffs
# replace a - b by c
neg_b_summands = {-x for x in b_summands}
a_minus_b_coeffs = a_summands & neg_b_summands
for coeff in a_minus_b_coeffs:
rest += c * coeff
a_summands -= a_minus_b_coeffs
b_summands -= {-x for x in a_minus_b_coeffs}
# replace a - c by b
neg_c_summands = {-x for x in c_summands}
a_minus_c_coeffs = a_summands & neg_c_summands
for coeff in a_minus_c_coeffs:
rest += b * coeff
a_summands -= a_minus_c_coeffs
c_summands -= {-x for x in a_minus_c_coeffs}
# put it back together
return (rest + a * sum(a_summands) + b * sum(b_summands) + c * sum(c_summands)).expand()
def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
only_type: Optional[str] = 'real') -> Dict[str, int]: only_type: Optional[str] = 'real') -> Dict[str, int]:
"""Counts the number of additions, multiplications and division. """Counts the number of additions, multiplications and division.
...@@ -444,7 +554,6 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -444,7 +554,6 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0, result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0,
'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0} 'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0}
if isinstance(term, Sequence): if isinstance(term, Sequence):
for element in term: for element in term:
r = count_operations(element, only_type) r = count_operations(element, only_type)
...@@ -454,16 +563,20 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -454,16 +563,20 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
elif isinstance(term, Assignment): elif isinstance(term, Assignment):
term = term.rhs term = term.rhs
if hasattr(term, 'evalf'):
term = term.evalf()
def check_type(e): def check_type(e):
if only_type is None: if only_type is None:
return True return True
if isinstance(e, FieldPointerSymbol) and only_type == "real":
return only_type == "int"
try: try:
base_type = get_base_type(get_type_of_expression(e)) base_type = get_type_of_expression(e)
except ValueError: except ValueError:
return False return False
if isinstance(base_type, VectorType):
return False
if isinstance(base_type, PointerType):
return only_type == 'int'
if only_type == 'int' and (base_type.is_int() or base_type.is_uint()): if only_type == 'int' and (base_type.is_int() or base_type.is_uint()):
return True return True
if only_type == 'real' and (base_type.is_float()): if only_type == 'real' and (base_type.is_float()):
...@@ -492,7 +605,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -492,7 +605,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
visit_children = False visit_children = False
elif t.is_integer: elif t.is_integer:
pass pass
elif isinstance(t, cast_func): elif isinstance(t, CastFunc):
visit_children = False visit_children = False
visit(t.args[0]) visit(t.args[0])
elif t.func is fast_sqrt: elif t.func is fast_sqrt:
...@@ -508,13 +621,17 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -508,13 +621,17 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
if t.exp >= 0: if t.exp >= 0:
result['muls'] += int(t.exp) - 1 result['muls'] += int(t.exp) - 1
else: else:
result['muls'] -= 1 if result['muls'] > 0:
result['muls'] -= 1
result['divs'] += 1 result['divs'] += 1
result['muls'] += (-int(t.exp)) - 1 result['muls'] += (-int(t.exp)) - 1
elif sp.nsimplify(t.exp) == sp.Rational(1, 2): elif sp.nsimplify(t.exp) == sp.Rational(1, 2):
result['sqrts'] += 1 result['sqrts'] += 1
elif sp.nsimplify(t.exp) == -sp.Rational(1, 2):
result["sqrts"] += 1
result["divs"] += 1
else: else:
warnings.warn("Cannot handle exponent", t.exp, " of sp.Pow node") warnings.warn(f"Cannot handle exponent {t.exp} of sp.Pow node")
else: else:
warnings.warn("Counting operations: only integer exponents are supported in Pow, " warnings.warn("Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate") "counting will be inaccurate")
...@@ -522,10 +639,12 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -522,10 +639,12 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
for child_term, condition in t.args: for child_term, condition in t.args:
visit(child_term) visit(child_term)
visit_children = False visit_children = False
elif isinstance(t, sp.Rel): elif isinstance(t, (sp.Rel, sp.UnevaluatedExpr)):
pass pass
elif isinstance(t, DivFunc):
result["divs"] += 1
else: else:
warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate") warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate")
if visit_children: if visit_children:
for a in t.args: for a in t.args:
......
File moved
import hashlib import hashlib
import pickle import pickle
import warnings import warnings
from collections import OrderedDict, defaultdict, namedtuple from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
from types import MappingProxyType from types import MappingProxyType
from typing import Set
import numpy as np
import sympy as sp import sympy as sp
from sympy.core.numbers import ImaginaryUnit
from sympy.logic.boolalg import Boolean, BooleanFunction
import pystencils as ps
import pystencils.astnodes as ast import pystencils.astnodes as ast
import pystencils.integer_functions
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.data_types import ( from pystencils.typing import (CastFunc, PointerType, StructType, TypedSymbol, get_base_type,
PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type, ReinterpretCastFunc, get_next_parent_of_type, parents_of_type)
get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func) from pystencils.field import Field, FieldType
from pystencils.field import AbstractField, Field, FieldType from pystencils.typing import FieldPointerSymbol
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.sympyextensions import fast_subs
from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.slicing import normalize_slice from pystencils.slicing import normalize_slice
from pystencils.integer_functions import int_div
class NestedScopes: class NestedScopes:
...@@ -101,6 +100,45 @@ def generic_visit(term, visitor): ...@@ -101,6 +100,45 @@ def generic_visit(term, visitor):
return visitor(term) return visitor(term)
def iterate_loops_by_depth(node, nesting_depth):
"""Iterate all LoopOverCoordinate nodes in the given AST of the specified nesting depth.
Args:
node: Root node of the abstract syntax tree
nesting_depth: Nesting depth of the loops the pragmas should be applied to.
Outermost loop has depth 0.
A depth of -1 indicates the innermost loops.
Returns: Iterable listing all loop nodes of given nesting depth.
"""
from pystencils.astnodes import LoopOverCoordinate
def _internal_default(node, nesting_depth):
isloop = isinstance(node, LoopOverCoordinate)
if nesting_depth < 0: # here, a negative value indicates end of descent
return
elif nesting_depth == 0 and isloop:
yield node
else:
next_depth = nesting_depth - 1 if isloop else nesting_depth
for arg in node.args:
yield from _internal_default(arg, next_depth)
def _internal_innermost(node):
if isinstance(node, LoopOverCoordinate) and node.is_innermost_loop:
yield node
else:
for arg in node.args:
yield from _internal_innermost(arg)
if nesting_depth >= 0:
yield from _internal_default(node, nesting_depth)
elif nesting_depth == -1:
yield from _internal_innermost(node)
else:
raise ValueError(f"Invalid nesting depth: {nesting_depth}. Choose a nonnegative number, or -1.")
def unify_shape_symbols(body, common_shape, fields): def unify_shape_symbols(body, common_shape, fields):
"""Replaces symbols for array sizes to ensure they are represented by the same unique symbol. """Replaces symbols for array sizes to ensure they are represented by the same unique symbol.
...@@ -125,9 +163,10 @@ def unify_shape_symbols(body, common_shape, fields): ...@@ -125,9 +163,10 @@ def unify_shape_symbols(body, common_shape, fields):
body.subs(substitutions) body.subs(substitutions)
def get_common_shape(field_set): def get_common_field(field_set):
"""Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise """Takes a set of pystencils Fields, checks if a common spatial shape exists and returns one
ValueError is raised""" representative field, that can be used for shape information etc. in the kernel creation.
If the fields have different shapes ValueError is raised"""
nr_of_fixed_shaped_fields = 0 nr_of_fixed_shaped_fields = 0
for f in field_set: for f in field_set:
if f.has_fixed_shape: if f.has_fixed_shape:
...@@ -137,7 +176,7 @@ def get_common_shape(field_set): ...@@ -137,7 +176,7 @@ def get_common_shape(field_set):
fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape]) fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape])
var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape]) var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape])
msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n" msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
msg += "Variable shaped: %s \nFixed shaped: %s" % (var_field_names, fixed_field_names) msg += f"Variable shaped: {var_field_names} \nFixed shaped: {fixed_field_names}"
raise ValueError(msg) raise ValueError(msg)
shape_set = set([f.spatial_shape for f in field_set]) shape_set = set([f.spatial_shape for f in field_set])
...@@ -145,8 +184,9 @@ def get_common_shape(field_set): ...@@ -145,8 +184,9 @@ def get_common_shape(field_set):
if len(shape_set) != 1: if len(shape_set) != 1:
raise ValueError("Differently sized field accesses in loop body: " + str(shape_set)) raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0] # Sort the fields by their name to ensure that always the same field is returned
return shape reference_field = sorted(field_set, key=lambda e: str(e))[0]
return reference_field
def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None): def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None):
...@@ -164,9 +204,11 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or ...@@ -164,9 +204,11 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
tuple of loop-node, ghost_layer_info tuple of loop-node, ghost_layer_info
""" """
# find correct ordering by inspecting participating FieldAccesses # find correct ordering by inspecting participating FieldAccesses
field_accesses = body.atoms(AbstractField.AbstractAccess) absolut_accesses_only = False
field_accesses = body.atoms(Field.Access)
field_accesses = {e for e in field_accesses if not e.is_absolute_access} field_accesses = {e for e in field_accesses if not e.is_absolute_access}
if len(field_accesses) == 0: # when kernel contains only absolute accesses
absolut_accesses_only = True
# exclude accesses to buffers from field_list, because buffers are treated separately # exclude accesses to buffers from field_list, because buffers are treated separately
field_list = [e.field for e in field_accesses if not (FieldType.is_buffer(e.field) or FieldType.is_custom(e.field))] field_list = [e.field for e in field_accesses if not (FieldType.is_buffer(e.field) or FieldType.is_custom(e.field))]
if len(field_list) == 0: # when kernel contains only custom fields if len(field_list) == 0: # when kernel contains only custom fields
...@@ -177,14 +219,23 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or ...@@ -177,14 +219,23 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
if loop_order is None: if loop_order is None:
loop_order = get_optimal_loop_ordering(fields) loop_order = get_optimal_loop_ordering(fields)
shape = get_common_shape(fields) if absolut_accesses_only:
unify_shape_symbols(body, common_shape=shape, fields=fields) absolut_access_fields = {e.field for e in body.atoms(Field.Access)}
common_field = get_common_field(absolut_access_fields)
common_shape = common_field.spatial_shape
else:
common_field = get_common_field(fields)
common_shape = common_field.spatial_shape
unify_shape_symbols(body, common_shape=common_shape, fields=fields)
if iteration_slice is not None: if iteration_slice is not None:
iteration_slice = normalize_slice(iteration_slice, shape) iteration_slice = normalize_slice(iteration_slice, common_shape)
if ghost_layers is None: if ghost_layers is None:
required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses]) if absolut_accesses_only:
required_ghost_layers = 0
else:
required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order) ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
if isinstance(ghost_layers, int): if isinstance(ghost_layers, int):
ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order) ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)
...@@ -193,7 +244,7 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or ...@@ -193,7 +244,7 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
for i, loop_coordinate in enumerate(reversed(loop_order)): for i, loop_coordinate in enumerate(reversed(loop_order)):
if iteration_slice is None: if iteration_slice is None:
begin = ghost_layers[loop_coordinate][0] begin = ghost_layers[loop_coordinate][0]
end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1] end = common_shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1) new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
current_body = ast.Block([new_loop]) current_body = ast.Block([new_loop])
else: else:
...@@ -210,6 +261,28 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or ...@@ -210,6 +261,28 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
return current_body, ghost_layers return current_body, ghost_layers
def get_common_indexed_element(indexed_elements: Set[sp.IndexedBase]) -> sp.IndexedBase:
assert len(indexed_elements) > 0, "indexed_elements can not be empty"
shape_set = {s.shape for s in indexed_elements}
if len(shape_set) != 1:
for shape in shape_set:
assert not isinstance(shape, int), "If indexed elements are used, they must all have the same shape"
return sorted(indexed_elements, key=lambda e: str(e))[0]
def add_outer_loop_over_indexed_elements(loop_node: ast.Block) -> ast.Block:
indexed_elements = loop_node.atoms(sp.Indexed)
if len(indexed_elements) == 0:
return loop_node
reference_element = get_common_indexed_element(indexed_elements)
index = reference_element.indices[0].atoms(TypedSymbol)
assert len(index) == 1, "index expressions must only contain one symbol representing the index"
new_loop = ast.LoopOverCoordinate(loop_node, 0, 0,
reference_element.shape[0], 1, custom_loop_ctr=index.pop())
return ast.Block([new_loop])
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr): def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
r""" r"""
Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]` Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
...@@ -326,7 +399,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dime ...@@ -326,7 +399,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dime
index = int(element[len("index"):]) index = int(element[len("index"):])
add_new_element(spatial_dimensions + index) add_new_element(spatial_dimensions + index)
else: else:
raise ValueError("Unknown specification %s" % (element,)) raise ValueError(f"Unknown specification {element}")
result.append(new_group) result.append(new_group)
...@@ -345,7 +418,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): ...@@ -345,7 +418,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
ast_node: ast before any field accesses are resolved ast_node: ast before any field accesses are resolved
loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes) loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
for GPU kernels: list of 'loop counters' from inner to outer loop for GPU kernels: list of 'loop counters' from inner to outer loop
loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default loop_iterations: iteration slice for each loop from inner to outer, for CPU kernels leave to default
Returns: Returns:
base buffer index - required by 'resolve_buffer_accesses' function base buffer index - required by 'resolve_buffer_accesses' function
...@@ -357,26 +430,46 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): ...@@ -357,26 +430,46 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
assert len(loops) == len(parents_of_innermost_loop) assert len(loops) == len(parents_of_innermost_loop)
assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop)) assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
loop_iterations = [(l.stop - l.start) / l.step for l in loops] loop_counters = [loop.loop_counter_symbol for loop in loops]
loop_counters = [l.loop_counter_symbol for l in loops] loop_iterations = [slice(loop.start, loop.stop, loop.step) for loop in loops]
actual_sizes = list()
actual_steps = list()
for ctr, s in zip(loop_counters, loop_iterations):
if s.step != 1:
if (s.stop - s.start) % s.step == 0:
actual_sizes.append((s.stop - s.start) // s.step)
else:
actual_sizes.append(int_div((s.stop - s.start), s.step))
if (ctr - s.start) % s.step == 0:
actual_steps.append((ctr - s.start) // s.step)
else:
actual_steps.append(int_div((ctr - s.start), s.step))
else:
actual_sizes.append(s.stop - s.start)
actual_steps.append(ctr - s.start)
field_accesses = ast_node.atoms(AbstractField.AbstractAccess) field_accesses = ast_node.atoms(Field.Access)
buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)} buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
loop_counters = [v * len(buffer_accesses) for v in loop_counters] buffer_index_size = len(buffer_accesses)
base_buffer_index = loop_counters[0] base_buffer_index = actual_steps[0]
stride = 1 actual_stride = 1
for idx, var in enumerate(loop_counters[1:]): for idx, actual_step in enumerate(actual_steps[1:]):
cur_stride = loop_iterations[idx] cur_stride = actual_sizes[idx]
stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride actual_stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
base_buffer_index += var * stride base_buffer_index += actual_stride * actual_step
return base_buffer_index return base_buffer_index * buffer_index_size
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()): def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=None):
if read_only_field_names is None:
read_only_field_names = set()
def visit_sympy_expr(expr, enclosing_block, sympy_assignment): def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, AbstractField.AbstractAccess): if isinstance(expr, Field.Access):
field_access = expr field_access = expr
# Do not apply transformation if field is not a buffer # Do not apply transformation if field is not a buffer
...@@ -419,7 +512,7 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s ...@@ -419,7 +512,7 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s
return visit_node(ast_node) return visit_node(ast_node)
def resolve_field_accesses(ast_node, read_only_field_names=set(), def resolve_field_accesses(ast_node, read_only_field_names=None,
field_to_base_pointer_info=MappingProxyType({}), field_to_base_pointer_info=MappingProxyType({}),
field_to_fixed_coordinates=MappingProxyType({})): field_to_fixed_coordinates=MappingProxyType({})):
""" """
...@@ -436,11 +529,13 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -436,11 +529,13 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
Returns Returns
transformed AST transformed AST
""" """
if read_only_field_names is None:
read_only_field_names = set()
field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0])) field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0])) field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
def visit_sympy_expr(expr, enclosing_block, sympy_assignment): def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, AbstractField.AbstractAccess): if isinstance(expr, Field.Access):
field_access = expr field_access = expr
field = field_access.field field = field_access.field
...@@ -456,10 +551,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -456,10 +551,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
if field.name in field_to_base_pointer_info: if field.name in field_to_base_pointer_info:
base_pointer_info = field_to_base_pointer_info[field.name] base_pointer_info = field_to_base_pointer_info[field.name]
else: else:
base_pointer_info = [ base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))]
list(
range(field.index_dimensions + field.spatial_dimensions))
]
field_ptr = FieldPointerSymbol( field_ptr = FieldPointerSymbol(
field.name, field.name,
...@@ -500,7 +592,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -500,7 +592,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
coord_dict = create_coordinate_dict(group) coord_dict = create_coordinate_dict(group)
new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer) new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
if new_ptr not in enclosing_block.symbols_defined: if new_ptr not in enclosing_block.symbols_defined:
new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False) new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False, use_auto=False)
enclosing_block.insert_before(new_assignment, sympy_assignment) enclosing_block.insert_before(new_assignment, sympy_assignment)
last_pointer = new_ptr last_pointer = new_ptr
...@@ -514,7 +606,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -514,7 +606,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
if isinstance(accessed_field_name, sp.Symbol): if isinstance(accessed_field_name, sp.Symbol):
accessed_field_name = accessed_field_name.name accessed_field_name = accessed_field_name.name
new_type = field_access.field.dtype.get_element_type(accessed_field_name) new_type = field_access.field.dtype.get_element_type(accessed_field_name)
result = reinterpret_cast_func(result, new_type) result = ReinterpretCastFunc(result, new_type)
return visit_sympy_expr(result, enclosing_block, sympy_assignment) return visit_sympy_expr(result, enclosing_block, sympy_assignment)
else: else:
...@@ -564,21 +656,65 @@ def move_constants_before_loop(ast_node): ...@@ -564,21 +656,65 @@ def move_constants_before_loop(ast_node):
""" """
assert isinstance(node.parent, ast.Block) assert isinstance(node.parent, ast.Block)
def modifies_or_declares(node: ast.Node, symbol_names: Set[str]) -> bool:
if isinstance(node, (ps.Assignment, ast.SympyAssignment)):
if isinstance(node.lhs, ast.ResolvedFieldAccess):
return node.lhs.typed_symbol.name in symbol_names
else:
return node.lhs.name in symbol_names
elif isinstance(node, ast.Block):
for arg in node.args:
if isinstance(arg, ast.SympyAssignment) and arg.is_declaration:
continue
if modifies_or_declares(arg, symbol_names):
return True
return False
elif isinstance(node, ast.LoopOverCoordinate):
return modifies_or_declares(node.body, symbol_names)
elif isinstance(node, ast.Conditional):
return (
modifies_or_declares(node.true_block, symbol_names)
or (node.false_block and modifies_or_declares(node.false_block, symbol_names))
)
elif isinstance(node, ast.KernelFunction):
return False
else:
defs = {s.name for s in node.symbols_defined}
return bool(symbol_names.intersection(defs))
dependencies = {s.name for s in node.undefined_symbols}
last_block = node.parent last_block = node.parent
last_block_child = node last_block_child = node
element = node.parent element = node.parent
prev_element = node prev_element = node
while element: while element:
if isinstance(element, ast.Block): if isinstance(element, (ast.Conditional, ast.KernelFunction)):
# Never move out of Conditionals or KernelFunctions.
break
elif isinstance(element, ast.Block):
last_block = element last_block = element
last_block_child = prev_element last_block_child = prev_element
if isinstance(element, ast.Conditional): if any(modifies_or_declares(sibling, dependencies) for sibling in element.args):
break # The node depends on one of the statements in this block.
# Do not move further out.
break
elif isinstance(element, ast.LoopOverCoordinate):
if element.loop_counter_symbol.name in dependencies:
# The node depends on the loop counter.
# Do not move out of this loop.
break
else: else:
critical_symbols = element.symbols_defined raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n'
if node.undefined_symbols.intersection(critical_symbols): f'The expression {element} of type {type(element)} is not known yet.')
break
# No dependencies to symbols defined/modified within the current element.
# We can move the node up one level and in front of the current element.
prev_element = element prev_element = element
element = element.parent element = element.parent
return last_block, last_block_child return last_block, last_block_child
...@@ -602,13 +738,7 @@ def move_constants_before_loop(ast_node): ...@@ -602,13 +738,7 @@ def move_constants_before_loop(ast_node):
get_blocks(ast_node, all_blocks) get_blocks(ast_node, all_blocks)
for block in all_blocks: for block in all_blocks:
children = block.take_child_nodes() children = block.take_child_nodes()
# Every time a symbol can be replaced in the current block because the assignment
# was found in a parent block, but with a different lhs symbol (same rhs)
# the outer symbol is inserted here as key.
substitute_variables = {}
for child in children: for child in children:
# Before traversing the next child, all symbols are substituted first.
child.subs(substitute_variables)
if not isinstance(child, ast.SympyAssignment): # only move SympyAssignments if not isinstance(child, ast.SympyAssignment): # only move SympyAssignments
block.append(child) block.append(child)
...@@ -624,14 +754,7 @@ def move_constants_before_loop(ast_node): ...@@ -624,14 +754,7 @@ def move_constants_before_loop(ast_node):
exists_already = False exists_already = False
if not exists_already: if not exists_already:
rhs_identical = check_if_assignment_already_in_block(child, target, True) target.insert_before(child, child_to_insert_before)
if rhs_identical:
# there is already an assignment out there with the same rhs
# -> replace all lhs symbols in this block with the lhs of the outer assignment
# -> remove the local assignment (do not re-append child to the former block)
substitute_variables[child.lhs] = rhs_identical.lhs
else:
target.insert_before(child, child_to_insert_before)
elif exists_already and exists_already.rhs == child.rhs: elif exists_already and exists_already.rhs == child.rhs:
if target.args.index(exists_already) > target.args.index(child_to_insert_before): if target.args.index(exists_already) > target.args.index(child_to_insert_before):
assert target.args.count(exists_already) == 1 assert target.args.count(exists_already) == 1
...@@ -645,7 +768,7 @@ def move_constants_before_loop(ast_node): ...@@ -645,7 +768,7 @@ def move_constants_before_loop(ast_node):
new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype) new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const), target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const),
child_to_insert_before) child_to_insert_before)
substitute_variables[child.lhs] = new_symbol block.append(ast.SympyAssignment(child.lhs, new_symbol, is_const=child.is_const))
def split_inner_loop(ast_node: ast.Node, symbol_groups): def split_inner_loop(ast_node: ast.Node, symbol_groups):
...@@ -659,11 +782,11 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): ...@@ -659,11 +782,11 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
and which no symbol in a symbol group depends on, are not updated! and which no symbol in a symbol group depends on, are not updated!
""" """
all_loops = ast_node.atoms(ast.LoopOverCoordinate) all_loops = ast_node.atoms(ast.LoopOverCoordinate)
inner_loop = [l for l in all_loops if l.is_innermost_loop] inner_loop = [loop for loop in all_loops if loop.is_innermost_loop]
assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?" assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
inner_loop = inner_loop[0] inner_loop = inner_loop[0]
assert type(inner_loop.body) is ast.Block assert type(inner_loop.body) is ast.Block
outer_loop = [l for l in all_loops if l.is_outermost_loop] outer_loop = [loop for loop in all_loops if loop.is_outermost_loop]
assert len(outer_loop) == 1, "Error in AST, multiple outermost loops." assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
outer_loop = outer_loop[0] outer_loop = outer_loop[0]
...@@ -682,13 +805,13 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): ...@@ -682,13 +805,13 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
if s in assignment_map: # if there is no assignment inside the loop body it is independent already if s in assignment_map: # if there is no assignment inside the loop body it is independent already
for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol): for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
if not isinstance(new_symbol, AbstractField.AbstractAccess) and \ if not isinstance(new_symbol, Field.Access) and \
new_symbol not in symbols_with_temporary_array: new_symbol not in symbols_with_temporary_array:
symbols_to_process.append(new_symbol) symbols_to_process.append(new_symbol)
symbols_resolved.add(s) symbols_resolved.add(s)
for symbol in symbol_group: for symbol in symbol_group:
if not isinstance(symbol, AbstractField.AbstractAccess): if not isinstance(symbol, Field.Access):
assert type(symbol) is TypedSymbol assert type(symbol) is TypedSymbol
new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype)) new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
symbols_with_temporary_array[symbol] = sp.IndexedBase( symbols_with_temporary_array[symbol] = sp.IndexedBase(
...@@ -697,9 +820,9 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): ...@@ -697,9 +820,9 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
assignment_group = [] assignment_group = []
for assignment in inner_loop.body.args: for assignment in inner_loop.body.args:
if assignment.lhs in symbols_resolved: if assignment.lhs in symbols_resolved:
new_rhs = assignment.rhs.subs( # use fast_subs here because it checks if multiplications should be evaluated or not
symbols_with_temporary_array.items()) new_rhs = fast_subs(assignment.rhs, symbols_with_temporary_array)
if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group: if not isinstance(assignment.lhs, Field.Access) and assignment.lhs in symbol_group:
assert type(assignment.lhs) is TypedSymbol assert type(assignment.lhs) is TypedSymbol
new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype)) new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
new_lhs = sp.IndexedBase(new_ts, shape=(1, ))[inner_loop.loop_counter_symbol] new_lhs = sp.IndexedBase(new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
...@@ -728,7 +851,8 @@ def cut_loop(loop_node, cutting_points): ...@@ -728,7 +851,8 @@ def cut_loop(loop_node, cutting_points):
One loop is transformed into len(cuttingPoints)+1 new loops that range from One loop is transformed into len(cuttingPoints)+1 new loops that range from
old_begin to cutting_points[1], ..., cutting_points[-1] to old_end old_begin to cutting_points[1], ..., cutting_points[-1] to old_end
Modifies the ast in place Modifies the ast in place. Note Issue #5783 of SymPy. Deepcopy will evaluate mul
https://github.com/sympy/sympy/issues/5783
Returns: Returns:
list of new loop nodes list of new loop nodes
...@@ -766,11 +890,16 @@ def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = Fa ...@@ -766,11 +890,16 @@ def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = Fa
This analysis needs the integer set library (ISL) islpy, so it is not done by This analysis needs the integer set library (ISL) islpy, so it is not done by
default. default.
""" """
from sympy.codegen.rewriting import ReplaceOptim, optimize
remove_casts = ReplaceOptim(lambda e: isinstance(e, CastFunc), lambda p: p.expr)
for conditional in node.atoms(ast.Conditional): for conditional in node.atoms(ast.Conditional):
conditional.condition_expr = sp.simplify(conditional.condition_expr) # TODO simplify conditional before the type system! Casts make it very hard here
if conditional.condition_expr == sp.true: condition_expression = optimize(conditional.condition_expr, [remove_casts])
condition_expression = sp.simplify(condition_expression)
if condition_expression == sp.true:
conditional.parent.replace(conditional, [conditional.true_block]) conditional.parent.replace(conditional, [conditional.true_block])
elif conditional.condition_expr == sp.false: elif condition_expression == sp.false:
conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else []) conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
elif loop_counter_simplification: elif loop_counter_simplification:
try: try:
...@@ -796,282 +925,6 @@ def cleanup_blocks(node: ast.Node) -> None: ...@@ -796,282 +925,6 @@ def cleanup_blocks(node: ast.Node) -> None:
cleanup_blocks(a) cleanup_blocks(a)
class KernelConstraintsCheck:
"""Checks if the input to create_kernel is valid.
Test the following conditions:
- SSA Form for pure symbols:
- Every pure symbol may occur only once as left-hand-side of an assignment
- Every pure symbol that is read, may not be written to later
- Independence / Parallelization condition:
- a field that is written may only be read at exact the same spatial position
(Pure symbols are symbols that are not Field.Accesses)
"""
FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
def __init__(self, type_for_symbol, check_independence_condition, check_double_write_condition=True):
self._type_for_symbol = type_for_symbol
self.scopes = NestedScopes()
self._field_writes = defaultdict(set)
self.fields_read = set()
self.check_independence_condition = check_independence_condition
self.check_double_write_condition = check_double_write_condition
def process_assignment(self, assignment):
# for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
new_rhs = self.process_expression(assignment.rhs)
new_lhs = self._process_lhs(assignment.lhs)
return ast.SympyAssignment(new_lhs, new_rhs)
def process_expression(self, rhs, type_constants=True):
from pystencils.interpolation_astnodes import InterpolatorAccess
self._update_accesses_rhs(rhs)
if isinstance(rhs, AbstractField.AbstractAccess):
self.fields_read.add(rhs.field)
self.fields_read.update(rhs.indirect_addressing_fields)
return rhs
elif isinstance(rhs, InterpolatorAccess):
new_args = [self.process_expression(arg, type_constants) for arg in rhs.offsets]
if new_args:
rhs.offsets = new_args
return rhs
elif isinstance(rhs, ImaginaryUnit):
return TypedImaginaryUnit(create_type(self._type_for_symbol['_complex_type']))
elif isinstance(rhs, TypedSymbol):
return rhs
elif isinstance(rhs, sp.Symbol):
return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
elif type_constants and isinstance(rhs, np.generic):
return cast_func(rhs, create_type(rhs.dtype))
elif type_constants and isinstance(rhs, sp.Number):
return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
# Very important that this clause comes before BooleanFunction
elif isinstance(rhs, cast_func):
return cast_func(
self.process_expression(rhs.args[0], type_constants=False),
rhs.dtype)
elif isinstance(rhs, BooleanFunction) or \
type(rhs) in pystencils.integer_functions.__dict__.values():
new_args = [self.process_expression(a, type_constants) for a in rhs.args]
types_of_expressions = [get_type_of_expression(a) for a in new_args]
arg_type = collate_types(types_of_expressions, forbid_collation_to_float=True)
new_args = [a if not hasattr(a, 'dtype') or a.dtype == arg_type
else cast_func(a, arg_type)
for a in new_args]
return rhs.func(*new_args)
elif isinstance(rhs, sp.Mul):
new_args = [
self.process_expression(arg, type_constants)
if arg not in (-1, 1) else arg for arg in rhs.args
]
return rhs.func(*new_args) if new_args else rhs
elif isinstance(rhs, sp.Indexed):
return rhs
else:
if isinstance(rhs, sp.Pow):
# don't process exponents -> they should remain integers
return sp.Pow(
self.process_expression(rhs.args[0], type_constants),
rhs.args[1])
else:
new_args = [
self.process_expression(arg, type_constants)
for arg in rhs.args
]
return rhs.func(*new_args) if new_args else rhs
@property
def fields_written(self):
return set(k.field for k, v in self._field_writes.items() if len(v))
def _process_lhs(self, lhs):
assert isinstance(lhs, sp.Symbol)
self._update_accesses_lhs(lhs)
if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)):
return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
else:
return lhs
def _update_accesses_lhs(self, lhs):
if isinstance(lhs, AbstractField.AbstractAccess):
fai = self.FieldAndIndex(lhs.field, lhs.index)
self._field_writes[fai].add(lhs.offsets)
if self.check_double_write_condition and len(self._field_writes[fai]) > 1:
raise ValueError(
"Field {} is written at two different locations".format(
lhs.field.name))
elif isinstance(lhs, sp.Symbol):
if self.scopes.is_defined_locally(lhs):
raise ValueError("Assignments not in SSA form, multiple assignments to {}".format(lhs.name))
if lhs in self.scopes.free_parameters:
raise ValueError("Symbol {} is written, after it has been read".format(lhs.name))
self.scopes.define_symbol(lhs)
def _update_accesses_rhs(self, rhs):
if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition:
writes = self._field_writes[self.FieldAndIndex(
rhs.field, rhs.index)]
for write_offset in writes:
assert len(writes) == 1
if write_offset != rhs.offsets:
raise ValueError("Violation of loop independence condition. Field "
"{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset))
self.fields_read.add(rhs.field)
elif isinstance(rhs, sp.Symbol):
self.scopes.access_symbol(rhs)
def add_types(eqs, type_for_symbol, check_independence_condition):
"""Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
Additionally returns sets of all fields which are read/written
Args:
eqs: list of equations
type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed
kernels
Returns:
``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
list of equations where symbols have been replaced by typed symbols
"""
if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'):
type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
def visit(obj):
if isinstance(obj, (list, tuple)):
return [visit(e) for e in obj]
if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)):
return check.process_assignment(obj)
elif isinstance(obj, ast.Conditional):
check.scopes.push()
# Disable double write check inside conditionals
# would be triggered by e.g. in-kernel boundaries
check.check_double_write_condition = False
false_block = None if obj.false_block is None else visit(
obj.false_block)
result = ast.Conditional(check.process_expression(
obj.condition_expr, type_constants=False),
true_block=visit(obj.true_block),
false_block=false_block)
check.check_double_write_condition = True
check.scopes.pop()
return result
elif isinstance(obj, ast.Block):
check.scopes.push()
result = ast.Block([visit(e) for e in obj.args])
check.scopes.pop()
return result
elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
return obj
else:
raise ValueError("Invalid object in kernel " + str(type(obj)))
typed_equations = visit(eqs)
return check.fields_read, check.fields_written, typed_equations
def insert_casts(node):
"""Checks the types and inserts casts and pointer arithmetic where necessary.
Args:
node: the head node of the ast
Returns:
modified AST
"""
def cast(zipped_args_types, target_dtype):
"""
Adds casts to the arguments if their type differs from the target type
:param zipped_args_types: a zipped list of args and types
:param target_dtype: The target data type
:return: args with possible casts
"""
casted_args = []
for argument, data_type in zipped_args_types:
if data_type.numpy_dtype != target_dtype.numpy_dtype: # ignoring const
casted_args.append(cast_func(argument, target_dtype))
else:
casted_args.append(argument)
return casted_args
def pointer_arithmetic(expr_args):
"""
Creates a valid pointer arithmetic function
:param expr_args: Arguments of the add expression
:return: pointer_arithmetic_func
"""
pointer = None
new_args = []
for arg, data_type in expr_args:
if data_type.func is PointerType:
assert pointer is None
pointer = arg
for arg, data_type in expr_args:
if arg != pointer:
assert data_type.is_int() or data_type.is_uint()
new_args.append(arg)
new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args
return pointer_arithmetic_func(pointer, new_args)
if isinstance(node, sp.AtomicExpr) or isinstance(node, cast_func):
return node
args = []
for arg in node.args:
args.append(insert_casts(arg))
# TODO indexed, LoopOverCoordinate
if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
# TODO optimize pow, don't cast integer on double
types = [get_type_of_expression(arg) for arg in args]
assert len(types) > 0
# Never ever, ever collate to float type for boolean functions!
target = collate_types(types, forbid_collation_to_float=isinstance(node.func, BooleanFunction))
zipped = list(zip(args, types))
if target.func is PointerType:
assert node.func is sp.Add
return pointer_arithmetic(zipped)
else:
return node.func(*cast(zipped, target))
elif node.func is ast.SympyAssignment:
lhs = args[0]
rhs = args[1]
target = get_type_of_expression(lhs)
if target.func is PointerType:
return node.func(*args) # TODO fix, not complete
else:
return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
elif node.func is ast.ResolvedFieldAccess:
return node
elif node.func is ast.Block:
for old_arg, new_arg in zip(node.args, args):
node.replace(old_arg, new_arg)
return node
elif node.func is ast.LoopOverCoordinate:
for old_arg, new_arg in zip(node.args, args):
node.replace(old_arg, new_arg)
return node
elif node.func is sp.Piecewise:
expressions = [expr for (expr, _) in args]
types = [get_type_of_expression(expr) for expr in expressions]
target = collate_types(types)
zipped = list(zip(expressions, types))
casted_expressions = cast(zipped, target)
args = [
arg.func(*[expr, arg.cond])
for (arg, expr) in zip(args, casted_expressions)
]
return node.func(*args)
def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, include_first=True) -> None: def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, include_first=True) -> None:
"""Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last or """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last or
first and last element""" first and last element"""
...@@ -1094,73 +947,6 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, i ...@@ -1094,73 +947,6 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, i
# --------------------------------------- Helper Functions ------------------------------------------------------------- # --------------------------------------- Helper Functions -------------------------------------------------------------
def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='int64'):
"""
Creates a default symbol name to type mapping.
If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
Args:
eqs: list of equations
default_type: the type for non-boolean symbols
Returns:
dictionary, mapping symbol name to type
"""
result = defaultdict(lambda: default_type)
if hasattr(default_type, 'numpy_dtype'):
result['_complex_type'] = (np.zeros((1,), default_type.numpy_dtype) * 1j).dtype
else:
result['_complex_type'] = (np.zeros((1,), default_type) * 1j).dtype
for eq in eqs:
if isinstance(eq, ast.Conditional):
result.update(typing_from_sympy_inspection(eq.true_block.args))
if eq.false_block:
result.update(typing_from_sympy_inspection(
eq.false_block.args))
elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment):
continue
else:
from pystencils.cpu.vectorization import vec_all, vec_any
if isinstance(eq.rhs, (vec_all, vec_any)):
result[eq.lhs.name] = "bool"
# problematic case here is when rhs is a symbol: then it is impossible to decide here without
# further information what type the left hand side is - default fallback is the dict value then
if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
result[eq.lhs.name] = "bool"
try:
result[eq.lhs.name] = get_type_of_expression(eq.rhs,
default_float_type=default_type,
default_int_type=default_int_type,
symbol_type_dict=result)
except Exception:
pass # gracefully fail in case get_type_of_expression cannot determine type
return result
def get_next_parent_of_type(node, parent_type):
"""Returns the next parent node of given type or None, if root is reached.
Traverses the AST nodes parents until a parent of given type was found.
If no such parent is found, None is returned
"""
parent = node.parent
while parent is not None:
if isinstance(parent, parent_type):
return parent
parent = parent.parent
return None
def parents_of_type(node, parent_type, include_current=False):
"""Generator for all parent nodes of given type"""
parent = node if include_current else node.parent
while parent is not None:
if isinstance(parent, parent_type):
yield parent
parent = parent.parent
def get_optimal_loop_ordering(fields): def get_optimal_loop_ordering(fields):
""" """
Determines the optimal loop order for a given set of fields. Determines the optimal loop order for a given set of fields.
...@@ -1206,13 +992,13 @@ def get_loop_hierarchy(ast_node): ...@@ -1206,13 +992,13 @@ def get_loop_hierarchy(ast_node):
return reversed(result) return reversed(result)
def get_loop_counter_symbol_hierarchy(astNode): def get_loop_counter_symbol_hierarchy(ast_node):
"""Determines the loop counter symbols around a given AST node. """Determines the loop counter symbols around a given AST node.
:param astNode: the AST node :param ast_node: the AST node
:return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop :return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop
""" """
result = [] result = []
node = astNode node = ast_node
while node is not None: while node is not None:
node = get_next_parent_of_type(node, ast.LoopOverCoordinate) node = get_next_parent_of_type(node, ast.LoopOverCoordinate)
if node: if node:
...@@ -1258,7 +1044,8 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int: ...@@ -1258,7 +1044,8 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
Args: Args:
ast_node: kernel function node before vectorization transformation has been applied ast_node: kernel function node before vectorization transformation has been applied
block_size: sequence defining block size in x, y, (z) direction block_size: sequence defining block size in x, y, (z) direction.
If chosen as zero the direction will not be used for blocking.
Returns: Returns:
number of dimensions blocked number of dimensions blocked
...@@ -1270,8 +1057,10 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int: ...@@ -1270,8 +1057,10 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
body = ast_node.body body = ast_node.body
coordinates = [] coordinates = []
coordinates_taken_into_account = 0
loop_starts = {} loop_starts = {}
loop_stops = {} loop_stops = {}
for loop in loops: for loop in loops:
coord = loop.coordinate_to_loop_over coord = loop.coordinate_to_loop_over
if coord not in coordinates: if coord not in coordinates:
...@@ -1280,11 +1069,14 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int: ...@@ -1280,11 +1069,14 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
loop_stops[coord] = loop.stop loop_stops[coord] = loop.stop
else: else:
assert loop.start == loop_starts[coord] and loop.stop == loop_stops[coord], \ assert loop.start == loop_starts[coord] and loop.stop == loop_stops[coord], \
"Multiple loops over coordinate {} with different loop bounds".format(coord) f"Multiple loops over coordinate {coord} with different loop bounds"
# Create the outer loops that iterate over the blocks # Create the outer loops that iterate over the blocks
outer_loop = None outer_loop = None
for coord in reversed(coordinates): for coord in reversed(coordinates):
if block_size[coord] == 0:
continue
coordinates_taken_into_account += 1
body = ast.Block([outer_loop]) if outer_loop else body body = ast.Block([outer_loop]) if outer_loop else body
outer_loop = ast.LoopOverCoordinate(body, outer_loop = ast.LoopOverCoordinate(body,
coord, coord,
...@@ -1298,6 +1090,8 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int: ...@@ -1298,6 +1090,8 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
# modify the existing loops to only iterate within one block # modify the existing loops to only iterate within one block
for inner_loop in loops: for inner_loop in loops:
coord = inner_loop.coordinate_to_loop_over coord = inner_loop.coordinate_to_loop_over
if block_size[coord] == 0:
continue
block_ctr = ast.LoopOverCoordinate.get_block_loop_counter_symbol(coord) block_ctr = ast.LoopOverCoordinate.get_block_loop_counter_symbol(coord)
loop_range = inner_loop.stop - inner_loop.start loop_range = inner_loop.stop - inner_loop.start
if sp.sympify( if sp.sympify(
...@@ -1307,66 +1101,4 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int: ...@@ -1307,66 +1101,4 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
stop = sp.Min(inner_loop.stop, block_ctr + block_size[coord]) stop = sp.Min(inner_loop.stop, block_ctr + block_size[coord])
inner_loop.start = block_ctr inner_loop.start = block_ctr
inner_loop.stop = stop inner_loop.stop = stop
return len(coordinates) return coordinates_taken_into_account
def implement_interpolations(ast_node: ast.Node,
implement_by_texture_accesses: bool = False,
vectorize: bool = False,
use_hardware_interpolation_for_f32=True):
from pystencils.interpolation_astnodes import (InterpolatorAccess, TextureCachedField)
# TODO: perform this function on assignments, when unify_shape_symbols allows differently sized fields
assert not(implement_by_texture_accesses and vectorize), \
"can only implement interpolations either by texture accesses or CPU vectorization"
FLOAT32_T = create_type('float32')
interpolation_accesses = ast_node.atoms(InterpolatorAccess)
if not interpolation_accesses:
return ast_node
def can_use_hw_interpolation(i):
return (use_hardware_interpolation_for_f32
and implement_by_texture_accesses
and i.dtype == FLOAT32_T
and isinstance(i.symbol.interpolator, TextureCachedField))
if implement_by_texture_accesses:
for i in interpolation_accesses:
from pystencils.interpolation_astnodes import _InterpolationSymbol
try:
import pycuda.driver as cuda
texture = TextureCachedField.from_interpolator(i.interpolator)
if can_use_hw_interpolation(i):
texture.filter_mode = cuda.filter_mode.LINEAR
else:
texture.filter_mode = cuda.filter_mode.POINT
texture.read_as_integer = True
except Exception as e:
raise e
i.symbol = _InterpolationSymbol(str(texture), i.symbol.field, texture)
# from pystencils.math_optimizations import ReplaceOptim, optimize_ast
# ImplementInterpolationByStencils = ReplaceOptim(lambda e: isinstance(e, InterpolatorAccess)
# and not can_use_hw_interpolation(i),
# lambda e: e.implementation_with_stencils()
# )
# RemoveConjugate = ReplaceOptim(lambda e: isinstance(e, sp.conjugate),
# lambda e: e.args[0]
# )
if vectorize:
# TODO can be done in _interpolator_access_to_stencils field.absolute_access == simd_gather
raise NotImplementedError()
else:
substitutions = {i: i.implementation_with_stencils()
for i in interpolation_accesses if not can_use_hw_interpolation(i)}
if isinstance(ast_node, AssignmentCollection):
ast_node = ast_node.subs(substitutions)
else:
ast_node.subs(substitutions)
return ast_node
from pystencils.typing.cast_functions import (CastFunc, BooleanCastFunc, VectorMemoryAccess, ReinterpretCastFunc,
PointerArithmeticFunc)
from pystencils.typing.types import (is_supported_type, numpy_name_to_c, AbstractType, BasicType, VectorType,
PointerType, StructType, create_type)
from pystencils.typing.typed_sympy import (assumptions_from_dtype, TypedSymbol, FieldStrideSymbol, FieldShapeSymbol,
FieldPointerSymbol, CFunction)
from pystencils.typing.utilities import (typed_symbols, get_base_type, result_type, collate_types,
get_type_of_expression, get_next_parent_of_type, parents_of_type)
__all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCastFunc', 'PointerArithmeticFunc',
'is_supported_type', 'numpy_name_to_c', 'AbstractType', 'BasicType',
'VectorType', 'PointerType', 'StructType', 'create_type', 'assumptions_from_dtype',
'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol', 'CFunction',
'typed_symbols', 'get_base_type', 'result_type', 'collate_types',
'get_type_of_expression', 'get_next_parent_of_type', 'parents_of_type']
import numpy as np
import sympy as sp
from sympy.logic.boolalg import Boolean
from pystencils.typing.types import AbstractType, BasicType
from pystencils.typing.typed_sympy import TypedSymbol
class CastFunc(sp.Function):
"""
CastFunc is used in order to introduce static casts. They are especially useful as a way to signal what type
a certain node should have, if it is impossible to add a type to a node, e.g. a sp.Number.
"""
is_Atom = True
def __new__(cls, *args, **kwargs):
if len(args) != 2:
pass
expr, dtype, *other_args = args
# If we have two consecutive casts, throw the inner one away.
# This optimisation is only available for simple casts. Thus the == is intended here!
if expr.__class__ == CastFunc:
expr = expr.args[0]
if not isinstance(dtype, AbstractType):
dtype = BasicType(dtype)
# to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
# however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
# to problems when for example comparing cast_func's for equality
#
# lhs = bitwise_and(a, cast_func(1, 'int'))
# rhs = cast_func(0, 'int')
# print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
# -> thus a separate class boolean_cast_func is introduced
if isinstance(expr, Boolean) and (not isinstance(expr, TypedSymbol) or expr.dtype == BasicType('bool')):
cls = BooleanCastFunc
return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
return self.args[0].canonical
else:
raise NotImplementedError()
@property
def is_commutative(self):
return self.args[0].is_commutative
@property
def dtype(self):
return self.args[1]
@property
def expr(self):
return self.args[0]
@property
def is_integer(self):
"""
Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate
For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
"""
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
else:
return super().is_integer
@property
def is_negative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if hasattr(self.dtype, 'numpy_dtype'):
if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
return False
return super().is_negative
@property
def is_nonnegative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if self.is_negative is False:
return True
else:
return super().is_nonnegative
@property
def is_real(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or np.issubdtype(self.dtype.numpy_dtype,
np.floating) or super().is_real
else:
return super().is_real
class BooleanCastFunc(CastFunc, Boolean):
# TODO: documentation
pass
class VectorMemoryAccess(CastFunc):
"""
Special memory access for vectorized kernel.
Arguments: read/write expression, type, aligned, non-temporal, mask (or none), stride
"""
nargs = (6,)
class ReinterpretCastFunc(CastFunc):
"""
Reinterpret cast is necessary for the StructType
"""
pass
class PointerArithmeticFunc(sp.Function, Boolean):
# TODO: documentation, or deprecate!
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
return self.args[0].canonical
else:
raise NotImplementedError()
from collections import namedtuple
from typing import Union, Tuple, Any, DefaultDict
import logging
import numpy as np
import sympy as sp
from sympy import Piecewise
from sympy.core.numbers import NegativeOne
from sympy.core.relational import Relational
from sympy.functions.elementary.piecewise import ExprCondPair
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
from sympy.functions.elementary.hyperbolic import HyperbolicFunction
from sympy.functions.elementary.integers import RoundFunction
from sympy.logic.boolalg import BooleanFunction
from sympy.logic.boolalg import BooleanAtom
from pystencils import astnodes as ast
from pystencils.functions import DivFunc, AddressOf
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.field import Field
from pystencils.typing.types import BasicType, PointerType
from pystencils.typing.utilities import collate_types
from pystencils.typing.cast_functions import CastFunc, BooleanCastFunc
from pystencils.typing.typed_sympy import TypedSymbol
from pystencils.fast_approximation import fast_sqrt, fast_division, fast_inv_sqrt
from pystencils.utils import ContextVar
class TypeAdder:
# TODO: specification -> jupyter notebook
"""Checks if the input to create_kernel is valid.
Test the following conditions:
- SSA Form for pure symbols:
- Every pure symbol may occur only once as left-hand-side of an assignment
- Every pure symbol that is read, may not be written to later
- Independence / Parallelization condition:
- a field that is written may only be read at exact the same spatial position
(Pure symbols are symbols that are not Field.Accesses)
"""
FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
def __init__(self, type_for_symbol: DefaultDict[str, BasicType], default_number_float: BasicType,
default_number_int: BasicType):
self.type_for_symbol = type_for_symbol
self.default_number_float = ContextVar(default_number_float)
self.default_number_int = ContextVar(default_number_int)
def visit(self, obj):
if isinstance(obj, (list, tuple)):
return [self.visit(e) for e in obj]
if isinstance(obj, ast.SympyAssignment):
return self.process_assignment(obj)
elif isinstance(obj, ast.Conditional):
condition, condition_type = self.figure_out_type(obj.condition_expr)
assert condition_type == BasicType('bool')
true_block = self.visit(obj.true_block)
false_block = None if obj.false_block is None else self.visit(
obj.false_block)
return ast.Conditional(condition, true_block=true_block, false_block=false_block)
elif isinstance(obj, ast.Block):
return ast.Block([self.visit(e) for e in obj.args])
elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
return obj
else:
raise ValueError("Invalid object in kernel " + str(type(obj)))
def process_assignment(self, assignment: ast.SympyAssignment) -> ast.SympyAssignment:
# for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
new_rhs, rhs_type = self.figure_out_type(assignment.rhs)
lhs = assignment.lhs
if not isinstance(lhs, (Field.Access, TypedSymbol)):
if isinstance(lhs, sp.Symbol):
self.type_for_symbol[lhs.name] = rhs_type
else:
raise ValueError(f'Lhs: `{lhs}` is not a subtype of sp.Symbol')
new_lhs, lhs_type = self.figure_out_type(lhs)
assert isinstance(new_lhs, (Field.Access, TypedSymbol))
if lhs_type != rhs_type:
logging.debug(f'Lhs"{new_lhs} of type "{lhs_type}" is assigned with a different datatype '
f'rhs: "{new_rhs}" of type "{rhs_type}".')
return ast.SympyAssignment(new_lhs, CastFunc(new_rhs, lhs_type), assignment.is_const, assignment.use_auto)
else:
return ast.SympyAssignment(new_lhs, new_rhs, assignment.is_const, assignment.use_auto)
# Type System Specification
# - Defined Types: TypedSymbol, Field, Field.Access, ...?
# - Indexed: always unsigned_integer64
# - Undefined Types: Symbol
# - Is specified in Config in the dict or as 'default_type' or behaves like `auto` in the case of lhs.
# - Constants/Numbers: Are either integer or floating. The precision and sign is specified via config
# - Example: 1.4 config:float32 -> float32
# - Expressions deduce types from arguments
# - Functions deduce types from arguments
# - default_type and default_float and default_int can be given for a list of assignment, or
# individually as a list for assignment
# Possible Problems - Do we need to support this?
# - Mixture in expression with int and float
# - Mixture in expression with uint64 and sint64
# TODO Logging: Lowest log level should log all casts ----> cast factory, make cast should contain logging
def figure_out_type(self, expr) -> Tuple[Any, Union[BasicType, PointerType]]:
# Trivial cases
from pystencils.field import Field
import pystencils.integer_functions
from pystencils.bit_masks import flag_cond
bool_type = BasicType('bool')
# TOOO: check the access
if isinstance(expr, Field.Access):
return expr, expr.dtype
elif isinstance(expr, TypedSymbol):
return expr, expr.dtype
elif isinstance(expr, sp.Symbol):
t = TypedSymbol(expr.name, self.type_for_symbol[expr.name])
return t, t.dtype
elif isinstance(expr, np.generic):
assert False, f'Why do we have a np.generic in rhs???? {expr}'
elif isinstance(expr, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)):
return expr, BasicType('float32') # see https://en.cppreference.com/w/cpp/numeric/math/INFINITY
elif isinstance(expr, sp.Number):
if expr.is_Integer:
data_type = self.default_number_int.get()
elif expr.is_Float or expr.is_Rational:
data_type = self.default_number_float.get()
else:
assert False, f'{sp.Number} is neither Float nor Integer'
return CastFunc(expr, data_type), data_type
elif isinstance(expr, AddressOf):
of = expr.args[0]
# TODO Basically this should do address_of already
assert isinstance(of, (Field.Access, TypedSymbol, Field))
return expr, PointerType(of.dtype)
elif isinstance(expr, BooleanAtom):
return expr, bool_type
elif isinstance(expr, Relational):
# TODO Jan: Code duplication with general case
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
if isinstance(expr, sp.Equality) and collated_type.is_float():
logging.warning(f"Using floating point numbers in equality comparison: {expr}")
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
new_eq = expr.func(*new_args)
return new_eq, bool_type
elif isinstance(expr, CastFunc):
new_expr, _ = self.figure_out_type(expr.expr)
return expr.func(*[new_expr, expr.dtype]), expr.dtype
elif isinstance(expr, ast.ConditionalFieldAccess):
access, access_type = self.figure_out_type(expr.access)
value, value_type = self.figure_out_type(expr.outofbounds_value)
condition, condition_type = self.figure_out_type(expr.outofbounds_condition)
assert condition_type == bool_type
collated_type = collate_types([access_type, value_type])
if collated_type == access_type:
new_access = access
else:
logging.warning(f"In {expr} the Field Access had to be casted to {collated_type}. This is "
f"probably due to a type missmatch of the Field and the value of "
f"ConditionalFieldAccess")
new_access = CastFunc(access, collated_type)
new_value = value if value_type == collated_type else CastFunc(value, collated_type)
return expr.func(new_access, condition, new_value), collated_type
elif isinstance(expr, (vec_any, vec_all)):
return expr, bool_type
elif isinstance(expr, BooleanFunction):
args_types = [self.figure_out_type(a) for a in expr.args]
new_args = [a if t.dtype_eq(bool_type) else BooleanCastFunc(a, bool_type) for a, t in args_types]
return expr.func(*new_args), bool_type
elif type(expr, ) in pystencils.integer_functions.__dict__.values() or isinstance(expr, sp.Mod):
args_types = [self.figure_out_type(a) for a in expr.args]
collated_type = collate_types([t for _, t in args_types])
# TODO: should we downcast to integer? If yes then which integer type?
if not collated_type.is_int():
raise ValueError(f"Integer functions or Modulo need to be used with integer types "
f"but {collated_type} was given")
return expr, collated_type
elif isinstance(expr, flag_cond):
# do not process the arguments to the bit shift - they must remain integers
args_types = [self.figure_out_type(a) for a in (expr.args[i] for i in range(2, len(expr.args)))]
collated_type = collate_types([t for _, t in args_types])
new_expressions = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
return expr.func(expr.args[0], expr.args[1], *new_expressions), collated_type
# elif isinstance(expr, sp.Mul):
# raise NotImplementedError('sp.Mul')
# # TODO can we ignore this and move it to general expr handling, i.e. removing Mul? (See todo in backend)
# # args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)]
elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label
return expr, typed_symbol.dtype
elif isinstance(expr, ExprCondPair):
expr_expr, expr_type = self.figure_out_type(expr.expr)
condition, condition_type = self.figure_out_type(expr.cond)
if condition_type != bool_type:
logging.warning(f'Condition "{condition}" is of type "{condition_type}" and not "bool"')
return expr.func(expr_expr, condition), expr_type
elif isinstance(expr, Piecewise):
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
new_args = []
for a, t in args_types:
if t != collated_type:
if isinstance(a, ExprCondPair):
new_args.append(a.func(CastFunc(a.expr, collated_type), a.cond))
else:
new_args.append(CastFunc(a, collated_type))
else:
new_args.append(a)
return expr.func(*new_args) if new_args else expr, collated_type
elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction,
HyperbolicFunction, sp.log, RoundFunction)):
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
new_func = expr.func(*new_args) if new_args else expr
if collated_type == BasicType('float64'):
return new_func, collated_type
else:
return CastFunc(new_func, collated_type), collated_type
elif isinstance(expr, (fast_sqrt, fast_division, fast_inv_sqrt)):
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = BasicType('float32')
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
new_func = expr.func(*new_args) if new_args else expr
return CastFunc(new_func, collated_type), collated_type
elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)):
# Subtraction is realised a multiplication with -1 in SymPy. Thus we exclude the coefficient in this case
# and resolve the typing entirely with the expression itself
if isinstance(expr, sp.Mul):
c, e = expr.as_coeff_Mul()
if c == NegativeOne():
args_types = self.figure_out_type(e)
new_args = [NegativeOne(), args_types[0]]
return expr.func(*new_args, evaluate=False), args_types[1]
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
if isinstance(collated_type, PointerType):
if isinstance(expr, sp.Add):
return expr.func(*[a for a, _ in args_types]), collated_type
else:
raise NotImplementedError(f'Pointer Arithmetic is implemented only for Add, not {expr}')
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
if isinstance(expr, (sp.Add, sp.Mul)):
return expr.func(*new_args, evaluate=False) if new_args else expr, collated_type
else:
return expr.func(*new_args) if new_args else expr, collated_type
else:
raise NotImplementedError(f'expr {type(expr)}: {expr} unknown to typing')
from typing import List
from pystencils.astnodes import Node
from pystencils.config import CreateKernelConfig
from pystencils.typing.leaf_typing import TypeAdder
def add_types(node_list: List[Node], config: CreateKernelConfig):
"""Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
The AST needs to be a pystencils AST. Thus, in the list of nodes every entry must be inherited from
`pystencils.astnodes.Node`
Additionally returns sets of all fields which are read/written
Args:
node_list: List of pystencils Nodes.
config: CreateKernelConfig
Returns:
``typed_equations`` list of equations where symbols have been replaced by typed symbols
"""
check = TypeAdder(type_for_symbol=config.data_type,
default_number_float=config.default_number_float,
default_number_int=config.default_number_int)
return check.visit(node_list)
"""Special symbols representing kernel parameters related to fields/arrays. from typing import Union
A `KernelFunction` node determines parameters that have to be passed to the function by searching for all undefined import numpy as np
symbols. Some symbols are not directly defined by the user, but are related to the `Field`s used in the kernel: import sympy as sp
For each field a `FieldPointerSymbol` needs to be passed in, which is the pointer to the memory region where
the field is stored. This pointer is represented by the `FieldPointerSymbol` class that additionally stores the
name of the corresponding field. For fields where the size is not known at compile time, additionally shape and stride
information has to be passed in at runtime. These values are represented by `FieldShapeSymbol`
and `FieldPointerSymbol`.
The special symbols in this module store only the field name instead of a field reference. Storing a field reference
directly leads to problems with copying and pickling behaviour due to the circular dependency of `Field` and
e.g. `FieldShapeSymbol`, since a Field contains `FieldShapeSymbol`s in its shape, and a `FieldShapeSymbol`
would reference back to the field.
"""
from sympy.core.cache import cacheit from sympy.core.cache import cacheit
from pystencils.data_types import ( from pystencils.typing.types import BasicType, create_type, PointerType
PointerType, TypedSymbol, create_composite_type_from_string, get_base_type)
def assumptions_from_dtype(dtype: Union[BasicType, np.dtype]):
"""Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype
Args:
dtype (BasicType, np.dtype): a Numpy data type
Returns:
A dict of SymPy assumptions
"""
if hasattr(dtype, 'numpy_dtype'):
dtype = dtype.numpy_dtype
assumptions = dict()
try:
if np.issubdtype(dtype, np.integer):
assumptions.update({'integer': True})
if np.issubdtype(dtype, np.unsignedinteger):
assumptions.update({'negative': False})
if np.issubdtype(dtype, np.integer) or \
np.issubdtype(dtype, np.floating):
assumptions.update({'real': True})
except Exception: # TODO this is dirty
pass
return assumptions
class TypedSymbol(sp.Symbol):
def __new__(cls, *args, **kwds):
obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, name, dtype, **kwargs): # TODO does not match signature of sp.Symbol???
# TODO: also Symbol should be allowed ---> see sympy Variable
assumptions = assumptions_from_dtype(dtype)
assumptions.update(kwargs)
obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions)
try:
obj.numpy_dtype = create_type(dtype)
except (TypeError, ValueError):
# on error keep the string
obj.numpy_dtype = dtype
return obj
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
@property
def dtype(self):
return self.numpy_dtype
def _hashable_content(self):
return super()._hashable_content(), hash(self.numpy_dtype)
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), self.assumptions0
@property
def canonical(self):
return self
@property
def reversed(self):
return self
@property
def headers(self):
headers = []
try:
if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
try:
if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
SHAPE_DTYPE = create_composite_type_from_string("const int64") return headers
STRIDE_DTYPE = create_composite_type_from_string("const int64")
SHAPE_DTYPE = BasicType('int64', const=True)
STRIDE_DTYPE = BasicType('int64', const=True)
class FieldStrideSymbol(TypedSymbol): class FieldStrideSymbol(TypedSymbol):
...@@ -29,7 +105,7 @@ class FieldStrideSymbol(TypedSymbol): ...@@ -29,7 +105,7 @@ class FieldStrideSymbol(TypedSymbol):
return obj return obj
def __new_stage2__(cls, field_name, coordinate): def __new_stage2__(cls, field_name, coordinate):
name = "_stride_{name}_{i}".format(name=field_name, i=coordinate) name = f"_stride_{field_name}_{coordinate}"
obj = super(FieldStrideSymbol, cls).__xnew__(cls, name, STRIDE_DTYPE, positive=True) obj = super(FieldStrideSymbol, cls).__xnew__(cls, name, STRIDE_DTYPE, positive=True)
obj.field_name = field_name obj.field_name = field_name
obj.coordinate = coordinate obj.coordinate = coordinate
...@@ -38,6 +114,9 @@ class FieldStrideSymbol(TypedSymbol): ...@@ -38,6 +114,9 @@ class FieldStrideSymbol(TypedSymbol):
def __getnewargs__(self): def __getnewargs__(self):
return self.field_name, self.coordinate return self.field_name, self.coordinate
def __getnewargs_ex__(self):
return (self.field_name, self.coordinate), {}
__xnew__ = staticmethod(__new_stage2__) __xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__)) __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
...@@ -54,7 +133,7 @@ class FieldShapeSymbol(TypedSymbol): ...@@ -54,7 +133,7 @@ class FieldShapeSymbol(TypedSymbol):
def __new_stage2__(cls, field_names, coordinate): def __new_stage2__(cls, field_names, coordinate):
names = "_".join([field_name for field_name in field_names]) names = "_".join([field_name for field_name in field_names])
name = "_size_{names}_{i}".format(names=names, i=coordinate) name = f"_size_{names}_{coordinate}"
obj = super(FieldShapeSymbol, cls).__xnew__(cls, name, SHAPE_DTYPE, positive=True) obj = super(FieldShapeSymbol, cls).__xnew__(cls, name, SHAPE_DTYPE, positive=True)
obj.field_names = tuple(field_names) obj.field_names = tuple(field_names)
obj.coordinate = coordinate obj.coordinate = coordinate
...@@ -63,6 +142,9 @@ class FieldShapeSymbol(TypedSymbol): ...@@ -63,6 +142,9 @@ class FieldShapeSymbol(TypedSymbol):
def __getnewargs__(self): def __getnewargs__(self):
return self.field_names, self.coordinate return self.field_names, self.coordinate
def __getnewargs_ex__(self):
return (self.field_names, self.coordinate), {}
__xnew__ = staticmethod(__new_stage2__) __xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__)) __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
...@@ -77,7 +159,9 @@ class FieldPointerSymbol(TypedSymbol): ...@@ -77,7 +159,9 @@ class FieldPointerSymbol(TypedSymbol):
return obj return obj
def __new_stage2__(cls, field_name, field_dtype, const): def __new_stage2__(cls, field_name, field_dtype, const):
name = "_data_{name}".format(name=field_name) from pystencils.typing.utilities import get_base_type
name = f"_data_{field_name}"
dtype = PointerType(get_base_type(field_dtype), const=const, restrict=True) dtype = PointerType(get_base_type(field_dtype), const=const, restrict=True)
obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype) obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype)
obj.field_name = field_name obj.field_name = field_name
...@@ -86,8 +170,28 @@ class FieldPointerSymbol(TypedSymbol): ...@@ -86,8 +170,28 @@ class FieldPointerSymbol(TypedSymbol):
def __getnewargs__(self): def __getnewargs__(self):
return self.field_name, self.dtype, self.dtype.const return self.field_name, self.dtype, self.dtype.const
def __getnewargs_ex__(self):
return (self.field_name, self.dtype, self.dtype.const), {}
def _hashable_content(self): def _hashable_content(self):
return super()._hashable_content(), self.field_name return super()._hashable_content(), self.field_name
__xnew__ = staticmethod(__new_stage2__) __xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__)) __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
class CFunction(TypedSymbol):
def __new__(cls, function, dtype):
return CFunction.__xnew_cached_(cls, function, dtype)
def __new_stage2__(cls, function, dtype):
return super(CFunction, cls).__xnew__(cls, function, dtype)
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), {}
from abc import abstractmethod
from typing import Union
import numpy as np
import sympy as sp
def is_supported_type(dtype: np.dtype):
scalar = dtype.type
c = np.issubdtype(dtype, np.generic)
subclass = issubclass(scalar, np.floating) or issubclass(scalar, np.integer) or issubclass(scalar, np.bool_)
additional_checks = dtype.fields is None and dtype.hasobject is False and dtype.subdtype is None
return c and subclass and additional_checks
def numpy_name_to_c(name: str) -> str:
"""
Converts a np.dtype.name into a C type
Args:
name: np.dtype.name string
Returns:
type as a C string
"""
if name == 'float64':
return 'double'
elif name == 'float32':
return 'float'
elif name == 'float16' or name == 'half':
return 'half'
elif name.startswith('int'):
width = int(name[len("int"):])
return f"int{width}_t"
elif name.startswith('uint'):
width = int(name[len("uint"):])
return f"uint{width}_t"
elif name == 'bool':
return 'bool'
else:
raise NotImplementedError(f"Can't map numpy to C name for {name}")
class AbstractType(sp.Atom):
# TODO: Is it necessary to ineherit from sp.Atom?
def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls)
def _sympystr(self, *args, **kwargs):
return str(self)
@property
@abstractmethod
def base_type(self) -> Union[None, 'BasicType']:
"""
Returns: Returns BasicType of a Vector or Pointer type, None otherwise
"""
pass
@property
@abstractmethod
def item_size(self) -> int:
"""
Returns: Number of items.
E.g. width * item_size(basic_type) in vector's case, or simple numpy itemsize in Struct's case.
"""
pass
class BasicType(AbstractType):
"""
BasicType is defined with a const qualifier and a np.dtype.
"""
def __init__(self, dtype: Union[type, 'BasicType', str], const: bool = False):
if isinstance(dtype, BasicType):
self.numpy_dtype = dtype.numpy_dtype
self.const = dtype.const
else:
self.numpy_dtype = np.dtype(dtype)
self.const = const
assert is_supported_type(self.numpy_dtype), f'Type {self.numpy_dtype} is currently not supported!'
def __getnewargs__(self):
return self.numpy_dtype, self.const
def __getnewargs_ex__(self):
return (self.numpy_dtype, self.const), {}
@property
def base_type(self):
return None
@property
def item_size(self): # TODO: Do we want self.numpy_type.itemsize????
return 1
def is_float(self):
return issubclass(self.numpy_dtype.type, np.floating)
def is_half(self):
return issubclass(self.numpy_dtype.type, np.half)
def is_int(self):
return issubclass(self.numpy_dtype.type, np.integer)
def is_uint(self):
return issubclass(self.numpy_dtype.type, np.unsignedinteger)
def is_sint(self):
return issubclass(self.numpy_dtype.type, np.signedinteger)
def is_bool(self):
return issubclass(self.numpy_dtype.type, np.bool_)
def dtype_eq(self, other):
if not isinstance(other, BasicType):
return False
else:
return self.numpy_dtype == other.numpy_dtype
@property
def c_name(self) -> str:
return numpy_name_to_c(self.numpy_dtype.name)
def __str__(self):
return f'{self.c_name}{" const" if self.const else ""}'
def __repr__(self):
return f'BasicType( {str(self)} )'
def _repr_html_(self):
return f'BasicType( {str(self)} )'
def __eq__(self, other):
return self.dtype_eq(other) and self.const == other.const
def __hash__(self):
return hash(str(self))
class VectorType(AbstractType):
"""
VectorType consists of a BasicType and a width.
"""
instruction_set = None
def __init__(self, base_type: BasicType, width: int):
self._base_type = base_type
self.width = width
@property
def base_type(self):
return self._base_type
@property
def item_size(self):
return self.width * self.base_type.item_size
def __eq__(self, other):
if not isinstance(other, VectorType):
return False
else:
return (self.base_type, self.width) == (other.base_type, other.width)
def __str__(self):
if self.instruction_set is None:
return f"{self.base_type}[{self.width}]"
else:
# TODO VectorizationRevamp: this seems super weird. the instruction_set should know how to print a type out!
# TODO VectorizationRevamp: this is error prone. base_type could be cons=True. Use dtype instead
if self.base_type == create_type("int64") or self.base_type == create_type("int32"):
return self.instruction_set['int']
elif self.base_type == create_type("float64"):
return self.instruction_set['double']
elif self.base_type == create_type("float32"):
return self.instruction_set['float']
elif self.base_type == create_type("bool"):
return self.instruction_set['bool']
else:
raise NotImplementedError()
def __hash__(self):
return hash((self.base_type, self.width))
def __getnewargs__(self):
return self._base_type, self.width
def __getnewargs_ex__(self):
return (self._base_type, self.width), {}
class PointerType(AbstractType):
def __init__(self, base_type: BasicType, const: bool = False, restrict: bool = True, double_pointer: bool = False):
self._base_type = base_type
self.const = const
self.restrict = restrict
self.double_pointer = double_pointer
def __getnewargs__(self):
return self.base_type, self.const, self.restrict, self.double_pointer
def __getnewargs_ex__(self):
return (self.base_type, self.const, self.restrict, self.double_pointer), {}
@property
def alias(self):
return not self.restrict
@property
def base_type(self):
return self._base_type
@property
def item_size(self):
if self.double_pointer:
raise NotImplementedError("The item_size for double_pointer is not implemented")
else:
return self.base_type.item_size
def __eq__(self, other):
if not isinstance(other, PointerType):
return False
else:
own = (self.base_type, self.const, self.restrict, self.double_pointer)
return own == (other.base_type, other.const, other.restrict, other.double_pointer)
def __str__(self):
restrict_str = "RESTRICT" if self.restrict else ""
const_str = "const" if self.const else ""
if self.double_pointer:
return f'{str(self.base_type)} ** {restrict_str} {const_str}'
else:
return f'{str(self.base_type)} * {restrict_str} {const_str}'
def __repr__(self):
return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self):
return hash((self._base_type, self.const, self.restrict, self.double_pointer))
class StructType(AbstractType):
"""
A list of types (with C offsets).
It is implemented with uint8_t and casts to the correct datatype.
"""
def __init__(self, numpy_type, const=False):
self.const = const
self._dtype = np.dtype(numpy_type)
def __getnewargs__(self):
return self.numpy_dtype, self.const
def __getnewargs_ex__(self):
return (self.numpy_dtype, self.const), {}
@property
def base_type(self):
return None
@property
def numpy_dtype(self):
return self._dtype
@property
def item_size(self):
return self.numpy_dtype.itemsize
def get_element_offset(self, element_name):
return self.numpy_dtype.fields[element_name][1]
def get_element_type(self, element_name):
np_element_type = self.numpy_dtype.fields[element_name][0]
return BasicType(np_element_type, self.const)
def has_element(self, element_name):
return element_name in self.numpy_dtype.fields
def __eq__(self, other):
if not isinstance(other, StructType):
return False
else:
return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
def __str__(self):
# structs are handled byte-wise
result = "uint8_t"
if self.const:
result += " const"
return result
def __repr__(self):
return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self):
return hash((self.numpy_dtype, self.const))
def create_type(specification: Union[type, AbstractType, str]) -> AbstractType:
# TODO: Deprecated Use the constructor of BasicType or StructType instead
"""Creates a subclass of Type according to a string or an object of subclass Type.
Args:
specification: Type object, or a string
Returns:
Type object, or a new Type object parsed from the string
"""
if isinstance(specification, AbstractType):
return specification
else:
numpy_dtype = np.dtype(specification)
if numpy_dtype.fields is None:
return BasicType(numpy_dtype, const=False)
else:
return StructType(numpy_dtype, const=False)
from collections import defaultdict
from functools import partial
from typing import Tuple, Union, Sequence
import numpy as np
import sympy as sp
from sympy.logic.boolalg import Boolean, BooleanFunction
import pystencils
from pystencils.cache import memorycache_if_hashable
from pystencils.typing.types import BasicType, VectorType, PointerType, create_type
from pystencils.typing.cast_functions import CastFunc
from pystencils.typing.typed_sympy import TypedSymbol
from pystencils.utils import all_equal
def typed_symbols(names, dtype, **kwargs):
"""
Creates TypedSymbols with the same functionality as sympy.symbols
Args:
names: See sympy.symbols
dtype: The data type all symbols will have
**kwargs: Key value arguments passed to sympy.symbols
Returns:
TypedSymbols
"""
symbols = sp.symbols(names, **kwargs)
if isinstance(symbols, Tuple):
return tuple(TypedSymbol(str(s), dtype) for s in symbols)
else:
return TypedSymbol(str(symbols), dtype)
def get_base_type(data_type):
"""
Returns the BasicType of a Pointer or a Vector
"""
while data_type.base_type is not None:
data_type = data_type.base_type
return data_type
def result_type(*args: np.dtype):
"""Returns the type of the result if the np.dtype arguments would be collated.
We can't use numpy functionality, because numpy casts don't behave exactly like C casts"""
s = sorted(args, key=lambda x: x.itemsize)
def kind_to_value(kind: str) -> int:
if kind == 'f':
return 3
elif kind == 'i':
return 2
elif kind == 'u':
return 1
elif kind == 'b':
return 0
else:
raise NotImplementedError(f'{kind=} is not a supported kind of a type. See "numpy.dtype.kind" for options')
s = sorted(s, key=lambda x: kind_to_value(x.kind))
return s[-1]
def collate_types(types: Sequence[Union[BasicType, VectorType]]):
"""
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy.
"""
# Pointer arithmetic case i.e. pointer + [int, uint] is allowed
if any(isinstance(t, PointerType) for t in types):
pointer_type = None
for t in types:
if isinstance(t, PointerType):
if pointer_type is not None:
raise ValueError(f'Cannot collate the combination of two pointer types "{pointer_type}" and "{t}"')
pointer_type = t
elif isinstance(t, BasicType):
if not (t.is_int() or t.is_uint()):
raise ValueError("Invalid pointer arithmetic")
else:
raise ValueError("Invalid pointer arithmetic")
return pointer_type
# # peel of vector types, if at least one vector type occurred the result will also be the vector type
vector_type = [t for t in types if isinstance(t, VectorType)]
if not all_equal(t.width for t in vector_type):
raise ValueError("Collation failed because of vector types with different width")
# TODO: check if this is needed
# def peel_off_type(dtype, type_to_peel_off):
# while type(dtype) is type_to_peel_off:
# dtype = dtype.base_type
# return dtype
# types = [peel_off_type(t, VectorType) for t in types]
types = [t.base_type if isinstance(t, VectorType) else t for t in types]
# now we should have a list of basic types - struct types are not yet supported
assert all(type(t) is BasicType for t in types)
result_numpy_type = result_type(*(t.numpy_dtype for t in types))
result = BasicType(result_numpy_type)
if vector_type:
result = VectorType(result, vector_type[0].width)
return result
# TODO get_type_of_expression should be used after leaf_typing. So no defaults should be necessary
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
default_float_type='double',
default_int_type='int',
symbol_type_dict=None):
from pystencils.astnodes import ResolvedFieldAccess
from pystencils.cpu.vectorization import vec_all, vec_any
if default_float_type == 'float':
default_float_type = 'float32'
if not symbol_type_dict:
symbol_type_dict = defaultdict(lambda: create_type('double'))
# TODO this line is quite hard to understand, if possible simpl
get_type = partial(get_type_of_expression,
default_float_type=default_float_type,
default_int_type=default_int_type,
symbol_type_dict=symbol_type_dict)
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
return create_type(default_int_type)
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type(default_float_type)
elif isinstance(expr, ResolvedFieldAccess):
return expr.field.dtype
elif isinstance(expr, pystencils.field.Field.Access):
return expr.field.dtype
elif isinstance(expr, TypedSymbol):
return expr.dtype
elif isinstance(expr, sp.Symbol):
# TODO delete if case
if symbol_type_dict:
return symbol_type_dict[expr.name]
else:
raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
elif isinstance(expr, CastFunc):
return expr.args[1]
elif isinstance(expr, (vec_any, vec_all)):
return create_type("bool")
elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
return collated_result_type
elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label
return typed_symbol.dtype.base_type
elif isinstance(expr, (Boolean, BooleanFunction)):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool")
vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
if vec_args:
result = VectorType(result, width=vec_args[0].width)
return result
elif isinstance(expr, sp.Pow):
base_type = get_type(expr.args[0])
if expr.exp.is_integer:
return base_type
else:
return collate_types([create_type(default_float_type), base_type])
elif isinstance(expr, (sp.Sum, sp.Product)):
return get_type(expr.args[0])
elif isinstance(expr, sp.Expr):
expr: sp.Expr
if expr.args:
types = tuple(get_type(a) for a in expr.args)
return collate_types(types)
else:
if expr.is_integer:
return create_type(default_int_type)
else:
return create_type(default_float_type)
raise NotImplementedError("Could not determine type for", expr, type(expr))
# Fix for sympy versions from 1.9
sympy_version = sp.__version__.split('.')
sympy_version_int = int(sympy_version[0]) * 100 + int(sympy_version[1])
if sympy_version_int >= 109:
# __setstate__ would bypass the contructor, so we remove it
if sympy_version_int >= 111:
del sp.Basic.__setstate__
del sp.Symbol.__setstate__
else:
sp.Number.__getstate__ = sp.Basic.__getstate__
del sp.Basic.__getstate__
# __reduce_ex__ would strip kwargs, so we override it
def basic_reduce_ex(self, protocol):
if hasattr(self, '__getnewargs_ex__'):
args, kwargs = self.__getnewargs_ex__()
else:
args, kwargs = self.__getnewargs__(), {}
if hasattr(self, '__getstate__'):
state = self.__getstate__()
else:
state = None
return partial(type(self), **kwargs), args, state
sp.Basic.__reduce_ex__ = basic_reduce_ex
def get_next_parent_of_type(node, parent_type):
"""Returns the next parent node of given type or None, if root is reached.
Traverses the AST nodes parents until a parent of given type was found.
If no such parent is found, None is returned
"""
parent = node.parent
while parent is not None:
if isinstance(parent, parent_type):
return parent
parent = parent.parent
return None
def parents_of_type(node, parent_type, include_current=False):
"""Generator for all parent nodes of given type"""
parent = node if include_current else node.parent
while parent is not None:
if isinstance(parent, parent_type):
yield parent
parent = parent.parent
import os import os
import itertools
from itertools import groupby
from collections import Counter from collections import Counter
from contextlib import contextmanager from contextlib import contextmanager
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
...@@ -14,14 +16,21 @@ class DotDict(dict): ...@@ -14,14 +16,21 @@ class DotDict(dict):
__setattr__ = dict.__setitem__ __setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__ __delattr__ = dict.__delitem__
# Recursively make DotDict: https://stackoverflow.com/questions/13520421/recursive-dotdict
def __init__(self, dct={}):
for key, value in dct.items():
if isinstance(value, dict):
value = DotDict(value)
self[key] = value
def all_equal(iterator):
iterator = iter(iterator) def all_equal(iterable):
try: """
first = next(iterator) Returns ``True`` if all the elements are equal to each other.
except StopIteration: Copied from: more-itertools 8.12.0
return True """
return all(first == rest for rest in iterator) g = groupby(iterable)
return next(g, True) and not next(g, False)
def recursive_dict_update(d, u): def recursive_dict_update(d, u):
...@@ -43,33 +52,13 @@ def recursive_dict_update(d, u): ...@@ -43,33 +52,13 @@ def recursive_dict_update(d, u):
return d return d
@contextmanager
def file_handle_for_atomic_write(file_path):
"""Open temporary file object that atomically moves to destination upon exiting.
Allows reading and writing to and from the same filename.
The file will not be moved to destination in case of an exception.
Args:
file_path: path to file to be opened
"""
target_folder = os.path.dirname(os.path.abspath(file_path))
with NamedTemporaryFile(delete=False, dir=target_folder, mode='w') as f:
try:
yield f
finally:
f.flush()
os.fsync(f.fileno())
os.rename(f.name, file_path)
@contextmanager @contextmanager
def atomic_file_write(file_path): def atomic_file_write(file_path):
target_folder = os.path.dirname(os.path.abspath(file_path)) target_folder = os.path.dirname(os.path.abspath(file_path))
with NamedTemporaryFile(delete=False, dir=target_folder) as f: with NamedTemporaryFile(delete=False, dir=target_folder) as f:
f.file.close() f.file.close()
yield f.name yield f.name
os.rename(f.name, file_path) os.replace(f.name, file_path)
def fully_contains(l1, l2): def fully_contains(l1, l2):
...@@ -89,19 +78,39 @@ def fully_contains(l1, l2): ...@@ -89,19 +78,39 @@ def fully_contains(l1, l2):
def boolean_array_bounding_box(boolean_array): def boolean_array_bounding_box(boolean_array):
"""Returns bounding box around "true" area of boolean array""" """Returns bounding box around "true" area of boolean array
dim = len(boolean_array.shape)
>>> a = np.zeros((4, 4), dtype=bool)
>>> a[1:-1, 1:-1] = True
>>> boolean_array_bounding_box(a) == [(1, 3), (1, 3)]
True
"""
dim = boolean_array.ndim
shape = boolean_array.shape
assert 0 not in shape, "Shape must not contain zero"
bounds = [] bounds = []
for i in range(dim): for ax in itertools.combinations(reversed(range(dim)), dim - 1):
for j in range(dim): nonzero = np.any(boolean_array, axis=ax)
if i != j: t = np.where(nonzero)[0][[0, -1]]
arr_1d = np.any(boolean_array, axis=j) bounds.append((t[0], t[1] + 1))
begin = np.argmax(arr_1d)
end = begin + np.argmin(arr_1d[begin:])
bounds.append((begin, end))
return bounds return bounds
def binary_numbers(n):
"""Returns all binary numbers up to 2^n - 1
Example:
>>> binary_numbers(2)
[[0, 0], [0, 1], [1, 0], [1, 1]]
"""
result = list()
for i in range(1 << n):
binary_number = bin(i)[2:]
binary_number = '0' * (n - len(binary_number)) + binary_number
result.append((list(map(int, binary_number))))
return result
class LinearEquationSystem: class LinearEquationSystem:
"""Symbolic linear system of equations - consisting of matrix and right hand side. """Symbolic linear system of equations - consisting of matrix and right hand side.
...@@ -210,7 +219,8 @@ class LinearEquationSystem: ...@@ -210,7 +219,8 @@ class LinearEquationSystem:
return 'multiple' return 'multiple'
def solution(self): def solution(self):
"""Solves the system if it has a single solution. Returns a dictionary mapping symbol to solution value.""" """Solves the system. Under- and overdetermined systems are supported.
Returns a dictionary mapping symbol to solution value."""
return sp.solve_linear_system(self._matrix, *self.unknowns) return sp.solve_linear_system(self._matrix, *self.unknowns)
def _resize_if_necessary(self, new_rows=1): def _resize_if_necessary(self, new_rows=1):
...@@ -228,6 +238,15 @@ class LinearEquationSystem: ...@@ -228,6 +238,15 @@ class LinearEquationSystem:
self.next_zero_row = result self.next_zero_row = result
def find_unique_solutions_with_zeros(system: LinearEquationSystem): class ContextVar:
if not system.solution_structure() != 'multiple': def __init__(self, value):
raise ValueError("Function works only for underdetermined systems") self.stack = [value]
@contextmanager
def __call__(self, new_value):
self.stack.append(new_value)
yield self
self.stack.pop()
def get(self):
return self.stack[-1]