diff --git a/astnodes.py b/astnodes.py index 362eb7a38b380bb6d0e87616ca9ac5a5cc8259fd..96d9353918c4df3d447a842c9a34ddf59db731c8 100644 --- a/astnodes.py +++ b/astnodes.py @@ -109,6 +109,14 @@ class Conditional(Node): def __repr__(self): return 'if:({!r}) '.format(self.condition_expr) + def replace_by_true_block(self): + """Replaces the conditional by its True block""" + self.parent.replace(self, [self.true_block]) + + def replace_by_false_block(self): + """Replaces the conditional by its False block""" + self.parent.replace(self, [self.false_block] if self.false_block else []) + class KernelFunction(Node): diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py index d17d23aa0ec5d32bc17df7dbe38a0ff72f2a4029..460db7e129f086a576a5c04f02ae7eab4370bb58 100644 --- a/gpucuda/kernelcreation.py +++ b/gpucuda/kernelcreation.py @@ -2,7 +2,7 @@ from functools import partial from pystencils.gpucuda.indexing import BlockIndexing from pystencils.transformations import resolve_field_accesses, type_all_equations, parse_base_pointer_info, \ - get_common_shape, substitute_array_accesses_with_constants, resolve_buffer_accesses + get_common_shape, substitute_array_accesses_with_constants, resolve_buffer_accesses, unify_shape_symbols from pystencils.astnodes import Block, KernelFunction, SympyAssignment, LoopOverCoordinate from pystencils.data_types import TypedSymbol, BasicType, StructType from pystencils import Field, FieldType @@ -22,8 +22,7 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde num_buffer_accesses = 0 for eq in assignments: field_accesses.update(eq.atoms(Field.Access)) - - num_buffer_accesses += sum([1 for access in eq.atoms(Field.Access) if FieldType.is_buffer(access.field)]) + num_buffer_accesses += sum(1 for access in eq.atoms(Field.Access) if FieldType.is_buffer(access.field)) common_shape = get_common_shape(fields_without_buffers) @@ -51,6 +50,8 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde assignments = cell_idx_assignments + assignments block = Block(assignments) + unify_shape_symbols(block, common_shape=common_shape, fields=fields_without_buffers) + block = indexing.guard(block, common_shape) ast = KernelFunction(block, function_name=function_name, ghost_layers=ghost_layers, backend='gpucuda') ast.global_variables.update(indexing.index_variables) diff --git a/integer_set_analysis.py b/integer_set_analysis.py index 034f0049b0e64cf7b0ff50817631aa4ed6fc84ef..8d01c9ee7367ed682eb6cde0f7d3867b59f5f39a 100644 --- a/integer_set_analysis.py +++ b/integer_set_analysis.py @@ -2,71 +2,62 @@ import sympy as sp import islpy as isl -from typing import Tuple import pystencils.astnodes as ast from pystencils.transformations import parents_of_type -#context = isl.Context() -""" -- find all Condition nodes -- check if they depend on integers only -- create ISL space containing names of all loop symbols (counter and bounds) and all integers in Conditional expression -- build up pre-condition set by iteration over each enclosing loop add ISL constraints -- build up ISL space for condition -- if pre_condition_set.intersect(conditional_set) == pre_condition_set - always use True condition - elif pre_condition_set.intersect(conditional_set).is_empty(): - always use False condition -""" +def remove_brackets(s): + return s.replace('[', '').replace(']', '') + + +def _degrees_of_freedom_as_string(expr): + expr = sp.sympify(expr) + indexed = expr.atoms(sp.Indexed) + symbols = expr.atoms(sp.Symbol) + symbols_without_indexed_base = symbols - {ind.base.args[0] for ind in indexed} + symbols_without_indexed_base.update(indexed) + return {remove_brackets(str(s)) for s in symbols_without_indexed_base} def isl_iteration_set(node: ast.Node): """Builds up an ISL set describing the iteration space by analysing the enclosing loops of the given node. """ conditions = [] - loop_symbols = set() + degrees_of_freedom = set() + for loop in parents_of_type(node, ast.LoopOverCoordinate): if loop.step != 1: raise NotImplementedError("Loops with strides != 1 are not yet supported.") - loop_symbols.add(loop.loop_counter_symbol) - loop_symbols.update(sp.sympify(loop.start).atoms(sp.Symbol)) - loop_symbols.update(sp.sympify(loop.stop).atoms(sp.Symbol)) + degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.loop_counter_symbol)) + degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.start)) + degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.stop)) - loop_start_str = str(loop.start).replace('[', '_bracket1_').replace(']', '_bracket2_') - loop_stop_str = str(loop.stop).replace('[', '_bracket1_').replace(']', '_bracket2_') + loop_start_str = remove_brackets(str(loop.start)) + loop_stop_str = remove_brackets(str(loop.stop)) ctr_name = loop.loop_counter_name - conditions.append(f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}") + conditions.append(remove_brackets(f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}")) - symbol_names = ','.join([s.name for s in loop_symbols]) + symbol_names = ','.join(degrees_of_freedom) condition_str = ' and '.join(conditions) set_description = f"{{ [{symbol_names}] : {condition_str} }}" - return loop_symbols, isl.BasicSet(set_description) - - for loop in parents_of_type(node, ast.LoopOverCoordinate): - ctr_name = loop.loop_counter_name - lower_constraint = isl.Constraint.ineq_from_names(space, {ctr_name: 1, 1: -loop.start}) - upper_constraint = isl.Constraint.ineq_from_names(space, {ctr_name: 1, }) - - -def simplify_conditionals_new(ast_node): - for conditional in ast_node.atoms(ast.Conditional): - if conditional.condition_expr == sp.true: - conditional.parent.replace(conditional, [conditional.true_block]) - elif conditional.condition_expr == sp.false: - conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else []) - else: - loop_symbols, iteration_set = isl_iteration_set(conditional) - symbols_in_condition = conditional.condition_expr.atoms(sp.Symbol) - if symbols_in_condition.issubset(loop_symbols): - symbol_names = ','.join([s.name for s in loop_symbols]) - condition_str = str(conditional.condition_expr) - condition_set = isl.BasicSet(f"{{ [{symbol_names}] : {condition_str} }}") - - intersection = iteration_set.intersect(condition_set) - if intersection.is_empty(): - conditional.parent.replace(conditional, - [conditional.false_block] if conditional.false_block else []) - elif intersection == iteration_set: - conditional.parent.replace(conditional, [conditional.true_block]) + return degrees_of_freedom, isl.BasicSet(set_description) + + +def simplify_loop_counter_dependent_conditional(conditional): + """Removes conditionals that depend on the loop counter or iteration limits if they are always true/false.""" + dofs_in_condition = _degrees_of_freedom_as_string(conditional.condition_expr) + dofs_in_loops, iteration_set = isl_iteration_set(conditional) + if dofs_in_condition.issubset(dofs_in_loops): + symbol_names = ','.join(dofs_in_loops) + condition_str = remove_brackets(str(conditional.condition_expr)) + condition_set = isl.BasicSet(f"{{ [{symbol_names}] : {condition_str} }}") + + if condition_set.is_empty(): + conditional.replace_by_false_block() + + intersection = iteration_set.intersect(condition_set) + if intersection.is_empty(): + conditional.replace_by_false_block() + elif intersection == iteration_set: + conditional.replace_by_true_block() diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index 4f95505dc60fb943fef667361eb4ceb208bd7be3..bb822b48131dd5ca5cc344d52a28fd2125ebe9d1 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -30,24 +30,3 @@ def create_kernel(assignments, function_name="kernel", type_info=None, split_gro code = insert_casts(code) code.compile = partial(make_python_function, code) return code - - -def create_indexed_kernel(assignments, index_fields, function_name="kernel", type_info=None, - coordinate_names=('x', 'y', 'z')): - """ - Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with - coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling. - - The coordinates are stored in a separated index_field, which is a one dimensional array with struct data type. - This struct has to contain fields named 'x', 'y' and for 3D fields ('z'). These names are configurable with the - 'coordinate_names' parameter. The struct can have also other fields that can be read and written in the kernel, for - example boundary parameters. - - :param assignments: list of update equations or AST nodes - :param index_fields: list of index fields, i.e. 1D fields with struct data type - :param type_info: see documentation of :func:`create_kernel` - :param function_name: see documentation of :func:`create_kernel` - :param coordinate_names: name of the coordinate fields in the struct data type - :return: abstract syntax tree - """ - raise NotImplementedError diff --git a/transformations.py b/transformations.py index d7e7f160b565a24c636acf2c3fa438260225973e..e5158d55650f43b3f246c114ee666efa5c8d5621 100644 --- a/transformations.py +++ b/transformations.py @@ -1,3 +1,4 @@ +import warnings from collections import defaultdict, OrderedDict from copy import deepcopy from types import MappingProxyType @@ -19,6 +20,30 @@ def filtered_tree_iteration(node, node_type): yield from filtered_tree_iteration(arg, node_type) +def unify_shape_symbols(body, common_shape, fields): + """Replaces symbols for array sizes to ensure they are represented by the same unique symbol. + + When creating a kernel with variable array sizes, all passed arrays must have the same size. + This is ensured when the kernel is called. Inside the kernel this means that only on symbol has to be used instead + of one for each field. For example shape_arr1[0] and shape_arr2[0] must be equal, so they should also be + represented by the same symbol. + + Args: + body: ast node, for the kernel part where substitutions is made, is modified in-place + common_shape: shape of the field that was chosen + fields: all fields whose shapes should be replaced by common_shape + """ + substitutions = {} + for field in fields: + assert len(field.spatial_shape) == len(common_shape) + if not field.has_fixed_shape: + for common_shape_component, shape_component in zip(common_shape, field.spatial_shape): + if shape_component != common_shape_component: + substitutions[shape_component] = common_shape_component + if substitutions: + body.subs(substitutions) + + def get_common_shape(field_set): """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise ValueError is raised""" @@ -47,7 +72,7 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST. Args: - body: list of nodes + body: Block object with inner loop contents function_name: name of generated C function iteration_slice: if not None, iteration is done only over this slice of the field ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers @@ -68,7 +93,8 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer if loop_order is None: loop_order = get_optimal_loop_ordering(fields) - shape = get_common_shape(list(fields)) + shape = get_common_shape(fields) + unify_shape_symbols(body, common_shape=shape, fields=fields) if iteration_slice is not None: iteration_slice = normalize_slice(iteration_slice, shape) @@ -580,99 +606,33 @@ def cut_loop(loop_node, cutting_points): loop_node.parent.replace(loop_node, new_loops) -def is_condition_necessary(condition, pre_condition, symbol): - """ - Determines if a logical condition of a single variable is already contained in a stronger pre_condition - so if from pre_condition follows that condition is always true, then this condition is not necessary +def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool=False) -> None: + """Removes conditionals that are always true/false. Args: - condition: sympy relational of one variable - pre_condition: logical expression that is known to be true - symbol: the single symbol of interest - - Returns: - returns not (pre_condition => condition) where "=>" is logical implication - """ - from sympy.solvers.inequalities import reduce_rational_inequalities - from sympy.logic.boolalg import to_dnf - - def normalize_relational(e): - if isinstance(e, sp.Rel): - return e.func(e.lhs - e.rhs, 0) - else: - new_args = [normalize_relational(a) for a in e.args] - return e.func(*new_args) if new_args else e - - def to_dnf_list(expr): - result = to_dnf(expr) - if isinstance(result, sp.Or): - return [or_term.args for or_term in result.args] - elif isinstance(result, sp.And): - return [result.args] - else: - return [result] - - condition = normalize_relational(condition) - pre_condition = normalize_relational(pre_condition) - a1 = to_dnf_list(pre_condition) - a2 = to_dnf_list(condition) - t1 = reduce_rational_inequalities(to_dnf_list(sp.And(condition, pre_condition)), symbol) - t2 = reduce_rational_inequalities(to_dnf_list(pre_condition), symbol) - return t1 != t2 - - -def simplify_boolean_expression(expr, single_variable_ranges): - """Simplification of boolean expression using known ranges of variables - The singleVariableRanges parameter is a dict mapping a variable name to a sympy logical expression that - contains only this variable and defines a range for it. For example with a being a symbol - { a: sp.And(a >=0, a < 10) } + node: ast node, all descendants of this node are simplified + loop_counter_simplification: if enabled, tries to detect if a conditional is always true/false + depending on the surrounding loop. For example if the surrounding loop goes from + x=0 to 10 and the condition is x < 0, it is removed. + This analysis needs the integer set library (ISL) islpy, so it is not done by + default. """ - from sympy.core.relational import Relational - from sympy.logic.boolalg import to_dnf - - expr = to_dnf(expr) - - def visit(e): - if isinstance(e, Relational): - symbols = e.atoms(sp.Symbol).intersection(single_variable_ranges.keys()) - if len(symbols) == 1: - symbol = symbols.pop() - if not is_condition_necessary(e, single_variable_ranges[symbol], symbol): - return sp.true - return e - else: - new_args = [visit(a) for a in e.args] - return e.func(*new_args) if new_args else e - - return visit(expr) - - -def simplify_conditionals(node, loop_conditionals=MappingProxyType({})): - """Simplifies/Removes conditions inside loops that depend on the loop counter.""" - if isinstance(node, ast.LoopOverCoordinate): - ctr_sym = node.loop_counter_symbol - loop_conditionals = loop_conditionals.copy() - loop_conditionals[ctr_sym] = sp.And(ctr_sym >= node.start, ctr_sym < node.stop) - simplify_conditionals(node.body, loop_conditionals) - elif isinstance(node, ast.Conditional): - node.condition_expr = simplify_boolean_expression(node.condition_expr, loop_conditionals) - simplify_conditionals(node.true_block) - if node.false_block: - simplify_conditionals(node.false_block, loop_conditionals) - if node.condition_expr == sp.true: - node.parent.replace(node, [node.true_block]) - if node.condition_expr == sp.false: - node.parent.replace(node, [node.false_block] if node.false_block else []) - elif isinstance(node, ast.Block): - for a in list(node.args): - simplify_conditionals(a, loop_conditionals) - elif isinstance(node, ast.SympyAssignment): - return node - else: - raise ValueError("Can not handle node", type(node)) - - -def cleanup_blocks(node): + for conditional in node.atoms(ast.Conditional): + conditional.condition_expr = sp.simplify(conditional.condition_expr) + if conditional.condition_expr == sp.true: + conditional.parent.replace(conditional, [conditional.true_block]) + elif conditional.condition_expr == sp.false: + conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else []) + elif loop_counter_simplification: + try: + # noinspection PyUnresolvedReferences + from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional + simplify_loop_counter_dependent_conditional(conditional) + except ImportError: + warnings.warn("Integer simplifications in conditionals skipped, because ISLpy package not installed") + + +def cleanup_blocks(node: ast.Node) -> None: """Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """ if isinstance(node, ast.SympyAssignment): return @@ -850,9 +810,9 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) - inner_loop = all_inner_loops.pop() for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True): - cut_loop(loop, [loop.stop-1]) + cut_loop(loop, [loop.stop - 1]) - simplify_conditionals(function_node.body) + simplify_conditionals(function_node.body, loop_counter_simplification=True) cleanup_blocks(function_node.body) move_constants_before_loop(function_node.body) cleanup_blocks(function_node.body) @@ -884,8 +844,10 @@ def typing_from_sympy_inspection(eqs, default_type="double"): def get_next_parent_of_type(node, parent_type): - """ - Traverses the AST nodes parents until a parent of given type was found. If no such parent is found, None is returned + """Returns the next parent node of given type or None, if root is reached. + + Traverses the AST nodes parents until a parent of given type was found. + If no such parent is found, None is returned """ parent = node.parent while parent is not None: @@ -896,21 +858,24 @@ def get_next_parent_of_type(node, parent_type): def parents_of_type(node, parent_type, include_current=False): - """Similar to get_next_parent_of_type, but as generator""" + """Generator for all parent nodes of given type""" parent = node if include_current else node.parent while parent is not None: if isinstance(parent, parent_type): yield parent parent = parent.parent - return None def get_optimal_loop_ordering(fields): """ Determines the optimal loop order for a given set of fields. If the fields have different memory layout or different sizes an exception is thrown. - :param fields: sequence of fields - :return: list of coordinate ids, where the first list entry should be the outermost loop + + Args: + fields: sequence of fields + + Returns: + list of coordinate ids, where the first list entry should be the outermost loop """ assert len(fields) > 0 ref_field = next(iter(fields))