Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 1898 additions and 570 deletions
......@@ -92,7 +92,7 @@ class SimplificationStrategy:
assignment_collection = t(assignment_collection)
end_time = timeit.default_timer()
op = assignment_collection.operation_count
time_str = "%.2f ms" % ((end_time - start_time) * 1000,)
time_str = f"{(end_time - start_time) * 1000:.2f} ms"
total = op['adds'] + op['muls'] + op['divs']
report.add(ReportElement(t.__name__, time_str, op['adds'], op['muls'], op['divs'], total))
return report
......@@ -129,7 +129,7 @@ class SimplificationStrategy:
def _repr_html_(self):
def print_assignment_collection(title, c):
text = '<h5 style="padding-bottom:10px">%s</h5> <div style="padding-left:20px;">' % (title, )
text = f'<h5 style="padding-bottom:10px">{title}</h5> <div style="padding-left:20px;">'
if self.restrict_symbols:
text += "\n".join(["$$" + sp.latex(e) + '$$'
for e in c.new_filtered(self.restrict_symbols).main_assignments])
......@@ -151,5 +151,5 @@ class SimplificationStrategy:
def __repr__(self):
result = "Simplification Strategy:\n"
for t in self._rules:
result += " - %s\n" % (t.__name__,)
result += f" - {t.__name__}\n"
return result
import sympy as sp
from pystencils.sympyextensions import is_constant
# Subexpression Insertion
def insert_subexpressions(ac, selection_callback, skip=None):
"""
Removes a number of subexpressions from an assignment collection by
inserting their right-hand side wherever they occur.
Args:
- selection_callback: Function that is called to qualify subexpressions
for insertion. Should return `True` for any subexpression that is to be
inserted, and `False` otherwise.
- skip: Set of symbols (left-hand sides of subexpressions) that should be
ignored even if qualified by the callback.
"""
if skip is None:
skip = set()
i = 0
while i < len(ac.subexpressions):
exp = ac.subexpressions[i]
if exp.lhs not in skip and selection_callback(exp):
ac = ac.new_with_inserted_subexpression(exp.lhs)
else:
i += 1
return ac
def insert_aliases(ac, **kwargs):
"""Inserts subexpressions that are aliases of other symbols,
i.e. their right-hand side is only another symbol."""
return insert_subexpressions(ac, lambda x: isinstance(x.rhs, sp.Symbol), **kwargs)
def insert_zeros(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is zero."""
zero = sp.Integer(0)
return insert_subexpressions(ac, lambda x: x.rhs == zero, **kwargs)
def insert_constants(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is constant,
i.e. contains no symbols."""
return insert_subexpressions(ac, lambda x: is_constant(x.rhs), **kwargs)
def insert_symbol_times_minus_one(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is just a
negation of another symbol."""
def callback(exp):
rhs = exp.rhs
minus_one = sp.Integer(-1)
atoms = rhs.atoms(sp.Symbol)
return len(atoms) == 1 and rhs == minus_one * atoms.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_constant_multiples(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is a constant
multiplied with another symbol."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
numbers = rhs.atoms(sp.Number)
return len(symbols) == 1 and len(numbers) == 1 and \
rhs == numbers.pop() * symbols.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_constant_additions(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is a sum of a
constant and another symbol."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
numbers = rhs.atoms(sp.Number)
return len(symbols) == 1 and len(numbers) == 1 and \
rhs == numbers.pop() + symbols.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_squares(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is another symbol squared."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
return len(symbols) == 1 and rhs == symbols.pop() ** 2
return insert_subexpressions(ac, callback, **kwargs)
def bind_symbols_to_skip(insertion_function, skip):
return lambda ac: insertion_function(ac, skip=skip)
from pystencils.simp import (SimplificationStrategy, insert_constants, insert_symbol_times_minus_one,
insert_constant_multiples, insert_constant_additions, insert_squares, insert_zeros)
def create_simplification_strategy():
"""
Creates a default simplification `ps.simp.SimplificationStrategy`. The idea behind the default simplification
strategy is to reduce the number of subexpressions by inserting single constants and to evaluate constant
terms beforehand.
"""
s = SimplificationStrategy()
s.add(insert_symbol_times_minus_one)
s.add(insert_constant_multiples)
s.add(insert_constant_additions)
s.add(insert_squares)
s.add(insert_zeros)
s.add(insert_constants)
s.add(lambda ac: ac.new_without_unused_subexpressions())
......@@ -89,9 +89,12 @@ def shift_slice(slices, offset):
raise ValueError()
if hasattr(offset, '__len__'):
return [shift_slice_component(k, off) for k, off in zip(slices, offset)]
return tuple(shift_slice_component(k, off) for k, off in zip(slices, offset))
else:
return [shift_slice_component(k, offset) for k in slices]
if isinstance(slices, slice) or isinstance(slices, int) or isinstance(slices, float):
return shift_slice_component(slices, offset)
else:
return tuple(shift_slice_component(k, offset) for k in slices)
def slice_from_direction(direction_name, dim, normal_offset=0, tangential_offset=0):
......
......@@ -5,6 +5,8 @@ from typing import Sequence
import numpy as np
import sympy as sp
from pystencils.utils import binary_numbers
def inverse_direction(direction):
"""Returns inverse i.e. negative of given direction tuple
......@@ -34,6 +36,8 @@ def is_valid(stencil, max_neighborhood=None):
True
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=1)
False
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=2)
True
"""
expected_dim = len(stencil[0])
for d in stencil:
......@@ -67,8 +71,11 @@ def have_same_entries(s1, s2):
Examples:
>>> stencil1 = [(1, 0), (-1, 0), (0, 1), (0, -1)]
>>> stencil2 = [(-1, 0), (0, -1), (1, 0), (0, 1)]
>>> stencil3 = [(-1, 0), (0, -1), (1, 0)]
>>> have_same_entries(stencil1, stencil2)
True
>>> have_same_entries(stencil1, stencil3)
False
"""
if len(s1) != len(s2):
return False
......@@ -288,6 +295,38 @@ def direction_string_to_offset(direction: str, dim: int = 3):
return offset[:dim]
def adjacent_directions(direction):
"""
Returns all adjacent directions for a direction as tuple of tuples. This is useful for exmple to find all directions
relevant for neighbour communication.
Args:
direction: tuple representing a direction. For example (0, 1, 0) for the northern side
Examples:
>>> adjacent_directions((0, 0, 0))
((0, 0, 0),)
>>> adjacent_directions((0, 1, 0))
((0, 1, 0),)
>>> adjacent_directions((0, 1, 1))
((0, 0, 1), (0, 1, 0), (0, 1, 1))
>>> adjacent_directions((-1, -1))
((-1, -1), (-1, 0), (0, -1))
"""
result = set()
if all(e == 0 for e in direction):
result.add(direction)
return tuple(result)
binary_numbers_list = binary_numbers(len(direction))
for adjacent_direction in binary_numbers_list:
for i, entry in enumerate(direction):
if entry == 0:
adjacent_direction[i] = 0
if entry == -1 and adjacent_direction[i] == 1:
adjacent_direction[i] = -1
if not all(e == 0 for e in adjacent_direction):
result.add(tuple(adjacent_direction))
return tuple(sorted(result))
# -------------------------------------- Visualization -----------------------------------------------------------------
......@@ -314,6 +353,7 @@ def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs)
Args:
stencil: sequence of directions
axes: optional matplotlib axes
figure: optional matplotlib figure
data: data to annotate the directions with, if none given, the indices are used
textsize: size of annotation text
"""
......@@ -335,7 +375,7 @@ def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs)
for direction, annotation in zip(stencil, data):
assert len(direction) == 2, "Works only for 2D stencils"
direction = tuple(int(i) for i in direction)
if not(direction[0] == 0 and direction[1] == 0):
if not (direction[0] == 0 and direction[1] == 0):
axes.arrow(0, 0, direction[0], direction[1], head_width=0.08, head_length=head_length, color='k')
if isinstance(annotation, sp.Basic):
......@@ -351,7 +391,7 @@ def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs)
else:
return 0
text_position = [direction[c] + position_correction(direction[c]) for c in (0, 1)]
axes.text(*text_position, annotation, verticalalignment='center',
axes.text(x=text_position[0], y=text_position[1], s=annotation, verticalalignment='center',
zorder=30, horizontalalignment='center', size=textsize,
bbox=dict(boxstyle=text_box_style, facecolor='#00b6eb', alpha=0.85, linewidth=0))
......@@ -369,6 +409,7 @@ def plot_3d_slicing(stencil, slice_axis=2, figure=None, data=None, **kwargs):
Args:
stencil: stencil as sequence of directions
slice_axis: 0, 1, or 2 indicating the axis to slice through
figure: optional matplotlib figure
data: optional data to print as text besides the arrows
"""
import matplotlib.pyplot as plt
......@@ -414,16 +455,17 @@ def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'):
FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
self._verts3d = xs, ys, zs
def draw(self, renderer):
def do_3d_projection(self, *_):
xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
FancyArrowPatch.draw(self, renderer)
return np.min(zs)
if axes is None:
if figure is None:
figure = plt.figure()
axes = figure.gca(projection='3d')
axes = figure.add_subplot(projection='3d')
try:
axes.set_aspect("equal")
except NotImplementedError:
......@@ -439,7 +481,7 @@ def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'):
r = [-1, 1]
for s, e in combinations(np.array(list(product(r, r, r))), 2):
if np.sum(np.abs(s - e)) == r[1] - r[0]:
axes.plot3D(*zip(s, e), color="k", alpha=0.5)
axes.plot(*zip(s, e), color="k", alpha=0.5)
for d, annotation in zip(stencil, data):
assert len(d) == 3, "Works only for 3D stencils"
......@@ -463,8 +505,8 @@ def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'):
else:
annotation = str(annotation)
axes.text(d[0] * text_offset, d[1] * text_offset, d[2] * text_offset,
annotation, verticalalignment='center', zorder=30,
axes.text(x=d[0] * text_offset, y=d[1] * text_offset, z=d[2] * text_offset,
s=annotation, verticalalignment='center', zorder=30,
size=textsize, bbox=dict(boxstyle=text_box_style, facecolor='#777777', alpha=0.6, linewidth=0))
axes.set_xlim([-text_offset * 1.1, text_offset * 1.1])
......
......@@ -6,10 +6,14 @@ from functools import partial, reduce
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union
import sympy as sp
from sympy import PolynomialError
from sympy.functions import Abs
from sympy.core.numbers import Zero
from pystencils.assignment import Assignment
from pystencils.data_types import cast_func, get_base_type, get_type_of_expression
from pystencils.functions import DivFunc
from pystencils.typing import CastFunc, get_type_of_expression, PointerType, VectorType
from pystencils.typing.typed_sympy import FieldPointerSymbol
T = TypeVar('T')
......@@ -156,17 +160,23 @@ def fast_subs(expression: T, substitutions: Dict,
if type(expression) is sp.Matrix:
return expression.copy().applyfunc(partial(fast_subs, substitutions=substitutions))
def visit(expr):
def visit(expr, evaluate=True):
if skip and skip(expr):
return expr
if hasattr(expr, "fast_subs"):
elif hasattr(expr, "fast_subs"):
return expr.fast_subs(substitutions, skip)
if expr in substitutions:
elif expr in substitutions:
return substitutions[expr]
if not hasattr(expr, 'args'):
elif not hasattr(expr, 'args'):
return expr
param_list = [visit(a) for a in expr.args]
return expr if not param_list else expr.func(*param_list)
elif isinstance(expr, (sp.UnevaluatedExpr, DivFunc)):
args = [visit(a, False) for a in expr.args]
return expr.func(*args)
else:
param_list = [visit(a, evaluate) for a in expr.args]
if isinstance(expr, (sp.Mul, sp.Add)):
return expr if not param_list else expr.func(*param_list, evaluate=evaluate)
return expr if not param_list else expr.func(*param_list)
if len(substitutions) == 0:
return expression
......@@ -233,6 +243,9 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
normalized_replacement_match = normalize_match_parameter(required_match_replacement, len(subexpression.args))
if isinstance(subexpression, sp.Number):
return expr.subs({replacement: subexpression})
def visit(current_expr):
if current_expr.is_Add:
expr_max_length = max(len(current_expr.args), len(subexpression.args))
......@@ -260,8 +273,8 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
if not param_list:
return current_expr
else:
if current_expr.func == sp.Mul and sp.numbers.Zero() in param_list:
return sp.numbers.Zero()
if current_expr.func == sp.Mul and Zero() in param_list:
return sp.simplify(current_expr)
else:
return current_expr.func(*param_list, evaluate=False)
......@@ -271,7 +284,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol],
positive: Optional[bool] = None,
replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr:
"""Replaces second order mixed terms like x*y by 2*( (x+y)**2 - x**2 - y**2 ).
"""Replaces second order mixed terms like 4*x*y by 2*( (x+y)**2 - x**2 - y**2 ).
This makes the term longer - simplify usually is undoing these - however this
transformation can be done to find more common sub-expressions
......@@ -292,7 +305,7 @@ def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Sym
if expr.is_Mul:
distinct_search_symbols = set()
nr_of_search_terms = 0
other_factors = 1
other_factors = sp.Integer(1)
for t in expr.args:
if t in search_symbols:
nr_of_search_terms += 1
......@@ -343,7 +356,7 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order
factor_count = 0
if type(product) is Mul:
for factor in product.args:
if type(factor) == Pow:
if type(factor) is Pow:
if factor.args[0] in symbols:
factor_count += factor.args[1]
if factor in symbols:
......@@ -353,13 +366,13 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order
factor_count += product.args[1]
return factor_count
if type(expr) == Mul or type(expr) == Pow:
if type(expr) is Mul or type(expr) is Pow:
if velocity_factors_in_product(expr) <= order:
return expr
else:
return sp.Rational(0, 1)
return Zero()
if type(expr) != Add:
if type(expr) is not Add:
return expr
for sum_term in expr.args:
......@@ -429,7 +442,104 @@ def extract_most_common_factor(term):
return common_factor, term / common_factor
def count_operations(term: Union[sp.Expr, List[sp.Expr]],
def recursive_collect(expr, symbols, order_by_occurences=False):
"""Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1,
and so on.
``expr`` must be rewritable as a polynomial in the given ``symbols``.
It it is not, ``recursive_collect`` will fail quietly, returning the original expression.
Args:
expr: A sympy expression.
symbols: A sequence of symbols
order_by_occurences: If True, during recursive descent, always collect the symbol occuring
most often in the expression.
"""
if order_by_occurences:
symbols = list(expr.atoms(sp.Symbol) & set(symbols))
symbols = sorted(symbols, key=expr.count, reverse=True)
if len(symbols) == 0:
return expr
symbol = symbols[0]
collected = expr.collect(symbol)
try:
collected_poly = sp.Poly(collected, symbol)
except PolynomialError:
return expr
coeffs = collected_poly.all_coeffs()[::-1]
rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs))
return rec_sum
def summands(expr):
return set(expr.args) if isinstance(expr, sp.Add) else {expr}
def simplify_by_equality(expr, a, b, c):
"""
Uses the equality a = b + c, where a and b must be symbols, to simplify expr
by attempting to express additive combinations of two quantities by the third.
This works on expressions that are reducible to the form
:math:`a * (...) + b * (...) + c * (...)`,
without any mixed terms of a, b and c.
"""
if not isinstance(a, sp.Symbol) or not isinstance(b, sp.Symbol):
raise ValueError("a and b must be symbols.")
c = sp.sympify(c)
if not (isinstance(c, sp.Symbol) or is_constant(c)):
raise ValueError("c must be either a symbol or a constant!")
expr = sp.sympify(expr)
expr_expanded = sp.expand(expr)
a_coeff = expr_expanded.coeff(a, 1)
expr_expanded -= (a * a_coeff).expand()
b_coeff = expr_expanded.coeff(b, 1)
expr_expanded -= (b * b_coeff).expand()
if isinstance(c, sp.Symbol):
c_coeff = expr_expanded.coeff(c, 1)
rest = expr_expanded - (c * c_coeff).expand()
else:
c_coeff = expr_expanded / c
rest = 0
a_summands = summands(a_coeff)
b_summands = summands(b_coeff)
c_summands = summands(c_coeff)
# replace b + c by a
b_plus_c_coeffs = b_summands & c_summands
for coeff in b_plus_c_coeffs:
rest += a * coeff
b_summands -= b_plus_c_coeffs
c_summands -= b_plus_c_coeffs
# replace a - b by c
neg_b_summands = {-x for x in b_summands}
a_minus_b_coeffs = a_summands & neg_b_summands
for coeff in a_minus_b_coeffs:
rest += c * coeff
a_summands -= a_minus_b_coeffs
b_summands -= {-x for x in a_minus_b_coeffs}
# replace a - c by b
neg_c_summands = {-x for x in c_summands}
a_minus_c_coeffs = a_summands & neg_c_summands
for coeff in a_minus_c_coeffs:
rest += b * coeff
a_summands -= a_minus_c_coeffs
c_summands -= {-x for x in a_minus_c_coeffs}
# put it back together
return (rest + a * sum(a_summands) + b * sum(b_summands) + c * sum(c_summands)).expand()
def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
only_type: Optional[str] = 'real') -> Dict[str, int]:
"""Counts the number of additions, multiplications and division.
......@@ -444,7 +554,6 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0,
'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0}
if isinstance(term, Sequence):
for element in term:
r = count_operations(element, only_type)
......@@ -454,16 +563,20 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
elif isinstance(term, Assignment):
term = term.rhs
if hasattr(term, 'evalf'):
term = term.evalf()
def check_type(e):
if only_type is None:
return True
if isinstance(e, FieldPointerSymbol) and only_type == "real":
return only_type == "int"
try:
base_type = get_base_type(get_type_of_expression(e))
base_type = get_type_of_expression(e)
except ValueError:
return False
if isinstance(base_type, VectorType):
return False
if isinstance(base_type, PointerType):
return only_type == 'int'
if only_type == 'int' and (base_type.is_int() or base_type.is_uint()):
return True
if only_type == 'real' and (base_type.is_float()):
......@@ -492,7 +605,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
visit_children = False
elif t.is_integer:
pass
elif isinstance(t, cast_func):
elif isinstance(t, CastFunc):
visit_children = False
visit(t.args[0])
elif t.func is fast_sqrt:
......@@ -508,13 +621,17 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
if t.exp >= 0:
result['muls'] += int(t.exp) - 1
else:
result['muls'] -= 1
if result['muls'] > 0:
result['muls'] -= 1
result['divs'] += 1
result['muls'] += (-int(t.exp)) - 1
elif sp.nsimplify(t.exp) == sp.Rational(1, 2):
result['sqrts'] += 1
elif sp.nsimplify(t.exp) == -sp.Rational(1, 2):
result["sqrts"] += 1
result["divs"] += 1
else:
warnings.warn("Cannot handle exponent", t.exp, " of sp.Pow node")
warnings.warn(f"Cannot handle exponent {t.exp} of sp.Pow node")
else:
warnings.warn("Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate")
......@@ -522,10 +639,12 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
for child_term, condition in t.args:
visit(child_term)
visit_children = False
elif isinstance(t, sp.Rel):
elif isinstance(t, (sp.Rel, sp.UnevaluatedExpr)):
pass
elif isinstance(t, DivFunc):
result["divs"] += 1
else:
warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate")
warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate")
if visit_children:
for a in t.args:
......
File moved
import hashlib
import pickle
import warnings
from collections import OrderedDict, defaultdict, namedtuple
from collections import OrderedDict
from copy import deepcopy
from types import MappingProxyType
from typing import Set
import numpy as np
import sympy as sp
from sympy.core.numbers import ImaginaryUnit
from sympy.logic.boolalg import Boolean, BooleanFunction
import pystencils as ps
import pystencils.astnodes as ast
import pystencils.integer_functions
from pystencils.assignment import Assignment
from pystencils.data_types import (
PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type,
get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
from pystencils.field import AbstractField, Field, FieldType
from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.typing import (CastFunc, PointerType, StructType, TypedSymbol, get_base_type,
ReinterpretCastFunc, get_next_parent_of_type, parents_of_type)
from pystencils.field import Field, FieldType
from pystencils.typing import FieldPointerSymbol
from pystencils.sympyextensions import fast_subs
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.slicing import normalize_slice
from pystencils.integer_functions import int_div
class NestedScopes:
......@@ -101,6 +100,45 @@ def generic_visit(term, visitor):
return visitor(term)
def iterate_loops_by_depth(node, nesting_depth):
"""Iterate all LoopOverCoordinate nodes in the given AST of the specified nesting depth.
Args:
node: Root node of the abstract syntax tree
nesting_depth: Nesting depth of the loops the pragmas should be applied to.
Outermost loop has depth 0.
A depth of -1 indicates the innermost loops.
Returns: Iterable listing all loop nodes of given nesting depth.
"""
from pystencils.astnodes import LoopOverCoordinate
def _internal_default(node, nesting_depth):
isloop = isinstance(node, LoopOverCoordinate)
if nesting_depth < 0: # here, a negative value indicates end of descent
return
elif nesting_depth == 0 and isloop:
yield node
else:
next_depth = nesting_depth - 1 if isloop else nesting_depth
for arg in node.args:
yield from _internal_default(arg, next_depth)
def _internal_innermost(node):
if isinstance(node, LoopOverCoordinate) and node.is_innermost_loop:
yield node
else:
for arg in node.args:
yield from _internal_innermost(arg)
if nesting_depth >= 0:
yield from _internal_default(node, nesting_depth)
elif nesting_depth == -1:
yield from _internal_innermost(node)
else:
raise ValueError(f"Invalid nesting depth: {nesting_depth}. Choose a nonnegative number, or -1.")
def unify_shape_symbols(body, common_shape, fields):
"""Replaces symbols for array sizes to ensure they are represented by the same unique symbol.
......@@ -125,9 +163,10 @@ def unify_shape_symbols(body, common_shape, fields):
body.subs(substitutions)
def get_common_shape(field_set):
"""Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
ValueError is raised"""
def get_common_field(field_set):
"""Takes a set of pystencils Fields, checks if a common spatial shape exists and returns one
representative field, that can be used for shape information etc. in the kernel creation.
If the fields have different shapes ValueError is raised"""
nr_of_fixed_shaped_fields = 0
for f in field_set:
if f.has_fixed_shape:
......@@ -137,7 +176,7 @@ def get_common_shape(field_set):
fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape])
var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape])
msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
msg += "Variable shaped: %s \nFixed shaped: %s" % (var_field_names, fixed_field_names)
msg += f"Variable shaped: {var_field_names} \nFixed shaped: {fixed_field_names}"
raise ValueError(msg)
shape_set = set([f.spatial_shape for f in field_set])
......@@ -145,8 +184,9 @@ def get_common_shape(field_set):
if len(shape_set) != 1:
raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0]
return shape
# Sort the fields by their name to ensure that always the same field is returned
reference_field = sorted(field_set, key=lambda e: str(e))[0]
return reference_field
def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None):
......@@ -164,9 +204,11 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
tuple of loop-node, ghost_layer_info
"""
# find correct ordering by inspecting participating FieldAccesses
field_accesses = body.atoms(AbstractField.AbstractAccess)
absolut_accesses_only = False
field_accesses = body.atoms(Field.Access)
field_accesses = {e for e in field_accesses if not e.is_absolute_access}
if len(field_accesses) == 0: # when kernel contains only absolute accesses
absolut_accesses_only = True
# exclude accesses to buffers from field_list, because buffers are treated separately
field_list = [e.field for e in field_accesses if not (FieldType.is_buffer(e.field) or FieldType.is_custom(e.field))]
if len(field_list) == 0: # when kernel contains only custom fields
......@@ -177,14 +219,23 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
if loop_order is None:
loop_order = get_optimal_loop_ordering(fields)
shape = get_common_shape(fields)
unify_shape_symbols(body, common_shape=shape, fields=fields)
if absolut_accesses_only:
absolut_access_fields = {e.field for e in body.atoms(Field.Access)}
common_field = get_common_field(absolut_access_fields)
common_shape = common_field.spatial_shape
else:
common_field = get_common_field(fields)
common_shape = common_field.spatial_shape
unify_shape_symbols(body, common_shape=common_shape, fields=fields)
if iteration_slice is not None:
iteration_slice = normalize_slice(iteration_slice, shape)
iteration_slice = normalize_slice(iteration_slice, common_shape)
if ghost_layers is None:
required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
if absolut_accesses_only:
required_ghost_layers = 0
else:
required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
if isinstance(ghost_layers, int):
ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)
......@@ -193,7 +244,7 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
for i, loop_coordinate in enumerate(reversed(loop_order)):
if iteration_slice is None:
begin = ghost_layers[loop_coordinate][0]
end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
end = common_shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
current_body = ast.Block([new_loop])
else:
......@@ -210,6 +261,28 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
return current_body, ghost_layers
def get_common_indexed_element(indexed_elements: Set[sp.IndexedBase]) -> sp.IndexedBase:
assert len(indexed_elements) > 0, "indexed_elements can not be empty"
shape_set = {s.shape for s in indexed_elements}
if len(shape_set) != 1:
for shape in shape_set:
assert not isinstance(shape, int), "If indexed elements are used, they must all have the same shape"
return sorted(indexed_elements, key=lambda e: str(e))[0]
def add_outer_loop_over_indexed_elements(loop_node: ast.Block) -> ast.Block:
indexed_elements = loop_node.atoms(sp.Indexed)
if len(indexed_elements) == 0:
return loop_node
reference_element = get_common_indexed_element(indexed_elements)
index = reference_element.indices[0].atoms(TypedSymbol)
assert len(index) == 1, "index expressions must only contain one symbol representing the index"
new_loop = ast.LoopOverCoordinate(loop_node, 0, 0,
reference_element.shape[0], 1, custom_loop_ctr=index.pop())
return ast.Block([new_loop])
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
r"""
Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
......@@ -326,7 +399,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dime
index = int(element[len("index"):])
add_new_element(spatial_dimensions + index)
else:
raise ValueError("Unknown specification %s" % (element,))
raise ValueError(f"Unknown specification {element}")
result.append(new_group)
......@@ -345,7 +418,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
ast_node: ast before any field accesses are resolved
loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
for GPU kernels: list of 'loop counters' from inner to outer loop
loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default
loop_iterations: iteration slice for each loop from inner to outer, for CPU kernels leave to default
Returns:
base buffer index - required by 'resolve_buffer_accesses' function
......@@ -357,26 +430,46 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
assert len(loops) == len(parents_of_innermost_loop)
assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
loop_iterations = [(l.stop - l.start) / l.step for l in loops]
loop_counters = [l.loop_counter_symbol for l in loops]
loop_counters = [loop.loop_counter_symbol for loop in loops]
loop_iterations = [slice(loop.start, loop.stop, loop.step) for loop in loops]
actual_sizes = list()
actual_steps = list()
for ctr, s in zip(loop_counters, loop_iterations):
if s.step != 1:
if (s.stop - s.start) % s.step == 0:
actual_sizes.append((s.stop - s.start) // s.step)
else:
actual_sizes.append(int_div((s.stop - s.start), s.step))
if (ctr - s.start) % s.step == 0:
actual_steps.append((ctr - s.start) // s.step)
else:
actual_steps.append(int_div((ctr - s.start), s.step))
else:
actual_sizes.append(s.stop - s.start)
actual_steps.append(ctr - s.start)
field_accesses = ast_node.atoms(AbstractField.AbstractAccess)
field_accesses = ast_node.atoms(Field.Access)
buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
loop_counters = [v * len(buffer_accesses) for v in loop_counters]
buffer_index_size = len(buffer_accesses)
base_buffer_index = loop_counters[0]
stride = 1
for idx, var in enumerate(loop_counters[1:]):
cur_stride = loop_iterations[idx]
stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
base_buffer_index += var * stride
return base_buffer_index
base_buffer_index = actual_steps[0]
actual_stride = 1
for idx, actual_step in enumerate(actual_steps[1:]):
cur_stride = actual_sizes[idx]
actual_stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
base_buffer_index += actual_stride * actual_step
return base_buffer_index * buffer_index_size
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=None):
if read_only_field_names is None:
read_only_field_names = set()
def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, AbstractField.AbstractAccess):
if isinstance(expr, Field.Access):
field_access = expr
# Do not apply transformation if field is not a buffer
......@@ -419,7 +512,7 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s
return visit_node(ast_node)
def resolve_field_accesses(ast_node, read_only_field_names=set(),
def resolve_field_accesses(ast_node, read_only_field_names=None,
field_to_base_pointer_info=MappingProxyType({}),
field_to_fixed_coordinates=MappingProxyType({})):
"""
......@@ -436,11 +529,13 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
Returns
transformed AST
"""
if read_only_field_names is None:
read_only_field_names = set()
field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, AbstractField.AbstractAccess):
if isinstance(expr, Field.Access):
field_access = expr
field = field_access.field
......@@ -456,10 +551,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
if field.name in field_to_base_pointer_info:
base_pointer_info = field_to_base_pointer_info[field.name]
else:
base_pointer_info = [
list(
range(field.index_dimensions + field.spatial_dimensions))
]
base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))]
field_ptr = FieldPointerSymbol(
field.name,
......@@ -500,7 +592,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
coord_dict = create_coordinate_dict(group)
new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
if new_ptr not in enclosing_block.symbols_defined:
new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False)
new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False, use_auto=False)
enclosing_block.insert_before(new_assignment, sympy_assignment)
last_pointer = new_ptr
......@@ -514,7 +606,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
if isinstance(accessed_field_name, sp.Symbol):
accessed_field_name = accessed_field_name.name
new_type = field_access.field.dtype.get_element_type(accessed_field_name)
result = reinterpret_cast_func(result, new_type)
result = ReinterpretCastFunc(result, new_type)
return visit_sympy_expr(result, enclosing_block, sympy_assignment)
else:
......@@ -564,21 +656,65 @@ def move_constants_before_loop(ast_node):
"""
assert isinstance(node.parent, ast.Block)
def modifies_or_declares(node: ast.Node, symbol_names: Set[str]) -> bool:
if isinstance(node, (ps.Assignment, ast.SympyAssignment)):
if isinstance(node.lhs, ast.ResolvedFieldAccess):
return node.lhs.typed_symbol.name in symbol_names
else:
return node.lhs.name in symbol_names
elif isinstance(node, ast.Block):
for arg in node.args:
if isinstance(arg, ast.SympyAssignment) and arg.is_declaration:
continue
if modifies_or_declares(arg, symbol_names):
return True
return False
elif isinstance(node, ast.LoopOverCoordinate):
return modifies_or_declares(node.body, symbol_names)
elif isinstance(node, ast.Conditional):
return (
modifies_or_declares(node.true_block, symbol_names)
or (node.false_block and modifies_or_declares(node.false_block, symbol_names))
)
elif isinstance(node, ast.KernelFunction):
return False
else:
defs = {s.name for s in node.symbols_defined}
return bool(symbol_names.intersection(defs))
dependencies = {s.name for s in node.undefined_symbols}
last_block = node.parent
last_block_child = node
element = node.parent
prev_element = node
while element:
if isinstance(element, ast.Block):
if isinstance(element, (ast.Conditional, ast.KernelFunction)):
# Never move out of Conditionals or KernelFunctions.
break
elif isinstance(element, ast.Block):
last_block = element
last_block_child = prev_element
if isinstance(element, ast.Conditional):
break
if any(modifies_or_declares(sibling, dependencies) for sibling in element.args):
# The node depends on one of the statements in this block.
# Do not move further out.
break
elif isinstance(element, ast.LoopOverCoordinate):
if element.loop_counter_symbol.name in dependencies:
# The node depends on the loop counter.
# Do not move out of this loop.
break
else:
critical_symbols = element.symbols_defined
if node.undefined_symbols.intersection(critical_symbols):
break
raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n'
f'The expression {element} of type {type(element)} is not known yet.')
# No dependencies to symbols defined/modified within the current element.
# We can move the node up one level and in front of the current element.
prev_element = element
element = element.parent
return last_block, last_block_child
......@@ -602,13 +738,7 @@ def move_constants_before_loop(ast_node):
get_blocks(ast_node, all_blocks)
for block in all_blocks:
children = block.take_child_nodes()
# Every time a symbol can be replaced in the current block because the assignment
# was found in a parent block, but with a different lhs symbol (same rhs)
# the outer symbol is inserted here as key.
substitute_variables = {}
for child in children:
# Before traversing the next child, all symbols are substituted first.
child.subs(substitute_variables)
if not isinstance(child, ast.SympyAssignment): # only move SympyAssignments
block.append(child)
......@@ -624,14 +754,7 @@ def move_constants_before_loop(ast_node):
exists_already = False
if not exists_already:
rhs_identical = check_if_assignment_already_in_block(child, target, True)
if rhs_identical:
# there is already an assignment out there with the same rhs
# -> replace all lhs symbols in this block with the lhs of the outer assignment
# -> remove the local assignment (do not re-append child to the former block)
substitute_variables[child.lhs] = rhs_identical.lhs
else:
target.insert_before(child, child_to_insert_before)
target.insert_before(child, child_to_insert_before)
elif exists_already and exists_already.rhs == child.rhs:
if target.args.index(exists_already) > target.args.index(child_to_insert_before):
assert target.args.count(exists_already) == 1
......@@ -645,7 +768,7 @@ def move_constants_before_loop(ast_node):
new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const),
child_to_insert_before)
substitute_variables[child.lhs] = new_symbol
block.append(ast.SympyAssignment(child.lhs, new_symbol, is_const=child.is_const))
def split_inner_loop(ast_node: ast.Node, symbol_groups):
......@@ -659,11 +782,11 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
and which no symbol in a symbol group depends on, are not updated!
"""
all_loops = ast_node.atoms(ast.LoopOverCoordinate)
inner_loop = [l for l in all_loops if l.is_innermost_loop]
inner_loop = [loop for loop in all_loops if loop.is_innermost_loop]
assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
inner_loop = inner_loop[0]
assert type(inner_loop.body) is ast.Block
outer_loop = [l for l in all_loops if l.is_outermost_loop]
outer_loop = [loop for loop in all_loops if loop.is_outermost_loop]
assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
outer_loop = outer_loop[0]
......@@ -682,13 +805,13 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
if s in assignment_map: # if there is no assignment inside the loop body it is independent already
for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
if not isinstance(new_symbol, AbstractField.AbstractAccess) and \
if not isinstance(new_symbol, Field.Access) and \
new_symbol not in symbols_with_temporary_array:
symbols_to_process.append(new_symbol)
symbols_resolved.add(s)
for symbol in symbol_group:
if not isinstance(symbol, AbstractField.AbstractAccess):
if not isinstance(symbol, Field.Access):
assert type(symbol) is TypedSymbol
new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
symbols_with_temporary_array[symbol] = sp.IndexedBase(
......@@ -697,9 +820,9 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
assignment_group = []
for assignment in inner_loop.body.args:
if assignment.lhs in symbols_resolved:
new_rhs = assignment.rhs.subs(
symbols_with_temporary_array.items())
if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group:
# use fast_subs here because it checks if multiplications should be evaluated or not
new_rhs = fast_subs(assignment.rhs, symbols_with_temporary_array)
if not isinstance(assignment.lhs, Field.Access) and assignment.lhs in symbol_group:
assert type(assignment.lhs) is TypedSymbol
new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
new_lhs = sp.IndexedBase(new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
......@@ -728,7 +851,8 @@ def cut_loop(loop_node, cutting_points):
One loop is transformed into len(cuttingPoints)+1 new loops that range from
old_begin to cutting_points[1], ..., cutting_points[-1] to old_end
Modifies the ast in place
Modifies the ast in place. Note Issue #5783 of SymPy. Deepcopy will evaluate mul
https://github.com/sympy/sympy/issues/5783
Returns:
list of new loop nodes
......@@ -766,11 +890,16 @@ def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = Fa
This analysis needs the integer set library (ISL) islpy, so it is not done by
default.
"""
from sympy.codegen.rewriting import ReplaceOptim, optimize
remove_casts = ReplaceOptim(lambda e: isinstance(e, CastFunc), lambda p: p.expr)
for conditional in node.atoms(ast.Conditional):
conditional.condition_expr = sp.simplify(conditional.condition_expr)
if conditional.condition_expr == sp.true:
# TODO simplify conditional before the type system! Casts make it very hard here
condition_expression = optimize(conditional.condition_expr, [remove_casts])
condition_expression = sp.simplify(condition_expression)
if condition_expression == sp.true:
conditional.parent.replace(conditional, [conditional.true_block])
elif conditional.condition_expr == sp.false:
elif condition_expression == sp.false:
conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
elif loop_counter_simplification:
try:
......@@ -796,282 +925,6 @@ def cleanup_blocks(node: ast.Node) -> None:
cleanup_blocks(a)
class KernelConstraintsCheck:
"""Checks if the input to create_kernel is valid.
Test the following conditions:
- SSA Form for pure symbols:
- Every pure symbol may occur only once as left-hand-side of an assignment
- Every pure symbol that is read, may not be written to later
- Independence / Parallelization condition:
- a field that is written may only be read at exact the same spatial position
(Pure symbols are symbols that are not Field.Accesses)
"""
FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
def __init__(self, type_for_symbol, check_independence_condition, check_double_write_condition=True):
self._type_for_symbol = type_for_symbol
self.scopes = NestedScopes()
self._field_writes = defaultdict(set)
self.fields_read = set()
self.check_independence_condition = check_independence_condition
self.check_double_write_condition = check_double_write_condition
def process_assignment(self, assignment):
# for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
new_rhs = self.process_expression(assignment.rhs)
new_lhs = self._process_lhs(assignment.lhs)
return ast.SympyAssignment(new_lhs, new_rhs)
def process_expression(self, rhs, type_constants=True):
from pystencils.interpolation_astnodes import InterpolatorAccess
self._update_accesses_rhs(rhs)
if isinstance(rhs, AbstractField.AbstractAccess):
self.fields_read.add(rhs.field)
self.fields_read.update(rhs.indirect_addressing_fields)
return rhs
elif isinstance(rhs, InterpolatorAccess):
new_args = [self.process_expression(arg, type_constants) for arg in rhs.offsets]
if new_args:
rhs.offsets = new_args
return rhs
elif isinstance(rhs, ImaginaryUnit):
return TypedImaginaryUnit(create_type(self._type_for_symbol['_complex_type']))
elif isinstance(rhs, TypedSymbol):
return rhs
elif isinstance(rhs, sp.Symbol):
return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
elif type_constants and isinstance(rhs, np.generic):
return cast_func(rhs, create_type(rhs.dtype))
elif type_constants and isinstance(rhs, sp.Number):
return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
# Very important that this clause comes before BooleanFunction
elif isinstance(rhs, cast_func):
return cast_func(
self.process_expression(rhs.args[0], type_constants=False),
rhs.dtype)
elif isinstance(rhs, BooleanFunction) or \
type(rhs) in pystencils.integer_functions.__dict__.values():
new_args = [self.process_expression(a, type_constants) for a in rhs.args]
types_of_expressions = [get_type_of_expression(a) for a in new_args]
arg_type = collate_types(types_of_expressions, forbid_collation_to_float=True)
new_args = [a if not hasattr(a, 'dtype') or a.dtype == arg_type
else cast_func(a, arg_type)
for a in new_args]
return rhs.func(*new_args)
elif isinstance(rhs, sp.Mul):
new_args = [
self.process_expression(arg, type_constants)
if arg not in (-1, 1) else arg for arg in rhs.args
]
return rhs.func(*new_args) if new_args else rhs
elif isinstance(rhs, sp.Indexed):
return rhs
else:
if isinstance(rhs, sp.Pow):
# don't process exponents -> they should remain integers
return sp.Pow(
self.process_expression(rhs.args[0], type_constants),
rhs.args[1])
else:
new_args = [
self.process_expression(arg, type_constants)
for arg in rhs.args
]
return rhs.func(*new_args) if new_args else rhs
@property
def fields_written(self):
return set(k.field for k, v in self._field_writes.items() if len(v))
def _process_lhs(self, lhs):
assert isinstance(lhs, sp.Symbol)
self._update_accesses_lhs(lhs)
if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)):
return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
else:
return lhs
def _update_accesses_lhs(self, lhs):
if isinstance(lhs, AbstractField.AbstractAccess):
fai = self.FieldAndIndex(lhs.field, lhs.index)
self._field_writes[fai].add(lhs.offsets)
if self.check_double_write_condition and len(self._field_writes[fai]) > 1:
raise ValueError(
"Field {} is written at two different locations".format(
lhs.field.name))
elif isinstance(lhs, sp.Symbol):
if self.scopes.is_defined_locally(lhs):
raise ValueError("Assignments not in SSA form, multiple assignments to {}".format(lhs.name))
if lhs in self.scopes.free_parameters:
raise ValueError("Symbol {} is written, after it has been read".format(lhs.name))
self.scopes.define_symbol(lhs)
def _update_accesses_rhs(self, rhs):
if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition:
writes = self._field_writes[self.FieldAndIndex(
rhs.field, rhs.index)]
for write_offset in writes:
assert len(writes) == 1
if write_offset != rhs.offsets:
raise ValueError("Violation of loop independence condition. Field "
"{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset))
self.fields_read.add(rhs.field)
elif isinstance(rhs, sp.Symbol):
self.scopes.access_symbol(rhs)
def add_types(eqs, type_for_symbol, check_independence_condition):
"""Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
Additionally returns sets of all fields which are read/written
Args:
eqs: list of equations
type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed
kernels
Returns:
``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
list of equations where symbols have been replaced by typed symbols
"""
if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'):
type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
def visit(obj):
if isinstance(obj, (list, tuple)):
return [visit(e) for e in obj]
if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)):
return check.process_assignment(obj)
elif isinstance(obj, ast.Conditional):
check.scopes.push()
# Disable double write check inside conditionals
# would be triggered by e.g. in-kernel boundaries
check.check_double_write_condition = False
false_block = None if obj.false_block is None else visit(
obj.false_block)
result = ast.Conditional(check.process_expression(
obj.condition_expr, type_constants=False),
true_block=visit(obj.true_block),
false_block=false_block)
check.check_double_write_condition = True
check.scopes.pop()
return result
elif isinstance(obj, ast.Block):
check.scopes.push()
result = ast.Block([visit(e) for e in obj.args])
check.scopes.pop()
return result
elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
return obj
else:
raise ValueError("Invalid object in kernel " + str(type(obj)))
typed_equations = visit(eqs)
return check.fields_read, check.fields_written, typed_equations
def insert_casts(node):
"""Checks the types and inserts casts and pointer arithmetic where necessary.
Args:
node: the head node of the ast
Returns:
modified AST
"""
def cast(zipped_args_types, target_dtype):
"""
Adds casts to the arguments if their type differs from the target type
:param zipped_args_types: a zipped list of args and types
:param target_dtype: The target data type
:return: args with possible casts
"""
casted_args = []
for argument, data_type in zipped_args_types:
if data_type.numpy_dtype != target_dtype.numpy_dtype: # ignoring const
casted_args.append(cast_func(argument, target_dtype))
else:
casted_args.append(argument)
return casted_args
def pointer_arithmetic(expr_args):
"""
Creates a valid pointer arithmetic function
:param expr_args: Arguments of the add expression
:return: pointer_arithmetic_func
"""
pointer = None
new_args = []
for arg, data_type in expr_args:
if data_type.func is PointerType:
assert pointer is None
pointer = arg
for arg, data_type in expr_args:
if arg != pointer:
assert data_type.is_int() or data_type.is_uint()
new_args.append(arg)
new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args
return pointer_arithmetic_func(pointer, new_args)
if isinstance(node, sp.AtomicExpr) or isinstance(node, cast_func):
return node
args = []
for arg in node.args:
args.append(insert_casts(arg))
# TODO indexed, LoopOverCoordinate
if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
# TODO optimize pow, don't cast integer on double
types = [get_type_of_expression(arg) for arg in args]
assert len(types) > 0
# Never ever, ever collate to float type for boolean functions!
target = collate_types(types, forbid_collation_to_float=isinstance(node.func, BooleanFunction))
zipped = list(zip(args, types))
if target.func is PointerType:
assert node.func is sp.Add
return pointer_arithmetic(zipped)
else:
return node.func(*cast(zipped, target))
elif node.func is ast.SympyAssignment:
lhs = args[0]
rhs = args[1]
target = get_type_of_expression(lhs)
if target.func is PointerType:
return node.func(*args) # TODO fix, not complete
else:
return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
elif node.func is ast.ResolvedFieldAccess:
return node
elif node.func is ast.Block:
for old_arg, new_arg in zip(node.args, args):
node.replace(old_arg, new_arg)
return node
elif node.func is ast.LoopOverCoordinate:
for old_arg, new_arg in zip(node.args, args):
node.replace(old_arg, new_arg)
return node
elif node.func is sp.Piecewise:
expressions = [expr for (expr, _) in args]
types = [get_type_of_expression(expr) for expr in expressions]
target = collate_types(types)
zipped = list(zip(expressions, types))
casted_expressions = cast(zipped, target)
args = [
arg.func(*[expr, arg.cond])
for (arg, expr) in zip(args, casted_expressions)
]
return node.func(*args)
def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, include_first=True) -> None:
"""Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last or
first and last element"""
......@@ -1094,73 +947,6 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, i
# --------------------------------------- Helper Functions -------------------------------------------------------------
def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='int64'):
"""
Creates a default symbol name to type mapping.
If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
Args:
eqs: list of equations
default_type: the type for non-boolean symbols
Returns:
dictionary, mapping symbol name to type
"""
result = defaultdict(lambda: default_type)
if hasattr(default_type, 'numpy_dtype'):
result['_complex_type'] = (np.zeros((1,), default_type.numpy_dtype) * 1j).dtype
else:
result['_complex_type'] = (np.zeros((1,), default_type) * 1j).dtype
for eq in eqs:
if isinstance(eq, ast.Conditional):
result.update(typing_from_sympy_inspection(eq.true_block.args))
if eq.false_block:
result.update(typing_from_sympy_inspection(
eq.false_block.args))
elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment):
continue
else:
from pystencils.cpu.vectorization import vec_all, vec_any
if isinstance(eq.rhs, (vec_all, vec_any)):
result[eq.lhs.name] = "bool"
# problematic case here is when rhs is a symbol: then it is impossible to decide here without
# further information what type the left hand side is - default fallback is the dict value then
if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
result[eq.lhs.name] = "bool"
try:
result[eq.lhs.name] = get_type_of_expression(eq.rhs,
default_float_type=default_type,
default_int_type=default_int_type,
symbol_type_dict=result)
except Exception:
pass # gracefully fail in case get_type_of_expression cannot determine type
return result
def get_next_parent_of_type(node, parent_type):
"""Returns the next parent node of given type or None, if root is reached.
Traverses the AST nodes parents until a parent of given type was found.
If no such parent is found, None is returned
"""
parent = node.parent
while parent is not None:
if isinstance(parent, parent_type):
return parent
parent = parent.parent
return None
def parents_of_type(node, parent_type, include_current=False):
"""Generator for all parent nodes of given type"""
parent = node if include_current else node.parent
while parent is not None:
if isinstance(parent, parent_type):
yield parent
parent = parent.parent
def get_optimal_loop_ordering(fields):
"""
Determines the optimal loop order for a given set of fields.
......@@ -1206,13 +992,13 @@ def get_loop_hierarchy(ast_node):
return reversed(result)
def get_loop_counter_symbol_hierarchy(astNode):
def get_loop_counter_symbol_hierarchy(ast_node):
"""Determines the loop counter symbols around a given AST node.
:param astNode: the AST node
:param ast_node: the AST node
:return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop
"""
result = []
node = astNode
node = ast_node
while node is not None:
node = get_next_parent_of_type(node, ast.LoopOverCoordinate)
if node:
......@@ -1258,7 +1044,8 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
Args:
ast_node: kernel function node before vectorization transformation has been applied
block_size: sequence defining block size in x, y, (z) direction
block_size: sequence defining block size in x, y, (z) direction.
If chosen as zero the direction will not be used for blocking.
Returns:
number of dimensions blocked
......@@ -1270,8 +1057,10 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
body = ast_node.body
coordinates = []
coordinates_taken_into_account = 0
loop_starts = {}
loop_stops = {}
for loop in loops:
coord = loop.coordinate_to_loop_over
if coord not in coordinates:
......@@ -1280,11 +1069,14 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
loop_stops[coord] = loop.stop
else:
assert loop.start == loop_starts[coord] and loop.stop == loop_stops[coord], \
"Multiple loops over coordinate {} with different loop bounds".format(coord)
f"Multiple loops over coordinate {coord} with different loop bounds"
# Create the outer loops that iterate over the blocks
outer_loop = None
for coord in reversed(coordinates):
if block_size[coord] == 0:
continue
coordinates_taken_into_account += 1
body = ast.Block([outer_loop]) if outer_loop else body
outer_loop = ast.LoopOverCoordinate(body,
coord,
......@@ -1298,6 +1090,8 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
# modify the existing loops to only iterate within one block
for inner_loop in loops:
coord = inner_loop.coordinate_to_loop_over
if block_size[coord] == 0:
continue
block_ctr = ast.LoopOverCoordinate.get_block_loop_counter_symbol(coord)
loop_range = inner_loop.stop - inner_loop.start
if sp.sympify(
......@@ -1307,66 +1101,4 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
stop = sp.Min(inner_loop.stop, block_ctr + block_size[coord])
inner_loop.start = block_ctr
inner_loop.stop = stop
return len(coordinates)
def implement_interpolations(ast_node: ast.Node,
implement_by_texture_accesses: bool = False,
vectorize: bool = False,
use_hardware_interpolation_for_f32=True):
from pystencils.interpolation_astnodes import (InterpolatorAccess, TextureCachedField)
# TODO: perform this function on assignments, when unify_shape_symbols allows differently sized fields
assert not(implement_by_texture_accesses and vectorize), \
"can only implement interpolations either by texture accesses or CPU vectorization"
FLOAT32_T = create_type('float32')
interpolation_accesses = ast_node.atoms(InterpolatorAccess)
if not interpolation_accesses:
return ast_node
def can_use_hw_interpolation(i):
return (use_hardware_interpolation_for_f32
and implement_by_texture_accesses
and i.dtype == FLOAT32_T
and isinstance(i.symbol.interpolator, TextureCachedField))
if implement_by_texture_accesses:
for i in interpolation_accesses:
from pystencils.interpolation_astnodes import _InterpolationSymbol
try:
import pycuda.driver as cuda
texture = TextureCachedField.from_interpolator(i.interpolator)
if can_use_hw_interpolation(i):
texture.filter_mode = cuda.filter_mode.LINEAR
else:
texture.filter_mode = cuda.filter_mode.POINT
texture.read_as_integer = True
except Exception as e:
raise e
i.symbol = _InterpolationSymbol(str(texture), i.symbol.field, texture)
# from pystencils.math_optimizations import ReplaceOptim, optimize_ast
# ImplementInterpolationByStencils = ReplaceOptim(lambda e: isinstance(e, InterpolatorAccess)
# and not can_use_hw_interpolation(i),
# lambda e: e.implementation_with_stencils()
# )
# RemoveConjugate = ReplaceOptim(lambda e: isinstance(e, sp.conjugate),
# lambda e: e.args[0]
# )
if vectorize:
# TODO can be done in _interpolator_access_to_stencils field.absolute_access == simd_gather
raise NotImplementedError()
else:
substitutions = {i: i.implementation_with_stencils()
for i in interpolation_accesses if not can_use_hw_interpolation(i)}
if isinstance(ast_node, AssignmentCollection):
ast_node = ast_node.subs(substitutions)
else:
ast_node.subs(substitutions)
return ast_node
return coordinates_taken_into_account
from pystencils.typing.cast_functions import (CastFunc, BooleanCastFunc, VectorMemoryAccess, ReinterpretCastFunc,
PointerArithmeticFunc)
from pystencils.typing.types import (is_supported_type, numpy_name_to_c, AbstractType, BasicType, VectorType,
PointerType, StructType, create_type)
from pystencils.typing.typed_sympy import (assumptions_from_dtype, TypedSymbol, FieldStrideSymbol, FieldShapeSymbol,
FieldPointerSymbol, CFunction)
from pystencils.typing.utilities import (typed_symbols, get_base_type, result_type, collate_types,
get_type_of_expression, get_next_parent_of_type, parents_of_type)
__all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCastFunc', 'PointerArithmeticFunc',
'is_supported_type', 'numpy_name_to_c', 'AbstractType', 'BasicType',
'VectorType', 'PointerType', 'StructType', 'create_type', 'assumptions_from_dtype',
'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol', 'CFunction',
'typed_symbols', 'get_base_type', 'result_type', 'collate_types',
'get_type_of_expression', 'get_next_parent_of_type', 'parents_of_type']
import numpy as np
import sympy as sp
from sympy.logic.boolalg import Boolean
from pystencils.typing.types import AbstractType, BasicType
from pystencils.typing.typed_sympy import TypedSymbol
class CastFunc(sp.Function):
"""
CastFunc is used in order to introduce static casts. They are especially useful as a way to signal what type
a certain node should have, if it is impossible to add a type to a node, e.g. a sp.Number.
"""
is_Atom = True
def __new__(cls, *args, **kwargs):
if len(args) != 2:
pass
expr, dtype, *other_args = args
# If we have two consecutive casts, throw the inner one away.
# This optimisation is only available for simple casts. Thus the == is intended here!
if expr.__class__ == CastFunc:
expr = expr.args[0]
if not isinstance(dtype, AbstractType):
dtype = BasicType(dtype)
# to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
# however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
# to problems when for example comparing cast_func's for equality
#
# lhs = bitwise_and(a, cast_func(1, 'int'))
# rhs = cast_func(0, 'int')
# print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
# -> thus a separate class boolean_cast_func is introduced
if isinstance(expr, Boolean) and (not isinstance(expr, TypedSymbol) or expr.dtype == BasicType('bool')):
cls = BooleanCastFunc
return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
return self.args[0].canonical
else:
raise NotImplementedError()
@property
def is_commutative(self):
return self.args[0].is_commutative
@property
def dtype(self):
return self.args[1]
@property
def expr(self):
return self.args[0]
@property
def is_integer(self):
"""
Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate
For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
"""
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
else:
return super().is_integer
@property
def is_negative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if hasattr(self.dtype, 'numpy_dtype'):
if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
return False
return super().is_negative
@property
def is_nonnegative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if self.is_negative is False:
return True
else:
return super().is_nonnegative
@property
def is_real(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or np.issubdtype(self.dtype.numpy_dtype,
np.floating) or super().is_real
else:
return super().is_real
class BooleanCastFunc(CastFunc, Boolean):
# TODO: documentation
pass
class VectorMemoryAccess(CastFunc):
"""
Special memory access for vectorized kernel.
Arguments: read/write expression, type, aligned, non-temporal, mask (or none), stride
"""
nargs = (6,)
class ReinterpretCastFunc(CastFunc):
"""
Reinterpret cast is necessary for the StructType
"""
pass
class PointerArithmeticFunc(sp.Function, Boolean):
# TODO: documentation, or deprecate!
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
return self.args[0].canonical
else:
raise NotImplementedError()
from collections import namedtuple
from typing import Union, Tuple, Any, DefaultDict
import logging
import numpy as np
import sympy as sp
from sympy import Piecewise
from sympy.core.numbers import NegativeOne
from sympy.core.relational import Relational
from sympy.functions.elementary.piecewise import ExprCondPair
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
from sympy.functions.elementary.hyperbolic import HyperbolicFunction
from sympy.functions.elementary.integers import RoundFunction
from sympy.logic.boolalg import BooleanFunction
from sympy.logic.boolalg import BooleanAtom
from pystencils import astnodes as ast
from pystencils.functions import DivFunc, AddressOf
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.field import Field
from pystencils.typing.types import BasicType, PointerType
from pystencils.typing.utilities import collate_types
from pystencils.typing.cast_functions import CastFunc, BooleanCastFunc
from pystencils.typing.typed_sympy import TypedSymbol
from pystencils.fast_approximation import fast_sqrt, fast_division, fast_inv_sqrt
from pystencils.utils import ContextVar
class TypeAdder:
# TODO: specification -> jupyter notebook
"""Checks if the input to create_kernel is valid.
Test the following conditions:
- SSA Form for pure symbols:
- Every pure symbol may occur only once as left-hand-side of an assignment
- Every pure symbol that is read, may not be written to later
- Independence / Parallelization condition:
- a field that is written may only be read at exact the same spatial position
(Pure symbols are symbols that are not Field.Accesses)
"""
FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
def __init__(self, type_for_symbol: DefaultDict[str, BasicType], default_number_float: BasicType,
default_number_int: BasicType):
self.type_for_symbol = type_for_symbol
self.default_number_float = ContextVar(default_number_float)
self.default_number_int = ContextVar(default_number_int)
def visit(self, obj):
if isinstance(obj, (list, tuple)):
return [self.visit(e) for e in obj]
if isinstance(obj, ast.SympyAssignment):
return self.process_assignment(obj)
elif isinstance(obj, ast.Conditional):
condition, condition_type = self.figure_out_type(obj.condition_expr)
assert condition_type == BasicType('bool')
true_block = self.visit(obj.true_block)
false_block = None if obj.false_block is None else self.visit(
obj.false_block)
return ast.Conditional(condition, true_block=true_block, false_block=false_block)
elif isinstance(obj, ast.Block):
return ast.Block([self.visit(e) for e in obj.args])
elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
return obj
else:
raise ValueError("Invalid object in kernel " + str(type(obj)))
def process_assignment(self, assignment: ast.SympyAssignment) -> ast.SympyAssignment:
# for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
new_rhs, rhs_type = self.figure_out_type(assignment.rhs)
lhs = assignment.lhs
if not isinstance(lhs, (Field.Access, TypedSymbol)):
if isinstance(lhs, sp.Symbol):
self.type_for_symbol[lhs.name] = rhs_type
else:
raise ValueError(f'Lhs: `{lhs}` is not a subtype of sp.Symbol')
new_lhs, lhs_type = self.figure_out_type(lhs)
assert isinstance(new_lhs, (Field.Access, TypedSymbol))
if lhs_type != rhs_type:
logging.debug(f'Lhs"{new_lhs} of type "{lhs_type}" is assigned with a different datatype '
f'rhs: "{new_rhs}" of type "{rhs_type}".')
return ast.SympyAssignment(new_lhs, CastFunc(new_rhs, lhs_type), assignment.is_const, assignment.use_auto)
else:
return ast.SympyAssignment(new_lhs, new_rhs, assignment.is_const, assignment.use_auto)
# Type System Specification
# - Defined Types: TypedSymbol, Field, Field.Access, ...?
# - Indexed: always unsigned_integer64
# - Undefined Types: Symbol
# - Is specified in Config in the dict or as 'default_type' or behaves like `auto` in the case of lhs.
# - Constants/Numbers: Are either integer or floating. The precision and sign is specified via config
# - Example: 1.4 config:float32 -> float32
# - Expressions deduce types from arguments
# - Functions deduce types from arguments
# - default_type and default_float and default_int can be given for a list of assignment, or
# individually as a list for assignment
# Possible Problems - Do we need to support this?
# - Mixture in expression with int and float
# - Mixture in expression with uint64 and sint64
# TODO Logging: Lowest log level should log all casts ----> cast factory, make cast should contain logging
def figure_out_type(self, expr) -> Tuple[Any, Union[BasicType, PointerType]]:
# Trivial cases
from pystencils.field import Field
import pystencils.integer_functions
from pystencils.bit_masks import flag_cond
bool_type = BasicType('bool')
# TOOO: check the access
if isinstance(expr, Field.Access):
return expr, expr.dtype
elif isinstance(expr, TypedSymbol):
return expr, expr.dtype
elif isinstance(expr, sp.Symbol):
t = TypedSymbol(expr.name, self.type_for_symbol[expr.name])
return t, t.dtype
elif isinstance(expr, np.generic):
assert False, f'Why do we have a np.generic in rhs???? {expr}'
elif isinstance(expr, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)):
return expr, BasicType('float32') # see https://en.cppreference.com/w/cpp/numeric/math/INFINITY
elif isinstance(expr, sp.Number):
if expr.is_Integer:
data_type = self.default_number_int.get()
elif expr.is_Float or expr.is_Rational:
data_type = self.default_number_float.get()
else:
assert False, f'{sp.Number} is neither Float nor Integer'
return CastFunc(expr, data_type), data_type
elif isinstance(expr, AddressOf):
of = expr.args[0]
# TODO Basically this should do address_of already
assert isinstance(of, (Field.Access, TypedSymbol, Field))
return expr, PointerType(of.dtype)
elif isinstance(expr, BooleanAtom):
return expr, bool_type
elif isinstance(expr, Relational):
# TODO Jan: Code duplication with general case
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
if isinstance(expr, sp.Equality) and collated_type.is_float():
logging.warning(f"Using floating point numbers in equality comparison: {expr}")
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
new_eq = expr.func(*new_args)
return new_eq, bool_type
elif isinstance(expr, CastFunc):
new_expr, _ = self.figure_out_type(expr.expr)
return expr.func(*[new_expr, expr.dtype]), expr.dtype
elif isinstance(expr, ast.ConditionalFieldAccess):
access, access_type = self.figure_out_type(expr.access)
value, value_type = self.figure_out_type(expr.outofbounds_value)
condition, condition_type = self.figure_out_type(expr.outofbounds_condition)
assert condition_type == bool_type
collated_type = collate_types([access_type, value_type])
if collated_type == access_type:
new_access = access
else:
logging.warning(f"In {expr} the Field Access had to be casted to {collated_type}. This is "
f"probably due to a type missmatch of the Field and the value of "
f"ConditionalFieldAccess")
new_access = CastFunc(access, collated_type)
new_value = value if value_type == collated_type else CastFunc(value, collated_type)
return expr.func(new_access, condition, new_value), collated_type
elif isinstance(expr, (vec_any, vec_all)):
return expr, bool_type
elif isinstance(expr, BooleanFunction):
args_types = [self.figure_out_type(a) for a in expr.args]
new_args = [a if t.dtype_eq(bool_type) else BooleanCastFunc(a, bool_type) for a, t in args_types]
return expr.func(*new_args), bool_type
elif type(expr, ) in pystencils.integer_functions.__dict__.values() or isinstance(expr, sp.Mod):
args_types = [self.figure_out_type(a) for a in expr.args]
collated_type = collate_types([t for _, t in args_types])
# TODO: should we downcast to integer? If yes then which integer type?
if not collated_type.is_int():
raise ValueError(f"Integer functions or Modulo need to be used with integer types "
f"but {collated_type} was given")
return expr, collated_type
elif isinstance(expr, flag_cond):
# do not process the arguments to the bit shift - they must remain integers
args_types = [self.figure_out_type(a) for a in (expr.args[i] for i in range(2, len(expr.args)))]
collated_type = collate_types([t for _, t in args_types])
new_expressions = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
return expr.func(expr.args[0], expr.args[1], *new_expressions), collated_type
# elif isinstance(expr, sp.Mul):
# raise NotImplementedError('sp.Mul')
# # TODO can we ignore this and move it to general expr handling, i.e. removing Mul? (See todo in backend)
# # args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)]
elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label
return expr, typed_symbol.dtype
elif isinstance(expr, ExprCondPair):
expr_expr, expr_type = self.figure_out_type(expr.expr)
condition, condition_type = self.figure_out_type(expr.cond)
if condition_type != bool_type:
logging.warning(f'Condition "{condition}" is of type "{condition_type}" and not "bool"')
return expr.func(expr_expr, condition), expr_type
elif isinstance(expr, Piecewise):
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
new_args = []
for a, t in args_types:
if t != collated_type:
if isinstance(a, ExprCondPair):
new_args.append(a.func(CastFunc(a.expr, collated_type), a.cond))
else:
new_args.append(CastFunc(a, collated_type))
else:
new_args.append(a)
return expr.func(*new_args) if new_args else expr, collated_type
elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction,
HyperbolicFunction, sp.log, RoundFunction)):
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
new_func = expr.func(*new_args) if new_args else expr
if collated_type == BasicType('float64'):
return new_func, collated_type
else:
return CastFunc(new_func, collated_type), collated_type
elif isinstance(expr, (fast_sqrt, fast_division, fast_inv_sqrt)):
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = BasicType('float32')
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
new_func = expr.func(*new_args) if new_args else expr
return CastFunc(new_func, collated_type), collated_type
elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)):
# Subtraction is realised a multiplication with -1 in SymPy. Thus we exclude the coefficient in this case
# and resolve the typing entirely with the expression itself
if isinstance(expr, sp.Mul):
c, e = expr.as_coeff_Mul()
if c == NegativeOne():
args_types = self.figure_out_type(e)
new_args = [NegativeOne(), args_types[0]]
return expr.func(*new_args, evaluate=False), args_types[1]
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
if isinstance(collated_type, PointerType):
if isinstance(expr, sp.Add):
return expr.func(*[a for a, _ in args_types]), collated_type
else:
raise NotImplementedError(f'Pointer Arithmetic is implemented only for Add, not {expr}')
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
if isinstance(expr, (sp.Add, sp.Mul)):
return expr.func(*new_args, evaluate=False) if new_args else expr, collated_type
else:
return expr.func(*new_args) if new_args else expr, collated_type
else:
raise NotImplementedError(f'expr {type(expr)}: {expr} unknown to typing')
from typing import List
from pystencils.astnodes import Node
from pystencils.config import CreateKernelConfig
from pystencils.typing.leaf_typing import TypeAdder
def add_types(node_list: List[Node], config: CreateKernelConfig):
"""Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
The AST needs to be a pystencils AST. Thus, in the list of nodes every entry must be inherited from
`pystencils.astnodes.Node`
Additionally returns sets of all fields which are read/written
Args:
node_list: List of pystencils Nodes.
config: CreateKernelConfig
Returns:
``typed_equations`` list of equations where symbols have been replaced by typed symbols
"""
check = TypeAdder(type_for_symbol=config.data_type,
default_number_float=config.default_number_float,
default_number_int=config.default_number_int)
return check.visit(node_list)
"""Special symbols representing kernel parameters related to fields/arrays.
A `KernelFunction` node determines parameters that have to be passed to the function by searching for all undefined
symbols. Some symbols are not directly defined by the user, but are related to the `Field`s used in the kernel:
For each field a `FieldPointerSymbol` needs to be passed in, which is the pointer to the memory region where
the field is stored. This pointer is represented by the `FieldPointerSymbol` class that additionally stores the
name of the corresponding field. For fields where the size is not known at compile time, additionally shape and stride
information has to be passed in at runtime. These values are represented by `FieldShapeSymbol`
and `FieldPointerSymbol`.
The special symbols in this module store only the field name instead of a field reference. Storing a field reference
directly leads to problems with copying and pickling behaviour due to the circular dependency of `Field` and
e.g. `FieldShapeSymbol`, since a Field contains `FieldShapeSymbol`s in its shape, and a `FieldShapeSymbol`
would reference back to the field.
"""
from typing import Union
import numpy as np
import sympy as sp
from sympy.core.cache import cacheit
from pystencils.data_types import (
PointerType, TypedSymbol, create_composite_type_from_string, get_base_type)
from pystencils.typing.types import BasicType, create_type, PointerType
def assumptions_from_dtype(dtype: Union[BasicType, np.dtype]):
"""Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype
Args:
dtype (BasicType, np.dtype): a Numpy data type
Returns:
A dict of SymPy assumptions
"""
if hasattr(dtype, 'numpy_dtype'):
dtype = dtype.numpy_dtype
assumptions = dict()
try:
if np.issubdtype(dtype, np.integer):
assumptions.update({'integer': True})
if np.issubdtype(dtype, np.unsignedinteger):
assumptions.update({'negative': False})
if np.issubdtype(dtype, np.integer) or \
np.issubdtype(dtype, np.floating):
assumptions.update({'real': True})
except Exception: # TODO this is dirty
pass
return assumptions
class TypedSymbol(sp.Symbol):
def __new__(cls, *args, **kwds):
obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, name, dtype, **kwargs): # TODO does not match signature of sp.Symbol???
# TODO: also Symbol should be allowed ---> see sympy Variable
assumptions = assumptions_from_dtype(dtype)
assumptions.update(kwargs)
obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions)
try:
obj.numpy_dtype = create_type(dtype)
except (TypeError, ValueError):
# on error keep the string
obj.numpy_dtype = dtype
return obj
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
@property
def dtype(self):
return self.numpy_dtype
def _hashable_content(self):
return super()._hashable_content(), hash(self.numpy_dtype)
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), self.assumptions0
@property
def canonical(self):
return self
@property
def reversed(self):
return self
@property
def headers(self):
headers = []
try:
if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
try:
if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
SHAPE_DTYPE = create_composite_type_from_string("const int64")
STRIDE_DTYPE = create_composite_type_from_string("const int64")
return headers
SHAPE_DTYPE = BasicType('int64', const=True)
STRIDE_DTYPE = BasicType('int64', const=True)
class FieldStrideSymbol(TypedSymbol):
......@@ -29,7 +105,7 @@ class FieldStrideSymbol(TypedSymbol):
return obj
def __new_stage2__(cls, field_name, coordinate):
name = "_stride_{name}_{i}".format(name=field_name, i=coordinate)
name = f"_stride_{field_name}_{coordinate}"
obj = super(FieldStrideSymbol, cls).__xnew__(cls, name, STRIDE_DTYPE, positive=True)
obj.field_name = field_name
obj.coordinate = coordinate
......@@ -38,6 +114,9 @@ class FieldStrideSymbol(TypedSymbol):
def __getnewargs__(self):
return self.field_name, self.coordinate
def __getnewargs_ex__(self):
return (self.field_name, self.coordinate), {}
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
......@@ -54,7 +133,7 @@ class FieldShapeSymbol(TypedSymbol):
def __new_stage2__(cls, field_names, coordinate):
names = "_".join([field_name for field_name in field_names])
name = "_size_{names}_{i}".format(names=names, i=coordinate)
name = f"_size_{names}_{coordinate}"
obj = super(FieldShapeSymbol, cls).__xnew__(cls, name, SHAPE_DTYPE, positive=True)
obj.field_names = tuple(field_names)
obj.coordinate = coordinate
......@@ -63,6 +142,9 @@ class FieldShapeSymbol(TypedSymbol):
def __getnewargs__(self):
return self.field_names, self.coordinate
def __getnewargs_ex__(self):
return (self.field_names, self.coordinate), {}
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
......@@ -77,7 +159,9 @@ class FieldPointerSymbol(TypedSymbol):
return obj
def __new_stage2__(cls, field_name, field_dtype, const):
name = "_data_{name}".format(name=field_name)
from pystencils.typing.utilities import get_base_type
name = f"_data_{field_name}"
dtype = PointerType(get_base_type(field_dtype), const=const, restrict=True)
obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype)
obj.field_name = field_name
......@@ -86,8 +170,28 @@ class FieldPointerSymbol(TypedSymbol):
def __getnewargs__(self):
return self.field_name, self.dtype, self.dtype.const
def __getnewargs_ex__(self):
return (self.field_name, self.dtype, self.dtype.const), {}
def _hashable_content(self):
return super()._hashable_content(), self.field_name
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
class CFunction(TypedSymbol):
def __new__(cls, function, dtype):
return CFunction.__xnew_cached_(cls, function, dtype)
def __new_stage2__(cls, function, dtype):
return super(CFunction, cls).__xnew__(cls, function, dtype)
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), {}
from abc import abstractmethod
from typing import Union
import numpy as np
import sympy as sp
def is_supported_type(dtype: np.dtype):
scalar = dtype.type
c = np.issubdtype(dtype, np.generic)
subclass = issubclass(scalar, np.floating) or issubclass(scalar, np.integer) or issubclass(scalar, np.bool_)
additional_checks = dtype.fields is None and dtype.hasobject is False and dtype.subdtype is None
return c and subclass and additional_checks
def numpy_name_to_c(name: str) -> str:
"""
Converts a np.dtype.name into a C type
Args:
name: np.dtype.name string
Returns:
type as a C string
"""
if name == 'float64':
return 'double'
elif name == 'float32':
return 'float'
elif name == 'float16' or name == 'half':
return 'half'
elif name.startswith('int'):
width = int(name[len("int"):])
return f"int{width}_t"
elif name.startswith('uint'):
width = int(name[len("uint"):])
return f"uint{width}_t"
elif name == 'bool':
return 'bool'
else:
raise NotImplementedError(f"Can't map numpy to C name for {name}")
class AbstractType(sp.Atom):
# TODO: Is it necessary to ineherit from sp.Atom?
def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls)
def _sympystr(self, *args, **kwargs):
return str(self)
@property
@abstractmethod
def base_type(self) -> Union[None, 'BasicType']:
"""
Returns: Returns BasicType of a Vector or Pointer type, None otherwise
"""
pass
@property
@abstractmethod
def item_size(self) -> int:
"""
Returns: Number of items.
E.g. width * item_size(basic_type) in vector's case, or simple numpy itemsize in Struct's case.
"""
pass
class BasicType(AbstractType):
"""
BasicType is defined with a const qualifier and a np.dtype.
"""
def __init__(self, dtype: Union[type, 'BasicType', str], const: bool = False):
if isinstance(dtype, BasicType):
self.numpy_dtype = dtype.numpy_dtype
self.const = dtype.const
else:
self.numpy_dtype = np.dtype(dtype)
self.const = const
assert is_supported_type(self.numpy_dtype), f'Type {self.numpy_dtype} is currently not supported!'
def __getnewargs__(self):
return self.numpy_dtype, self.const
def __getnewargs_ex__(self):
return (self.numpy_dtype, self.const), {}
@property
def base_type(self):
return None
@property
def item_size(self): # TODO: Do we want self.numpy_type.itemsize????
return 1
def is_float(self):
return issubclass(self.numpy_dtype.type, np.floating)
def is_half(self):
return issubclass(self.numpy_dtype.type, np.half)
def is_int(self):
return issubclass(self.numpy_dtype.type, np.integer)
def is_uint(self):
return issubclass(self.numpy_dtype.type, np.unsignedinteger)
def is_sint(self):
return issubclass(self.numpy_dtype.type, np.signedinteger)
def is_bool(self):
return issubclass(self.numpy_dtype.type, np.bool_)
def dtype_eq(self, other):
if not isinstance(other, BasicType):
return False
else:
return self.numpy_dtype == other.numpy_dtype
@property
def c_name(self) -> str:
return numpy_name_to_c(self.numpy_dtype.name)
def __str__(self):
return f'{self.c_name}{" const" if self.const else ""}'
def __repr__(self):
return f'BasicType( {str(self)} )'
def _repr_html_(self):
return f'BasicType( {str(self)} )'
def __eq__(self, other):
return self.dtype_eq(other) and self.const == other.const
def __hash__(self):
return hash(str(self))
class VectorType(AbstractType):
"""
VectorType consists of a BasicType and a width.
"""
instruction_set = None
def __init__(self, base_type: BasicType, width: int):
self._base_type = base_type
self.width = width
@property
def base_type(self):
return self._base_type
@property
def item_size(self):
return self.width * self.base_type.item_size
def __eq__(self, other):
if not isinstance(other, VectorType):
return False
else:
return (self.base_type, self.width) == (other.base_type, other.width)
def __str__(self):
if self.instruction_set is None:
return f"{self.base_type}[{self.width}]"
else:
# TODO VectorizationRevamp: this seems super weird. the instruction_set should know how to print a type out!
# TODO VectorizationRevamp: this is error prone. base_type could be cons=True. Use dtype instead
if self.base_type == create_type("int64") or self.base_type == create_type("int32"):
return self.instruction_set['int']
elif self.base_type == create_type("float64"):
return self.instruction_set['double']
elif self.base_type == create_type("float32"):
return self.instruction_set['float']
elif self.base_type == create_type("bool"):
return self.instruction_set['bool']
else:
raise NotImplementedError()
def __hash__(self):
return hash((self.base_type, self.width))
def __getnewargs__(self):
return self._base_type, self.width
def __getnewargs_ex__(self):
return (self._base_type, self.width), {}
class PointerType(AbstractType):
def __init__(self, base_type: BasicType, const: bool = False, restrict: bool = True, double_pointer: bool = False):
self._base_type = base_type
self.const = const
self.restrict = restrict
self.double_pointer = double_pointer
def __getnewargs__(self):
return self.base_type, self.const, self.restrict, self.double_pointer
def __getnewargs_ex__(self):
return (self.base_type, self.const, self.restrict, self.double_pointer), {}
@property
def alias(self):
return not self.restrict
@property
def base_type(self):
return self._base_type
@property
def item_size(self):
if self.double_pointer:
raise NotImplementedError("The item_size for double_pointer is not implemented")
else:
return self.base_type.item_size
def __eq__(self, other):
if not isinstance(other, PointerType):
return False
else:
own = (self.base_type, self.const, self.restrict, self.double_pointer)
return own == (other.base_type, other.const, other.restrict, other.double_pointer)
def __str__(self):
restrict_str = "RESTRICT" if self.restrict else ""
const_str = "const" if self.const else ""
if self.double_pointer:
return f'{str(self.base_type)} ** {restrict_str} {const_str}'
else:
return f'{str(self.base_type)} * {restrict_str} {const_str}'
def __repr__(self):
return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self):
return hash((self._base_type, self.const, self.restrict, self.double_pointer))
class StructType(AbstractType):
"""
A list of types (with C offsets).
It is implemented with uint8_t and casts to the correct datatype.
"""
def __init__(self, numpy_type, const=False):
self.const = const
self._dtype = np.dtype(numpy_type)
def __getnewargs__(self):
return self.numpy_dtype, self.const
def __getnewargs_ex__(self):
return (self.numpy_dtype, self.const), {}
@property
def base_type(self):
return None
@property
def numpy_dtype(self):
return self._dtype
@property
def item_size(self):
return self.numpy_dtype.itemsize
def get_element_offset(self, element_name):
return self.numpy_dtype.fields[element_name][1]
def get_element_type(self, element_name):
np_element_type = self.numpy_dtype.fields[element_name][0]
return BasicType(np_element_type, self.const)
def has_element(self, element_name):
return element_name in self.numpy_dtype.fields
def __eq__(self, other):
if not isinstance(other, StructType):
return False
else:
return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
def __str__(self):
# structs are handled byte-wise
result = "uint8_t"
if self.const:
result += " const"
return result
def __repr__(self):
return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self):
return hash((self.numpy_dtype, self.const))
def create_type(specification: Union[type, AbstractType, str]) -> AbstractType:
# TODO: Deprecated Use the constructor of BasicType or StructType instead
"""Creates a subclass of Type according to a string or an object of subclass Type.
Args:
specification: Type object, or a string
Returns:
Type object, or a new Type object parsed from the string
"""
if isinstance(specification, AbstractType):
return specification
else:
numpy_dtype = np.dtype(specification)
if numpy_dtype.fields is None:
return BasicType(numpy_dtype, const=False)
else:
return StructType(numpy_dtype, const=False)
from collections import defaultdict
from functools import partial
from typing import Tuple, Union, Sequence
import numpy as np
import sympy as sp
from sympy.logic.boolalg import Boolean, BooleanFunction
import pystencils
from pystencils.cache import memorycache_if_hashable
from pystencils.typing.types import BasicType, VectorType, PointerType, create_type
from pystencils.typing.cast_functions import CastFunc
from pystencils.typing.typed_sympy import TypedSymbol
from pystencils.utils import all_equal
def typed_symbols(names, dtype, **kwargs):
"""
Creates TypedSymbols with the same functionality as sympy.symbols
Args:
names: See sympy.symbols
dtype: The data type all symbols will have
**kwargs: Key value arguments passed to sympy.symbols
Returns:
TypedSymbols
"""
symbols = sp.symbols(names, **kwargs)
if isinstance(symbols, Tuple):
return tuple(TypedSymbol(str(s), dtype) for s in symbols)
else:
return TypedSymbol(str(symbols), dtype)
def get_base_type(data_type):
"""
Returns the BasicType of a Pointer or a Vector
"""
while data_type.base_type is not None:
data_type = data_type.base_type
return data_type
def result_type(*args: np.dtype):
"""Returns the type of the result if the np.dtype arguments would be collated.
We can't use numpy functionality, because numpy casts don't behave exactly like C casts"""
s = sorted(args, key=lambda x: x.itemsize)
def kind_to_value(kind: str) -> int:
if kind == 'f':
return 3
elif kind == 'i':
return 2
elif kind == 'u':
return 1
elif kind == 'b':
return 0
else:
raise NotImplementedError(f'{kind=} is not a supported kind of a type. See "numpy.dtype.kind" for options')
s = sorted(s, key=lambda x: kind_to_value(x.kind))
return s[-1]
def collate_types(types: Sequence[Union[BasicType, VectorType]]):
"""
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy.
"""
# Pointer arithmetic case i.e. pointer + [int, uint] is allowed
if any(isinstance(t, PointerType) for t in types):
pointer_type = None
for t in types:
if isinstance(t, PointerType):
if pointer_type is not None:
raise ValueError(f'Cannot collate the combination of two pointer types "{pointer_type}" and "{t}"')
pointer_type = t
elif isinstance(t, BasicType):
if not (t.is_int() or t.is_uint()):
raise ValueError("Invalid pointer arithmetic")
else:
raise ValueError("Invalid pointer arithmetic")
return pointer_type
# # peel of vector types, if at least one vector type occurred the result will also be the vector type
vector_type = [t for t in types if isinstance(t, VectorType)]
if not all_equal(t.width for t in vector_type):
raise ValueError("Collation failed because of vector types with different width")
# TODO: check if this is needed
# def peel_off_type(dtype, type_to_peel_off):
# while type(dtype) is type_to_peel_off:
# dtype = dtype.base_type
# return dtype
# types = [peel_off_type(t, VectorType) for t in types]
types = [t.base_type if isinstance(t, VectorType) else t for t in types]
# now we should have a list of basic types - struct types are not yet supported
assert all(type(t) is BasicType for t in types)
result_numpy_type = result_type(*(t.numpy_dtype for t in types))
result = BasicType(result_numpy_type)
if vector_type:
result = VectorType(result, vector_type[0].width)
return result
# TODO get_type_of_expression should be used after leaf_typing. So no defaults should be necessary
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
default_float_type='double',
default_int_type='int',
symbol_type_dict=None):
from pystencils.astnodes import ResolvedFieldAccess
from pystencils.cpu.vectorization import vec_all, vec_any
if default_float_type == 'float':
default_float_type = 'float32'
if not symbol_type_dict:
symbol_type_dict = defaultdict(lambda: create_type('double'))
# TODO this line is quite hard to understand, if possible simpl
get_type = partial(get_type_of_expression,
default_float_type=default_float_type,
default_int_type=default_int_type,
symbol_type_dict=symbol_type_dict)
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
return create_type(default_int_type)
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type(default_float_type)
elif isinstance(expr, ResolvedFieldAccess):
return expr.field.dtype
elif isinstance(expr, pystencils.field.Field.Access):
return expr.field.dtype
elif isinstance(expr, TypedSymbol):
return expr.dtype
elif isinstance(expr, sp.Symbol):
# TODO delete if case
if symbol_type_dict:
return symbol_type_dict[expr.name]
else:
raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
elif isinstance(expr, CastFunc):
return expr.args[1]
elif isinstance(expr, (vec_any, vec_all)):
return create_type("bool")
elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
return collated_result_type
elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label
return typed_symbol.dtype.base_type
elif isinstance(expr, (Boolean, BooleanFunction)):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool")
vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
if vec_args:
result = VectorType(result, width=vec_args[0].width)
return result
elif isinstance(expr, sp.Pow):
base_type = get_type(expr.args[0])
if expr.exp.is_integer:
return base_type
else:
return collate_types([create_type(default_float_type), base_type])
elif isinstance(expr, (sp.Sum, sp.Product)):
return get_type(expr.args[0])
elif isinstance(expr, sp.Expr):
expr: sp.Expr
if expr.args:
types = tuple(get_type(a) for a in expr.args)
return collate_types(types)
else:
if expr.is_integer:
return create_type(default_int_type)
else:
return create_type(default_float_type)
raise NotImplementedError("Could not determine type for", expr, type(expr))
# Fix for sympy versions from 1.9
sympy_version = sp.__version__.split('.')
sympy_version_int = int(sympy_version[0]) * 100 + int(sympy_version[1])
if sympy_version_int >= 109:
# __setstate__ would bypass the contructor, so we remove it
if sympy_version_int >= 111:
del sp.Basic.__setstate__
del sp.Symbol.__setstate__
else:
sp.Number.__getstate__ = sp.Basic.__getstate__
del sp.Basic.__getstate__
# __reduce_ex__ would strip kwargs, so we override it
def basic_reduce_ex(self, protocol):
if hasattr(self, '__getnewargs_ex__'):
args, kwargs = self.__getnewargs_ex__()
else:
args, kwargs = self.__getnewargs__(), {}
if hasattr(self, '__getstate__'):
state = self.__getstate__()
else:
state = None
return partial(type(self), **kwargs), args, state
sp.Basic.__reduce_ex__ = basic_reduce_ex
def get_next_parent_of_type(node, parent_type):
"""Returns the next parent node of given type or None, if root is reached.
Traverses the AST nodes parents until a parent of given type was found.
If no such parent is found, None is returned
"""
parent = node.parent
while parent is not None:
if isinstance(parent, parent_type):
return parent
parent = parent.parent
return None
def parents_of_type(node, parent_type, include_current=False):
"""Generator for all parent nodes of given type"""
parent = node if include_current else node.parent
while parent is not None:
if isinstance(parent, parent_type):
yield parent
parent = parent.parent
import os
import itertools
from itertools import groupby
from collections import Counter
from contextlib import contextmanager
from tempfile import NamedTemporaryFile
......@@ -14,14 +16,21 @@ class DotDict(dict):
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
# Recursively make DotDict: https://stackoverflow.com/questions/13520421/recursive-dotdict
def __init__(self, dct={}):
for key, value in dct.items():
if isinstance(value, dict):
value = DotDict(value)
self[key] = value
def all_equal(iterator):
iterator = iter(iterator)
try:
first = next(iterator)
except StopIteration:
return True
return all(first == rest for rest in iterator)
def all_equal(iterable):
"""
Returns ``True`` if all the elements are equal to each other.
Copied from: more-itertools 8.12.0
"""
g = groupby(iterable)
return next(g, True) and not next(g, False)
def recursive_dict_update(d, u):
......@@ -43,33 +52,13 @@ def recursive_dict_update(d, u):
return d
@contextmanager
def file_handle_for_atomic_write(file_path):
"""Open temporary file object that atomically moves to destination upon exiting.
Allows reading and writing to and from the same filename.
The file will not be moved to destination in case of an exception.
Args:
file_path: path to file to be opened
"""
target_folder = os.path.dirname(os.path.abspath(file_path))
with NamedTemporaryFile(delete=False, dir=target_folder, mode='w') as f:
try:
yield f
finally:
f.flush()
os.fsync(f.fileno())
os.rename(f.name, file_path)
@contextmanager
def atomic_file_write(file_path):
target_folder = os.path.dirname(os.path.abspath(file_path))
with NamedTemporaryFile(delete=False, dir=target_folder) as f:
f.file.close()
yield f.name
os.rename(f.name, file_path)
os.replace(f.name, file_path)
def fully_contains(l1, l2):
......@@ -89,19 +78,39 @@ def fully_contains(l1, l2):
def boolean_array_bounding_box(boolean_array):
"""Returns bounding box around "true" area of boolean array"""
dim = len(boolean_array.shape)
"""Returns bounding box around "true" area of boolean array
>>> a = np.zeros((4, 4), dtype=bool)
>>> a[1:-1, 1:-1] = True
>>> boolean_array_bounding_box(a) == [(1, 3), (1, 3)]
True
"""
dim = boolean_array.ndim
shape = boolean_array.shape
assert 0 not in shape, "Shape must not contain zero"
bounds = []
for i in range(dim):
for j in range(dim):
if i != j:
arr_1d = np.any(boolean_array, axis=j)
begin = np.argmax(arr_1d)
end = begin + np.argmin(arr_1d[begin:])
bounds.append((begin, end))
for ax in itertools.combinations(reversed(range(dim)), dim - 1):
nonzero = np.any(boolean_array, axis=ax)
t = np.where(nonzero)[0][[0, -1]]
bounds.append((t[0], t[1] + 1))
return bounds
def binary_numbers(n):
"""Returns all binary numbers up to 2^n - 1
Example:
>>> binary_numbers(2)
[[0, 0], [0, 1], [1, 0], [1, 1]]
"""
result = list()
for i in range(1 << n):
binary_number = bin(i)[2:]
binary_number = '0' * (n - len(binary_number)) + binary_number
result.append((list(map(int, binary_number))))
return result
class LinearEquationSystem:
"""Symbolic linear system of equations - consisting of matrix and right hand side.
......@@ -210,7 +219,8 @@ class LinearEquationSystem:
return 'multiple'
def solution(self):
"""Solves the system if it has a single solution. Returns a dictionary mapping symbol to solution value."""
"""Solves the system. Under- and overdetermined systems are supported.
Returns a dictionary mapping symbol to solution value."""
return sp.solve_linear_system(self._matrix, *self.unknowns)
def _resize_if_necessary(self, new_rows=1):
......@@ -228,6 +238,15 @@ class LinearEquationSystem:
self.next_zero_row = result
def find_unique_solutions_with_zeros(system: LinearEquationSystem):
if not system.solution_structure() != 'multiple':
raise ValueError("Function works only for underdetermined systems")
class ContextVar:
def __init__(self, value):
self.stack = [value]
@contextmanager
def __call__(self, new_value):
self.stack.append(new_value)
yield self
self.stack.pop()
def get(self):
return self.stack[-1]
import pytest
import sympy as sp
import numpy
import pystencils
from pystencils.datahandling import create_data_handling
@pytest.mark.parametrize('dtype', ["float64", "float32"])
@pytest.mark.parametrize('sympy_function', [sp.Min, sp.Max])
def test_max(dtype, sympy_function):
dh = create_data_handling(domain_size=(10, 10), periodicity=True)
x = dh.add_array('x', values_per_cell=1, dtype=dtype)
dh.fill("x", 0.0, ghost_layers=True)
y = dh.add_array('y', values_per_cell=1, dtype=dtype)
dh.fill("y", 1.0, ghost_layers=True)
z = dh.add_array('z', values_per_cell=1, dtype=dtype)
dh.fill("z", 2.0, ghost_layers=True)
config = pystencils.CreateKernelConfig(default_number_float=dtype)
# test sp.Max with one argument
assignment_1 = pystencils.Assignment(x.center, sympy_function(y.center + 3.3))
ast_1 = pystencils.create_kernel(assignment_1, config=config)
kernel_1 = ast_1.compile()
# pystencils.show_code(ast_1)
# test sp.Max with two arguments
assignment_2 = pystencils.Assignment(x.center, sympy_function(0.5, y.center - 1.5))
ast_2 = pystencils.create_kernel(assignment_2, config=config)
kernel_2 = ast_2.compile()
# pystencils.show_code(ast_2)
# test sp.Max with many arguments
assignment_3 = pystencils.Assignment(x.center, sympy_function(z.center, 4.5, y.center - 1.5, y.center + z.center))
ast_3 = pystencils.create_kernel(assignment_3, config=config)
kernel_3 = ast_3.compile()
# pystencils.show_code(ast_3)
if sympy_function is sp.Max:
results = [4.3, 0.5, 4.5]
else:
results = [4.3, -0.5, -0.5]
dh.run_kernel(kernel_1)
assert numpy.all(dh.gather_array('x') == results[0])
dh.run_kernel(kernel_2)
assert numpy.all(dh.gather_array('x') == results[1])
dh.run_kernel(kernel_3)
assert numpy.all(dh.gather_array('x') == results[2])
@pytest.mark.parametrize('dtype', ["int64", 'int32'])
@pytest.mark.parametrize('sympy_function', [sp.Min, sp.Max])
def test_max_integer(dtype, sympy_function):
dh = create_data_handling(domain_size=(10, 10), periodicity=True)
x = dh.add_array('x', values_per_cell=1, dtype=dtype)
dh.fill("x", 0, ghost_layers=True)
y = dh.add_array('y', values_per_cell=1, dtype=dtype)
dh.fill("y", 1, ghost_layers=True)
z = dh.add_array('z', values_per_cell=1, dtype=dtype)
dh.fill("z", 2, ghost_layers=True)
config = pystencils.CreateKernelConfig(default_number_int=dtype)
# test sp.Max with one argument
assignment_1 = pystencils.Assignment(x.center, sympy_function(y.center + 3))
ast_1 = pystencils.create_kernel(assignment_1, config=config)
kernel_1 = ast_1.compile()
# pystencils.show_code(ast_1)
# test sp.Max with two arguments
assignment_2 = pystencils.Assignment(x.center, sympy_function(1, y.center - 1))
ast_2 = pystencils.create_kernel(assignment_2, config=config)
kernel_2 = ast_2.compile()
# pystencils.show_code(ast_2)
# test sp.Max with many arguments
assignment_3 = pystencils.Assignment(x.center, sympy_function(z.center, 4, y.center - 1, y.center + z.center))
ast_3 = pystencils.create_kernel(assignment_3, config=config)
kernel_3 = ast_3.compile()
# pystencils.show_code(ast_3)
if sympy_function is sp.Max:
results = [4, 1, 4]
else:
results = [4, 0, 0]
dh.run_kernel(kernel_1)
assert numpy.all(dh.gather_array('x') == results[0])
dh.run_kernel(kernel_2)
assert numpy.all(dh.gather_array('x') == results[1])
dh.run_kernel(kernel_3)
assert numpy.all(dh.gather_array('x') == results[2])
import pytest
import pystencils.config
import sympy
import pystencils as ps
from pystencils.typing import CastFunc, create_type
@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
def test_abs(target):
x, y, z = ps.fields('x, y, z: float64[2d]')
default_int_type = create_type('int64')
assignments = ps.AssignmentCollection({x[0, 0]: sympy.Abs(CastFunc(y[0, 0], default_int_type))})
config = pystencils.config.CreateKernelConfig(target=target)
ast = ps.create_kernel(assignments, config=config)
code = ps.get_code_str(ast)
print(code)
assert 'fabs(' not in code