An error occurred while loading the file. Please try again.
-
Martin Bauer authored01ab38e8
integer_set_analysis.py 2.92 KiB
"""Transformations using integer sets based on ISL library"""
import sympy as sp
import islpy as isl
import pystencils.astnodes as ast
from pystencils.transformations import parents_of_type
def remove_brackets(s):
return s.replace('[', '').replace(']', '')
def _degrees_of_freedom_as_string(expr):
expr = sp.sympify(expr)
indexed = expr.atoms(sp.Indexed)
symbols = expr.atoms(sp.Symbol)
symbols_without_indexed_base = symbols - {ind.base.args[0] for ind in indexed}
symbols_without_indexed_base.update(indexed)
return {remove_brackets(str(s)) for s in symbols_without_indexed_base}
def isl_iteration_set(node: ast.Node):
"""Builds up an ISL set describing the iteration space by analysing the enclosing loops of the given node. """
conditions = []
degrees_of_freedom = set()
for loop in parents_of_type(node, ast.LoopOverCoordinate):
if loop.step != 1:
raise NotImplementedError("Loops with strides != 1 are not yet supported.")
degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.loop_counter_symbol))
degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.start))
degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.stop))
loop_start_str = remove_brackets(str(loop.start))
loop_stop_str = remove_brackets(str(loop.stop))
ctr_name = loop.loop_counter_name
set_string_description = "{} >= {} and {} < {}".format(ctr_name, loop_start_str, ctr_name, loop_stop_str)
conditions.append(remove_brackets(set_string_description))
symbol_names = ','.join(degrees_of_freedom)
condition_str = ' and '.join(conditions)
set_description = "{{ [{symbol_names}] : {condition_str} }}".format(symbol_names=symbol_names,
condition_str=condition_str)
return degrees_of_freedom, isl.BasicSet(set_description)
def simplify_loop_counter_dependent_conditional(conditional):
"""Removes conditionals that depend on the loop counter or iteration limits if they are always true/false."""
dofs_in_condition = _degrees_of_freedom_as_string(conditional.condition_expr)
dofs_in_loops, iteration_set = isl_iteration_set(conditional)
if dofs_in_condition.issubset(dofs_in_loops):
symbol_names = ','.join(dofs_in_loops)
condition_str = remove_brackets(str(conditional.condition_expr))
condition_set = isl.BasicSet("{{ [{symbol_names}] : {condition_str} }}".format(symbol_names=symbol_names,
condition_str=condition_str))
if condition_set.is_empty():
conditional.replace_by_false_block()
intersection = iteration_set.intersect(condition_set)
if intersection.is_empty():
conditional.replace_by_false_block()
elif intersection == iteration_set:
conditional.replace_by_true_block()