integer_set_analysis.py 2.75 KB
Newer Older
1
2
3
"""Transformations using integer sets based on ISL library"""

import islpy as isl
Martin Bauer's avatar
Martin Bauer committed
4
import sympy as sp
5
6

import pystencils.astnodes as ast
Jan Hönig's avatar
Jan Hönig committed
7
from pystencils.typing import parents_of_type
8
from pystencils.backends.cbackend import CustomSympyPrinter
9
10


11
12
13
14
15
16
17
18
19
20
21
def remove_brackets(s):
    return s.replace('[', '').replace(']', '')


def _degrees_of_freedom_as_string(expr):
    expr = sp.sympify(expr)
    indexed = expr.atoms(sp.Indexed)
    symbols = expr.atoms(sp.Symbol)
    symbols_without_indexed_base = symbols - {ind.base.args[0] for ind in indexed}
    symbols_without_indexed_base.update(indexed)
    return {remove_brackets(str(s)) for s in symbols_without_indexed_base}
22
23
24
25
26


def isl_iteration_set(node: ast.Node):
    """Builds up an ISL set describing the iteration space by analysing the enclosing loops of the given node. """
    conditions = []
27
28
    degrees_of_freedom = set()

29
30
31
32
    for loop in parents_of_type(node, ast.LoopOverCoordinate):
        if loop.step != 1:
            raise NotImplementedError("Loops with strides != 1 are not yet supported.")

33
34
35
        degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.loop_counter_symbol))
        degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.start))
        degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.stop))
36

37
38
        loop_start_str = remove_brackets(str(loop.start))
        loop_stop_str = remove_brackets(str(loop.stop))
39
        ctr_name = loop.loop_counter_name
40
        set_string_description = f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}"
41
        conditions.append(remove_brackets(set_string_description))
42

43
    symbol_names = ','.join(degrees_of_freedom)
44
    condition_str = ' and '.join(conditions)
45
    set_description = f"{{ [{symbol_names}] : {condition_str} }}"
46
47
48
49
50
51
52
53
54
    return degrees_of_freedom, isl.BasicSet(set_description)


def simplify_loop_counter_dependent_conditional(conditional):
    """Removes conditionals that depend on the loop counter or iteration limits if they are always true/false."""
    dofs_in_condition = _degrees_of_freedom_as_string(conditional.condition_expr)
    dofs_in_loops, iteration_set = isl_iteration_set(conditional)
    if dofs_in_condition.issubset(dofs_in_loops):
        symbol_names = ','.join(dofs_in_loops)
55
56
        condition_str = CustomSympyPrinter().doprint(conditional.condition_expr)
        condition_str = remove_brackets(condition_str)
57
        condition_set = isl.BasicSet(f"{{ [{symbol_names}] : {condition_str} }}")
58
59
60
61
62
63
64
65
66

        if condition_set.is_empty():
            conditional.replace_by_false_block()

        intersection = iteration_set.intersect(condition_set)
        if intersection.is_empty():
            conditional.replace_by_false_block()
        elif intersection == iteration_set:
            conditional.replace_by_true_block()