Skip to content
Snippets Groups Projects
integer_set_analysis.py 2.92 KiB
"""Transformations using integer sets based on ISL library"""

import sympy as sp
import islpy as isl

import pystencils.astnodes as ast
from pystencils.transformations import parents_of_type


def remove_brackets(s):
    return s.replace('[', '').replace(']', '')


def _degrees_of_freedom_as_string(expr):
    expr = sp.sympify(expr)
    indexed = expr.atoms(sp.Indexed)
    symbols = expr.atoms(sp.Symbol)
    symbols_without_indexed_base = symbols - {ind.base.args[0] for ind in indexed}
    symbols_without_indexed_base.update(indexed)
    return {remove_brackets(str(s)) for s in symbols_without_indexed_base}


def isl_iteration_set(node: ast.Node):
    """Builds up an ISL set describing the iteration space by analysing the enclosing loops of the given node. """
    conditions = []
    degrees_of_freedom = set()

    for loop in parents_of_type(node, ast.LoopOverCoordinate):
        if loop.step != 1:
            raise NotImplementedError("Loops with strides != 1 are not yet supported.")

        degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.loop_counter_symbol))
        degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.start))
        degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.stop))

        loop_start_str = remove_brackets(str(loop.start))
        loop_stop_str = remove_brackets(str(loop.stop))
        ctr_name = loop.loop_counter_name
        set_string_description = "{} >= {} and {} < {}".format(ctr_name, loop_start_str, ctr_name, loop_stop_str)
        conditions.append(remove_brackets(set_string_description))

    symbol_names = ','.join(degrees_of_freedom)
    condition_str = ' and '.join(conditions)
    set_description = "{{ [{symbol_names}] : {condition_str} }}".format(symbol_names=symbol_names,
                                                                        condition_str=condition_str)
    return degrees_of_freedom, isl.BasicSet(set_description)


def simplify_loop_counter_dependent_conditional(conditional):
    """Removes conditionals that depend on the loop counter or iteration limits if they are always true/false."""
    dofs_in_condition = _degrees_of_freedom_as_string(conditional.condition_expr)
    dofs_in_loops, iteration_set = isl_iteration_set(conditional)
    if dofs_in_condition.issubset(dofs_in_loops):
        symbol_names = ','.join(dofs_in_loops)
        condition_str = remove_brackets(str(conditional.condition_expr))
        condition_set = isl.BasicSet("{{ [{symbol_names}] : {condition_str} }}".format(symbol_names=symbol_names,
                                                                                       condition_str=condition_str))

        if condition_set.is_empty():
            conditional.replace_by_false_block()

        intersection = iteration_set.intersect(condition_set)
        if intersection.is_empty():
            conditional.replace_by_false_block()
        elif intersection == iteration_set:
            conditional.replace_by_true_block()