Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 335 additions and 109 deletions
......@@ -6,10 +6,12 @@ from functools import partial, reduce
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union
import sympy as sp
from sympy import PolynomialError
from sympy.functions import Abs
from sympy.core.numbers import Zero
from pystencils.assignment import Assignment
from pystencils.functions import DivFunc
from pystencils.typing import CastFunc, get_type_of_expression, PointerType, VectorType
from pystencils.typing.typed_sympy import FieldPointerSymbol
......@@ -158,17 +160,23 @@ def fast_subs(expression: T, substitutions: Dict,
if type(expression) is sp.Matrix:
return expression.copy().applyfunc(partial(fast_subs, substitutions=substitutions))
def visit(expr):
def visit(expr, evaluate=True):
if skip and skip(expr):
return expr
if hasattr(expr, "fast_subs"):
elif hasattr(expr, "fast_subs"):
return expr.fast_subs(substitutions, skip)
if expr in substitutions:
elif expr in substitutions:
return substitutions[expr]
if not hasattr(expr, 'args'):
elif not hasattr(expr, 'args'):
return expr
param_list = [visit(a) for a in expr.args]
return expr if not param_list else expr.func(*param_list)
elif isinstance(expr, (sp.UnevaluatedExpr, DivFunc)):
args = [visit(a, False) for a in expr.args]
return expr.func(*args)
else:
param_list = [visit(a, evaluate) for a in expr.args]
if isinstance(expr, (sp.Mul, sp.Add)):
return expr if not param_list else expr.func(*param_list, evaluate=evaluate)
return expr if not param_list else expr.func(*param_list)
if len(substitutions) == 0:
return expression
......@@ -348,7 +356,7 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order
factor_count = 0
if type(product) is Mul:
for factor in product.args:
if type(factor) == Pow:
if type(factor) is Pow:
if factor.args[0] in symbols:
factor_count += factor.args[1]
if factor in symbols:
......@@ -358,13 +366,13 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order
factor_count += product.args[1]
return factor_count
if type(expr) == Mul or type(expr) == Pow:
if type(expr) is Mul or type(expr) is Pow:
if velocity_factors_in_product(expr) <= order:
return expr
else:
return Zero()
if type(expr) != Add:
if type(expr) is not Add:
return expr
for sum_term in expr.args:
......@@ -435,11 +443,14 @@ def extract_most_common_factor(term):
def recursive_collect(expr, symbols, order_by_occurences=False):
"""Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1,
"""Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1,
and so on.
``expr`` must be rewritable as a polynomial in the given ``symbols``.
It it is not, ``recursive_collect`` will fail quietly, returning the original expression.
Args:
expr: A sympy expression
expr: A sympy expression.
symbols: A sequence of symbols
order_by_occurences: If True, during recursive descent, always collect the symbol occuring
most often in the expression.
......@@ -450,7 +461,13 @@ def recursive_collect(expr, symbols, order_by_occurences=False):
if len(symbols) == 0:
return expr
symbol = symbols[0]
collected_poly = sp.Poly(expr.collect(symbol), symbol)
collected = expr.collect(symbol)
try:
collected_poly = sp.Poly(collected, symbol)
except PolynomialError:
return expr
coeffs = collected_poly.all_coeffs()[::-1]
rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs))
return rec_sum
......@@ -622,8 +639,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
for child_term, condition in t.args:
visit(child_term)
visit_children = False
elif isinstance(t, sp.Rel):
elif isinstance(t, (sp.Rel, sp.UnevaluatedExpr)):
pass
elif isinstance(t, DivFunc):
result["divs"] += 1
else:
warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate")
......
File moved
......@@ -4,15 +4,18 @@ import warnings
from collections import OrderedDict
from copy import deepcopy
from types import MappingProxyType
from typing import Set
import sympy as sp
import pystencils as ps
import pystencils.astnodes as ast
from pystencils.assignment import Assignment
from pystencils.typing import (
PointerType, StructType, TypedSymbol, get_base_type, ReinterpretCastFunc, get_next_parent_of_type, parents_of_type)
from pystencils.typing import (CastFunc, PointerType, StructType, TypedSymbol, get_base_type,
ReinterpretCastFunc, get_next_parent_of_type, parents_of_type)
from pystencils.field import Field, FieldType
from pystencils.typing import FieldPointerSymbol
from pystencils.sympyextensions import fast_subs
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.slicing import normalize_slice
from pystencils.integer_functions import int_div
......@@ -97,6 +100,45 @@ def generic_visit(term, visitor):
return visitor(term)
def iterate_loops_by_depth(node, nesting_depth):
"""Iterate all LoopOverCoordinate nodes in the given AST of the specified nesting depth.
Args:
node: Root node of the abstract syntax tree
nesting_depth: Nesting depth of the loops the pragmas should be applied to.
Outermost loop has depth 0.
A depth of -1 indicates the innermost loops.
Returns: Iterable listing all loop nodes of given nesting depth.
"""
from pystencils.astnodes import LoopOverCoordinate
def _internal_default(node, nesting_depth):
isloop = isinstance(node, LoopOverCoordinate)
if nesting_depth < 0: # here, a negative value indicates end of descent
return
elif nesting_depth == 0 and isloop:
yield node
else:
next_depth = nesting_depth - 1 if isloop else nesting_depth
for arg in node.args:
yield from _internal_default(arg, next_depth)
def _internal_innermost(node):
if isinstance(node, LoopOverCoordinate) and node.is_innermost_loop:
yield node
else:
for arg in node.args:
yield from _internal_innermost(arg)
if nesting_depth >= 0:
yield from _internal_default(node, nesting_depth)
elif nesting_depth == -1:
yield from _internal_innermost(node)
else:
raise ValueError(f"Invalid nesting depth: {nesting_depth}. Choose a nonnegative number, or -1.")
def unify_shape_symbols(body, common_shape, fields):
"""Replaces symbols for array sizes to ensure they are represented by the same unique symbol.
......@@ -121,9 +163,10 @@ def unify_shape_symbols(body, common_shape, fields):
body.subs(substitutions)
def get_common_shape(field_set):
"""Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
ValueError is raised"""
def get_common_field(field_set):
"""Takes a set of pystencils Fields, checks if a common spatial shape exists and returns one
representative field, that can be used for shape information etc. in the kernel creation.
If the fields have different shapes ValueError is raised"""
nr_of_fixed_shaped_fields = 0
for f in field_set:
if f.has_fixed_shape:
......@@ -141,8 +184,9 @@ def get_common_shape(field_set):
if len(shape_set) != 1:
raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0]
return shape
# Sort the fields by their name to ensure that always the same field is returned
reference_field = sorted(field_set, key=lambda e: str(e))[0]
return reference_field
def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None):
......@@ -160,9 +204,11 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
tuple of loop-node, ghost_layer_info
"""
# find correct ordering by inspecting participating FieldAccesses
absolut_accesses_only = False
field_accesses = body.atoms(Field.Access)
field_accesses = {e for e in field_accesses if not e.is_absolute_access}
if len(field_accesses) == 0: # when kernel contains only absolute accesses
absolut_accesses_only = True
# exclude accesses to buffers from field_list, because buffers are treated separately
field_list = [e.field for e in field_accesses if not (FieldType.is_buffer(e.field) or FieldType.is_custom(e.field))]
if len(field_list) == 0: # when kernel contains only custom fields
......@@ -173,14 +219,23 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
if loop_order is None:
loop_order = get_optimal_loop_ordering(fields)
shape = get_common_shape(fields)
unify_shape_symbols(body, common_shape=shape, fields=fields)
if absolut_accesses_only:
absolut_access_fields = {e.field for e in body.atoms(Field.Access)}
common_field = get_common_field(absolut_access_fields)
common_shape = common_field.spatial_shape
else:
common_field = get_common_field(fields)
common_shape = common_field.spatial_shape
unify_shape_symbols(body, common_shape=common_shape, fields=fields)
if iteration_slice is not None:
iteration_slice = normalize_slice(iteration_slice, shape)
iteration_slice = normalize_slice(iteration_slice, common_shape)
if ghost_layers is None:
required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
if absolut_accesses_only:
required_ghost_layers = 0
else:
required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
if isinstance(ghost_layers, int):
ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)
......@@ -189,7 +244,7 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
for i, loop_coordinate in enumerate(reversed(loop_order)):
if iteration_slice is None:
begin = ghost_layers[loop_coordinate][0]
end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
end = common_shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
current_body = ast.Block([new_loop])
else:
......@@ -206,6 +261,28 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
return current_body, ghost_layers
def get_common_indexed_element(indexed_elements: Set[sp.IndexedBase]) -> sp.IndexedBase:
assert len(indexed_elements) > 0, "indexed_elements can not be empty"
shape_set = {s.shape for s in indexed_elements}
if len(shape_set) != 1:
for shape in shape_set:
assert not isinstance(shape, int), "If indexed elements are used, they must all have the same shape"
return sorted(indexed_elements, key=lambda e: str(e))[0]
def add_outer_loop_over_indexed_elements(loop_node: ast.Block) -> ast.Block:
indexed_elements = loop_node.atoms(sp.Indexed)
if len(indexed_elements) == 0:
return loop_node
reference_element = get_common_indexed_element(indexed_elements)
index = reference_element.indices[0].atoms(TypedSymbol)
assert len(index) == 1, "index expressions must only contain one symbol representing the index"
new_loop = ast.LoopOverCoordinate(loop_node, 0, 0,
reference_element.shape[0], 1, custom_loop_ctr=index.pop())
return ast.Block([new_loop])
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
r"""
Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
......@@ -341,7 +418,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
ast_node: ast before any field accesses are resolved
loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
for GPU kernels: list of 'loop counters' from inner to outer loop
loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default
loop_iterations: iteration slice for each loop from inner to outer, for CPU kernels leave to default
Returns:
base buffer index - required by 'resolve_buffer_accesses' function
......@@ -353,15 +430,25 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
assert len(loops) == len(parents_of_innermost_loop)
assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
actual_sizes = [int_div((loop.stop - loop.start), loop.step)
if loop.step != 1 else loop.stop - loop.start for loop in loops]
loop_counters = [loop.loop_counter_symbol for loop in loops]
loop_iterations = [slice(loop.start, loop.stop, loop.step) for loop in loops]
actual_steps = [int_div((loop.loop_counter_symbol - loop.start), loop.step)
if loop.step != 1 else loop.loop_counter_symbol - loop.start for loop in loops]
actual_sizes = list()
actual_steps = list()
for ctr, s in zip(loop_counters, loop_iterations):
if s.step != 1:
if (s.stop - s.start) % s.step == 0:
actual_sizes.append((s.stop - s.start) // s.step)
else:
actual_sizes.append(int_div((s.stop - s.start), s.step))
else:
actual_sizes = loop_iterations
actual_steps = loop_counters
if (ctr - s.start) % s.step == 0:
actual_steps.append((ctr - s.start) // s.step)
else:
actual_steps.append(int_div((ctr - s.start), s.step))
else:
actual_sizes.append(s.stop - s.start)
actual_steps.append(ctr - s.start)
field_accesses = ast_node.atoms(Field.Access)
buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
......@@ -505,7 +592,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=None,
coord_dict = create_coordinate_dict(group)
new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
if new_ptr not in enclosing_block.symbols_defined:
new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False)
new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False, use_auto=False)
enclosing_block.insert_before(new_assignment, sympy_assignment)
last_pointer = new_ptr
......@@ -569,21 +656,65 @@ def move_constants_before_loop(ast_node):
"""
assert isinstance(node.parent, ast.Block)
def modifies_or_declares(node: ast.Node, symbol_names: Set[str]) -> bool:
if isinstance(node, (ps.Assignment, ast.SympyAssignment)):
if isinstance(node.lhs, ast.ResolvedFieldAccess):
return node.lhs.typed_symbol.name in symbol_names
else:
return node.lhs.name in symbol_names
elif isinstance(node, ast.Block):
for arg in node.args:
if isinstance(arg, ast.SympyAssignment) and arg.is_declaration:
continue
if modifies_or_declares(arg, symbol_names):
return True
return False
elif isinstance(node, ast.LoopOverCoordinate):
return modifies_or_declares(node.body, symbol_names)
elif isinstance(node, ast.Conditional):
return (
modifies_or_declares(node.true_block, symbol_names)
or (node.false_block and modifies_or_declares(node.false_block, symbol_names))
)
elif isinstance(node, ast.KernelFunction):
return False
else:
defs = {s.name for s in node.symbols_defined}
return bool(symbol_names.intersection(defs))
dependencies = {s.name for s in node.undefined_symbols}
last_block = node.parent
last_block_child = node
element = node.parent
prev_element = node
while element:
if isinstance(element, ast.Block):
if isinstance(element, (ast.Conditional, ast.KernelFunction)):
# Never move out of Conditionals or KernelFunctions.
break
elif isinstance(element, ast.Block):
last_block = element
last_block_child = prev_element
if isinstance(element, ast.Conditional):
break
if any(modifies_or_declares(sibling, dependencies) for sibling in element.args):
# The node depends on one of the statements in this block.
# Do not move further out.
break
elif isinstance(element, ast.LoopOverCoordinate):
if element.loop_counter_symbol.name in dependencies:
# The node depends on the loop counter.
# Do not move out of this loop.
break
else:
critical_symbols = set([s.name for s in element.symbols_defined])
if set([s.name for s in node.undefined_symbols]).intersection(critical_symbols):
break
raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n'
f'The expression {element} of type {type(element)} is not known yet.')
# No dependencies to symbols defined/modified within the current element.
# We can move the node up one level and in front of the current element.
prev_element = element
element = element.parent
return last_block, last_block_child
......@@ -607,13 +738,7 @@ def move_constants_before_loop(ast_node):
get_blocks(ast_node, all_blocks)
for block in all_blocks:
children = block.take_child_nodes()
# Every time a symbol can be replaced in the current block because the assignment
# was found in a parent block, but with a different lhs symbol (same rhs)
# the outer symbol is inserted here as key.
substitute_variables = {}
for child in children:
# Before traversing the next child, all symbols are substituted first.
child.subs(substitute_variables)
if not isinstance(child, ast.SympyAssignment): # only move SympyAssignments
block.append(child)
......@@ -629,14 +754,7 @@ def move_constants_before_loop(ast_node):
exists_already = False
if not exists_already:
rhs_identical = check_if_assignment_already_in_block(child, target, True)
if rhs_identical:
# there is already an assignment out there with the same rhs
# -> replace all lhs symbols in this block with the lhs of the outer assignment
# -> remove the local assignment (do not re-append child to the former block)
substitute_variables[child.lhs] = rhs_identical.lhs
else:
target.insert_before(child, child_to_insert_before)
target.insert_before(child, child_to_insert_before)
elif exists_already and exists_already.rhs == child.rhs:
if target.args.index(exists_already) > target.args.index(child_to_insert_before):
assert target.args.count(exists_already) == 1
......@@ -650,7 +768,7 @@ def move_constants_before_loop(ast_node):
new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const),
child_to_insert_before)
substitute_variables[child.lhs] = new_symbol
block.append(ast.SympyAssignment(child.lhs, new_symbol, is_const=child.is_const))
def split_inner_loop(ast_node: ast.Node, symbol_groups):
......@@ -664,11 +782,11 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
and which no symbol in a symbol group depends on, are not updated!
"""
all_loops = ast_node.atoms(ast.LoopOverCoordinate)
inner_loop = [l for l in all_loops if l.is_innermost_loop]
inner_loop = [loop for loop in all_loops if loop.is_innermost_loop]
assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
inner_loop = inner_loop[0]
assert type(inner_loop.body) is ast.Block
outer_loop = [l for l in all_loops if l.is_outermost_loop]
outer_loop = [loop for loop in all_loops if loop.is_outermost_loop]
assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
outer_loop = outer_loop[0]
......@@ -702,8 +820,8 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
assignment_group = []
for assignment in inner_loop.body.args:
if assignment.lhs in symbols_resolved:
new_rhs = assignment.rhs.subs(
symbols_with_temporary_array.items())
# use fast_subs here because it checks if multiplications should be evaluated or not
new_rhs = fast_subs(assignment.rhs, symbols_with_temporary_array)
if not isinstance(assignment.lhs, Field.Access) and assignment.lhs in symbol_group:
assert type(assignment.lhs) is TypedSymbol
new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
......@@ -733,7 +851,8 @@ def cut_loop(loop_node, cutting_points):
One loop is transformed into len(cuttingPoints)+1 new loops that range from
old_begin to cutting_points[1], ..., cutting_points[-1] to old_end
Modifies the ast in place
Modifies the ast in place. Note Issue #5783 of SymPy. Deepcopy will evaluate mul
https://github.com/sympy/sympy/issues/5783
Returns:
list of new loop nodes
......@@ -771,12 +890,16 @@ def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = Fa
This analysis needs the integer set library (ISL) islpy, so it is not done by
default.
"""
from sympy.codegen.rewriting import ReplaceOptim, optimize
remove_casts = ReplaceOptim(lambda e: isinstance(e, CastFunc), lambda p: p.expr)
for conditional in node.atoms(ast.Conditional):
# TODO simplify conditional before the type system! Casts make it very hard here
# conditional.condition_expr = sp.simplify(conditional.condition_expr)
if conditional.condition_expr == sp.true:
condition_expression = optimize(conditional.condition_expr, [remove_casts])
condition_expression = sp.simplify(condition_expression)
if condition_expression == sp.true:
conditional.parent.replace(conditional, [conditional.true_block])
elif conditional.condition_expr == sp.false:
elif condition_expression == sp.false:
conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
elif loop_counter_simplification:
try:
......
......@@ -3,14 +3,14 @@ from pystencils.typing.cast_functions import (CastFunc, BooleanCastFunc, VectorM
from pystencils.typing.types import (is_supported_type, numpy_name_to_c, AbstractType, BasicType, VectorType,
PointerType, StructType, create_type)
from pystencils.typing.typed_sympy import (assumptions_from_dtype, TypedSymbol, FieldStrideSymbol, FieldShapeSymbol,
FieldPointerSymbol)
FieldPointerSymbol, CFunction)
from pystencils.typing.utilities import (typed_symbols, get_base_type, result_type, collate_types,
get_type_of_expression, get_next_parent_of_type, parents_of_type)
__all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCastFunc', 'PointerArithmeticFunc',
'is_supported_type', 'numpy_name_to_c', 'AbstractType', 'BasicType',
'VectorType', 'PointerType', 'StructType', 'create_type',
'assumptions_from_dtype', 'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol',
'VectorType', 'PointerType', 'StructType', 'create_type', 'assumptions_from_dtype',
'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol', 'CFunction',
'typed_symbols', 'get_base_type', 'result_type', 'collate_types',
'get_type_of_expression', 'get_next_parent_of_type', 'parents_of_type']
......@@ -6,11 +6,12 @@ import numpy as np
import sympy as sp
from sympy import Piecewise
from sympy.core.numbers import NegativeOne
from sympy.core.relational import Relational
from sympy.functions.elementary.piecewise import ExprCondPair
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
from sympy.functions.elementary.hyperbolic import HyperbolicFunction
from sympy.codegen import Assignment
from sympy.functions.elementary.integers import RoundFunction
from sympy.logic.boolalg import BooleanFunction
from sympy.logic.boolalg import BooleanAtom
......@@ -51,7 +52,7 @@ class TypeAdder:
def visit(self, obj):
if isinstance(obj, (list, tuple)):
return [self.visit(e) for e in obj]
if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)):
if isinstance(obj, ast.SympyAssignment):
return self.process_assignment(obj)
elif isinstance(obj, ast.Conditional):
condition, condition_type = self.figure_out_type(obj.condition_expr)
......@@ -67,7 +68,7 @@ class TypeAdder:
else:
raise ValueError("Invalid object in kernel " + str(type(obj)))
def process_assignment(self, assignment: Union[sp.Eq, ast.SympyAssignment, Assignment]) -> ast.SympyAssignment:
def process_assignment(self, assignment: ast.SympyAssignment) -> ast.SympyAssignment:
# for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
new_rhs, rhs_type = self.figure_out_type(assignment.rhs)
......@@ -81,11 +82,11 @@ class TypeAdder:
assert isinstance(new_lhs, (Field.Access, TypedSymbol))
if lhs_type != rhs_type:
logging.warning(f'Lhs"{new_lhs} of type "{lhs_type}" is assigned with a different datatype '
f'rhs: "{new_rhs}" of type "{rhs_type}".')
return ast.SympyAssignment(new_lhs, CastFunc(new_rhs, lhs_type))
logging.debug(f'Lhs"{new_lhs} of type "{lhs_type}" is assigned with a different datatype '
f'rhs: "{new_rhs}" of type "{rhs_type}".')
return ast.SympyAssignment(new_lhs, CastFunc(new_rhs, lhs_type), assignment.is_const, assignment.use_auto)
else:
return ast.SympyAssignment(new_lhs, new_rhs)
return ast.SympyAssignment(new_lhs, new_rhs, assignment.is_const, assignment.use_auto)
# Type System Specification
# - Defined Types: TypedSymbol, Field, Field.Access, ...?
......@@ -171,12 +172,13 @@ class TypeAdder:
args_types = [self.figure_out_type(a) for a in expr.args]
new_args = [a if t.dtype_eq(bool_type) else BooleanCastFunc(a, bool_type) for a, t in args_types]
return expr.func(*new_args), bool_type
elif type(expr, ) in pystencils.integer_functions.__dict__.values():
elif type(expr, ) in pystencils.integer_functions.__dict__.values() or isinstance(expr, sp.Mod):
args_types = [self.figure_out_type(a) for a in expr.args]
collated_type = collate_types([t for _, t in args_types])
# TODO: should we downcast to integer? If yes then which integer type?
if not collated_type.is_int():
raise ValueError(f"Integer functions need to be used with integer types but {collated_type} was given")
raise ValueError(f"Integer functions or Modulo need to be used with integer types "
f"but {collated_type} was given")
return expr, collated_type
elif isinstance(expr, flag_cond):
......@@ -212,7 +214,7 @@ class TypeAdder:
new_args.append(a)
return expr.func(*new_args) if new_args else expr, collated_type
elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction,
HyperbolicFunction)):
HyperbolicFunction, sp.log, RoundFunction)):
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
......@@ -228,6 +230,15 @@ class TypeAdder:
new_func = expr.func(*new_args) if new_args else expr
return CastFunc(new_func, collated_type), collated_type
elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)):
# Subtraction is realised a multiplication with -1 in SymPy. Thus we exclude the coefficient in this case
# and resolve the typing entirely with the expression itself
if isinstance(expr, sp.Mul):
c, e = expr.as_coeff_Mul()
if c == NegativeOne():
args_types = self.figure_out_type(e)
new_args = [NegativeOne(), args_types[0]]
return expr.func(*new_args, evaluate=False), args_types[1]
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
if isinstance(collated_type, PointerType):
......@@ -236,6 +247,10 @@ class TypeAdder:
else:
raise NotImplementedError(f'Pointer Arithmetic is implemented only for Add, not {expr}')
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
return expr.func(*new_args) if new_args else expr, collated_type
if isinstance(expr, (sp.Add, sp.Mul)):
return expr.func(*new_args, evaluate=False) if new_args else expr, collated_type
else:
return expr.func(*new_args) if new_args else expr, collated_type
else:
raise NotImplementedError(f'expr {type(expr)}: {expr} unknown to typing')
from typing import List
from pystencils.astnodes import Node
from pystencils.config import CreateKernelConfig
from pystencils.typing.leaf_typing import TypeAdder
from sympy.codegen import Assignment
def add_types(eqs: List[Assignment], config: CreateKernelConfig):
def add_types(node_list: List[Node], config: CreateKernelConfig):
"""Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
The AST needs to be a pystencils AST. Thus, in the list of nodes every entry must be inherited from
`pystencils.astnodes.Node`
Additionally returns sets of all fields which are read/written
Args:
eqs: list of equations
node_list: List of pystencils Nodes.
config: CreateKernelConfig
Returns:
......@@ -22,4 +24,4 @@ def add_types(eqs: List[Assignment], config: CreateKernelConfig):
default_number_float=config.default_number_float,
default_number_int=config.default_number_int)
return check.visit(eqs)
return check.visit(node_list)
......@@ -178,3 +178,20 @@ class FieldPointerSymbol(TypedSymbol):
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
class CFunction(TypedSymbol):
def __new__(cls, function, dtype):
return CFunction.__xnew_cached_(cls, function, dtype)
def __new_stage2__(cls, function, dtype):
return super(CFunction, cls).__xnew__(cls, function, dtype)
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), {}
......@@ -7,7 +7,7 @@ import sympy as sp
def is_supported_type(dtype: np.dtype):
scalar = dtype.type
c = np.issctype(dtype)
c = np.issubdtype(dtype, np.generic)
subclass = issubclass(scalar, np.floating) or issubclass(scalar, np.integer) or issubclass(scalar, np.bool_)
additional_checks = dtype.fields is None and dtype.hasobject is False and dtype.subdtype is None
return c and subclass and additional_checks
......@@ -25,6 +25,8 @@ def numpy_name_to_c(name: str) -> str:
return 'double'
elif name == 'float32':
return 'float'
elif name == 'float16' or name == 'half':
return 'half'
elif name.startswith('int'):
width = int(name[len("int"):])
return f"int{width}_t"
......@@ -68,7 +70,7 @@ class BasicType(AbstractType):
BasicType is defined with a const qualifier and a np.dtype.
"""
def __init__(self, dtype: Union[np.dtype, 'BasicType', str], const: bool = False):
def __init__(self, dtype: Union[type, 'BasicType', str], const: bool = False):
if isinstance(dtype, BasicType):
self.numpy_dtype = dtype.numpy_dtype
self.const = dtype.const
......@@ -94,6 +96,9 @@ class BasicType(AbstractType):
def is_float(self):
return issubclass(self.numpy_dtype.type, np.floating)
def is_half(self):
return issubclass(self.numpy_dtype.type, np.half)
def is_int(self):
return issubclass(self.numpy_dtype.type, np.integer)
......@@ -120,7 +125,10 @@ class BasicType(AbstractType):
return f'{self.c_name}{" const" if self.const else ""}'
def __repr__(self):
return str(self)
return f'BasicType( {str(self)} )'
def _repr_html_(self):
return f'BasicType( {str(self)} )'
def __eq__(self, other):
return self.dtype_eq(other) and self.const == other.const
......@@ -181,16 +189,17 @@ class VectorType(AbstractType):
class PointerType(AbstractType):
def __init__(self, base_type: BasicType, const: bool = False, restrict: bool = True):
def __init__(self, base_type: BasicType, const: bool = False, restrict: bool = True, double_pointer: bool = False):
self._base_type = base_type
self.const = const
self.restrict = restrict
self.double_pointer = double_pointer
def __getnewargs__(self):
return self.base_type, self.const, self.restrict
return self.base_type, self.const, self.restrict, self.double_pointer
def __getnewargs_ex__(self):
return (self.base_type, self.const, self.restrict), {}
return (self.base_type, self.const, self.restrict, self.double_pointer), {}
@property
def alias(self):
......@@ -202,22 +211,34 @@ class PointerType(AbstractType):
@property
def item_size(self):
return self.base_type.item_size
if self.double_pointer:
raise NotImplementedError("The item_size for double_pointer is not implemented")
else:
return self.base_type.item_size
def __eq__(self, other):
if not isinstance(other, PointerType):
return False
else:
return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
own = (self.base_type, self.const, self.restrict, self.double_pointer)
return own == (other.base_type, other.const, other.restrict, other.double_pointer)
def __str__(self):
return f'{str(self.base_type)} * {"RESTRICT " if self.restrict else "" }{"const" if self.const else ""}'
restrict_str = "RESTRICT" if self.restrict else ""
const_str = "const" if self.const else ""
if self.double_pointer:
return f'{str(self.base_type)} ** {restrict_str} {const_str}'
else:
return f'{str(self.base_type)} * {restrict_str} {const_str}'
def __repr__(self):
return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self):
return hash((self._base_type, self.const, self.restrict))
return hash((self._base_type, self.const, self.restrict, self.double_pointer))
class StructType(AbstractType):
......@@ -273,11 +294,14 @@ class StructType(AbstractType):
def __repr__(self):
return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self):
return hash((self.numpy_dtype, self.const))
def create_type(specification: Union[np.dtype, AbstractType, str]) -> AbstractType:
def create_type(specification: Union[type, AbstractType, str]) -> AbstractType:
# TODO: Deprecated Use the constructor of BasicType or StructType instead
"""Creates a subclass of Type according to a string or an object of subclass Type.
......
......@@ -187,18 +187,15 @@ def get_type_of_expression(expr,
# Fix for sympy versions from 1.9
sympy_version = sp.__version__.split('.')
if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
sympy_version_int = int(sympy_version[0]) * 100 + int(sympy_version[1])
if sympy_version_int >= 109:
# __setstate__ would bypass the contructor, so we remove it
sp.Number.__getstate__ = sp.Basic.__getstate__
del sp.Basic.__getstate__
class FunctorWithStoredKwargs:
def __init__(self, func, **kwargs):
self.func = func
self.kwargs = kwargs
def __call__(self, *args):
return self.func(*args, **self.kwargs)
if sympy_version_int >= 111:
del sp.Basic.__setstate__
del sp.Symbol.__setstate__
else:
sp.Number.__getstate__ = sp.Basic.__getstate__
del sp.Basic.__getstate__
# __reduce_ex__ would strip kwargs, so we override it
def basic_reduce_ex(self, protocol):
......@@ -210,9 +207,7 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
state = self.__getstate__()
else:
state = None
return FunctorWithStoredKwargs(type(self), **kwargs), args, state
sp.Number.__reduce_ex__ = sp.Basic.__reduce_ex__
return partial(type(self), **kwargs), args, state
sp.Basic.__reduce_ex__ = basic_reduce_ex
......
......@@ -82,8 +82,8 @@ def boolean_array_bounding_box(boolean_array):
>>> a = np.zeros((4, 4), dtype=bool)
>>> a[1:-1, 1:-1] = True
>>> boolean_array_bounding_box(a)
[(1, 3), (1, 3)]
>>> boolean_array_bounding_box(a) == [(1, 3), (1, 3)]
True
"""
dim = boolean_array.ndim
shape = boolean_array.shape
......@@ -96,6 +96,21 @@ def boolean_array_bounding_box(boolean_array):
return bounds
def binary_numbers(n):
"""Returns all binary numbers up to 2^n - 1
Example:
>>> binary_numbers(2)
[[0, 0], [0, 1], [1, 0], [1, 1]]
"""
result = list()
for i in range(1 << n):
binary_number = bin(i)[2:]
binary_number = '0' * (n - len(binary_number)) + binary_number
result.append((list(map(int, binary_number))))
return result
class LinearEquationSystem:
"""Symbolic linear system of equations - consisting of matrix and right hand side.
......
File moved
File moved
......@@ -170,3 +170,19 @@ def test_new_merged():
assert ps.Assignment(d[0, 0](0), sp.symbols('xi_0')) in merged_ac.main_assignments
assert a1 in merged_ac.subexpressions
assert a3 in merged_ac.subexpressions
a1 = ps.Assignment(a, 20)
a2 = ps.Assignment(a, 10)
acommon = ps.Assignment(b, a)
# main assignments
a3 = ps.Assignment(f[0, 0](0), b)
a4 = ps.Assignment(d[0, 0](0), b)
ac = ps.AssignmentCollection([a3], subexpressions=[a1, acommon])
ac2 = ps.AssignmentCollection([a4], subexpressions=[a2, acommon])
merged_ac = ac.new_merged(ac2).new_without_subexpressions()
assert ps.Assignment(f[0, 0](0), 20) in merged_ac.main_assignments
assert ps.Assignment(d[0, 0](0), 10) in merged_ac.main_assignments