Commit f13701df authored by Martin Bauer's avatar Martin Bauer
Browse files

WIP: ISL based integer condition optimization

parent e31f1062
......@@ -481,12 +481,12 @@ class SympyAssignment(Node):
raise ValueError('%s is not in args of %s' % (replacement, self.__class__))
def __repr__(self):
return repr(self.lhs) + " = " + repr(self.rhs)
return repr(self.lhs) + " " + repr(self.rhs)
def _repr_html_(self):
printed_lhs = sp.latex(self.lhs)
printed_rhs = sp.latex(self.rhs)
return f"${printed_lhs} = {printed_rhs}$"
return f"${printed_lhs} \leftarrow {printed_rhs}$"
class ResolvedFieldAccess(sp.Indexed):
......
......@@ -43,17 +43,22 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde
-ghost_layers[i][1] if ghost_layers[i][1] > 0 else None))
indexing = indexing_creator(field=list(fields_without_buffers)[0], iteration_slice=iteration_slice)
coord_mapping = indexing.coordinates
cell_idx_assignments = [SympyAssignment(LoopOverCoordinate.get_loop_counter_symbol(i), value)
for i, value in enumerate(coord_mapping)]
cell_idx_symbols = [LoopOverCoordinate.get_loop_counter_symbol(i) for i, _ in enumerate(coord_mapping)]
assignments = cell_idx_assignments + assignments
block = Block(assignments)
block = indexing.guard(block, common_shape)
ast = KernelFunction(block, function_name=function_name, ghost_layers=ghost_layers, backend='gpucuda')
ast.global_variables.update(indexing.index_variables)
coord_mapping = indexing.coordinates
base_pointer_info = [['spatialInner0']]
base_pointer_infos = {f.name: parse_base_pointer_info(base_pointer_info, [2, 1, 0], f) for f in all_fields}
coord_mapping = {f.name: coord_mapping for f in all_fields}
coord_mapping = {f.name: cell_idx_symbols for f in all_fields}
loop_vars = [num_buffer_accesses * i for i in indexing.coordinates]
loop_strides = list(fields_without_buffers)[0].shape
......@@ -102,11 +107,11 @@ def created_indexed_cuda_kernel(assignments, index_fields, function_name="kernel
spatial_coordinates = list(spatial_coordinates)[0]
def get_coordinate_symbol_assignment(name):
for index_field in index_fields:
assert isinstance(index_field.dtype, StructType), "Index fields have to have a struct data type"
data_type = index_field.dtype
for ind_f in index_fields:
assert isinstance(ind_f.dtype, StructType), "Index fields have to have a struct data type"
data_type = ind_f.dtype
if data_type.has_element(name):
rhs = index_field[0](name)
rhs = ind_f[0](name)
lhs = TypedSymbol(name, BasicType(data_type.get_element_type(name)))
return SympyAssignment(lhs, rhs)
raise ValueError("Index %s not found in any of the passed index fields" % (name,))
......
"""Transformations using integer sets based on ISL library"""
import sympy as sp
import islpy as isl
from typing import Tuple
import pystencils.astnodes as ast
from pystencils.transformations import parents_of_type
#context = isl.Context()
"""
- find all Condition nodes
- check if they depend on integers only
- create ISL space containing names of all loop symbols (counter and bounds) and all integers in Conditional expression
- build up pre-condition set by iteration over each enclosing loop add ISL constraints
- build up ISL space for condition
- if pre_condition_set.intersect(conditional_set) == pre_condition_set
always use True condition
elif pre_condition_set.intersect(conditional_set).is_empty():
always use False condition
"""
def isl_iteration_set(node: ast.Node):
"""Builds up an ISL set describing the iteration space by analysing the enclosing loops of the given node. """
conditions = []
loop_symbols = set()
for loop in parents_of_type(node, ast.LoopOverCoordinate):
if loop.step != 1:
raise NotImplementedError("Loops with strides != 1 are not yet supported.")
loop_symbols.add(loop.loop_counter_symbol)
loop_symbols.update(sp.sympify(loop.start).atoms(sp.Symbol))
loop_symbols.update(sp.sympify(loop.stop).atoms(sp.Symbol))
loop_start_str = str(loop.start).replace('[', '_bracket1_').replace(']', '_bracket2_')
loop_stop_str = str(loop.stop).replace('[', '_bracket1_').replace(']', '_bracket2_')
ctr_name = loop.loop_counter_name
conditions.append(f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}")
symbol_names = ','.join([s.name for s in loop_symbols])
condition_str = ' and '.join(conditions)
set_description = f"{{ [{symbol_names}] : {condition_str} }}"
return loop_symbols, isl.BasicSet(set_description)
for loop in parents_of_type(node, ast.LoopOverCoordinate):
ctr_name = loop.loop_counter_name
lower_constraint = isl.Constraint.ineq_from_names(space, {ctr_name: 1, 1: -loop.start})
upper_constraint = isl.Constraint.ineq_from_names(space, {ctr_name: 1, })
def simplify_conditionals_new(ast_node):
for conditional in ast_node.atoms(ast.Conditional):
if conditional.condition_expr == sp.true:
conditional.parent.replace(conditional, [conditional.true_block])
elif conditional.condition_expr == sp.false:
conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
else:
loop_symbols, iteration_set = isl_iteration_set(conditional)
symbols_in_condition = conditional.condition_expr.atoms(sp.Symbol)
if symbols_in_condition.issubset(loop_symbols):
symbol_names = ','.join([s.name for s in loop_symbols])
condition_str = str(conditional.condition_expr)
condition_set = isl.BasicSet(f"{{ [{symbol_names}] : {condition_str} }}")
intersection = iteration_set.intersect(condition_set)
if intersection.is_empty():
conditional.parent.replace(conditional,
[conditional.false_block] if conditional.false_block else [])
elif intersection == iteration_set:
conditional.parent.replace(conditional, [conditional.true_block])
from types import MappingProxyType
import sympy as sp
from pystencils.assignment import Assignment
from pystencils.astnodes import LoopOverCoordinate, Conditional, Block, SympyAssignment
from pystencils.assignment_collection import AssignmentCollection
from pystencils.gpucuda.indexing import indexing_creator_from_params
from pystencils.transformations import remove_conditionals_in_staggered_kernel
def create_kernel(equations, target='cpu', data_type="double", iteration_slice=None, ghost_layers=None,
......@@ -104,3 +108,42 @@ def create_indexed_kernel(assignments, index_fields, target='cpu', data_type="do
return ast
else:
raise ValueError("Unknown target %s. Has to be either 'cpu' or 'gpu'" % (target,))
def create_staggered_kernel(staggered_field, expressions, subexpressions=(), target='cpu', **kwargs):
"""Kernel that updates a staggered field.
Args:
staggered_field: field that has one index coordinate and
where e.g. f[0,0](0) is interpreted as value at the left cell boundary, f[1,0](0) the right cell
boundary and f[0,0](1) the southern cell boundary etc.
expressions: sequence of expressions of length dim, defining how the east, southern, (bottom) cell boundary
should be update
subexpressions: optional sequence of Assignments, that define subexpressions used in the main expressions
target: 'cpu' or 'gpu'
kwargs: passed directly to create_kernel, iteration slice and ghost_layers parameters are not allowed
Returns:
AST
"""
assert 'iteration_slice' not in kwargs and 'ghost_layers' not in kwargs
assert staggered_field.index_dimensions == 1, 'Staggered field must have exactly one index dimension'
dim = staggered_field.spatial_dimensions
counters = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(dim)]
conditions = [counters[i] < staggered_field.shape[i] - 1 for i in range(dim)]
assert len(expressions) == dim
final_assignments = []
for d in range(dim):
cond = sp.And(*[conditions[i] for i in range(dim) if d != i])
a_coll = AssignmentCollection([Assignment(staggered_field(d), expressions[d])], list(subexpressions))
a_coll = a_coll.new_filtered([staggered_field(d)])
sp_assignments = [SympyAssignment(a.lhs, a.rhs) for a in a_coll.all_assignments]
final_assignments.append(Conditional(cond, Block(sp_assignments)))
ghost_layers = [(1, 0)] * dim
ast = create_kernel(final_assignments, ghost_layers=ghost_layers, target=target, **kwargs)
if target == 'cpu':
remove_conditionals_in_staggered_kernel(ast)
return ast
......@@ -582,16 +582,27 @@ def cut_loop(loop_node, cutting_points):
def is_condition_necessary(condition, pre_condition, symbol):
"""
Determines if a logical condition of a single variable is already contained in a stronger preCondition
so if from preCondition follows that condition is always true, then this condition is not necessary
:param condition: sympy relational of one variable
:param pre_condition: logical expression that is known to be true
:param symbol: the single symbol of interest
:return: returns not (preCondition => condition) where "=>" is logical implication
Determines if a logical condition of a single variable is already contained in a stronger pre_condition
so if from pre_condition follows that condition is always true, then this condition is not necessary
Args:
condition: sympy relational of one variable
pre_condition: logical expression that is known to be true
symbol: the single symbol of interest
Returns:
returns not (pre_condition => condition) where "=>" is logical implication
"""
from sympy.solvers.inequalities import reduce_rational_inequalities
from sympy.logic.boolalg import to_dnf
def normalize_relational(e):
if isinstance(e, sp.Rel):
return e.func(e.lhs - e.rhs, 0)
else:
new_args = [normalize_relational(a) for a in e.args]
return e.func(*new_args) if new_args else e
def to_dnf_list(expr):
result = to_dnf(expr)
if isinstance(result, sp.Or):
......@@ -599,8 +610,12 @@ def is_condition_necessary(condition, pre_condition, symbol):
elif isinstance(result, sp.And):
return [result.args]
else:
return result
return [result]
condition = normalize_relational(condition)
pre_condition = normalize_relational(pre_condition)
a1 = to_dnf_list(pre_condition)
a2 = to_dnf_list(condition)
t1 = reduce_rational_inequalities(to_dnf_list(sp.And(condition, pre_condition)), symbol)
t2 = reduce_rational_inequalities(to_dnf_list(pre_condition), symbol)
return t1 != t2
......@@ -619,12 +634,11 @@ def simplify_boolean_expression(expr, single_variable_ranges):
def visit(e):
if isinstance(e, Relational):
symbols = e.atoms(sp.Symbol)
symbols = e.atoms(sp.Symbol).intersection(single_variable_ranges.keys())
if len(symbols) == 1:
symbol = symbols.pop()
if symbol in single_variable_ranges:
if not is_condition_necessary(e, single_variable_ranges[symbol], symbol):
return sp.true
if not is_condition_necessary(e, single_variable_ranges[symbol], symbol):
return sp.true
return e
else:
new_args = [visit(a) for a in e.args]
......@@ -635,24 +649,23 @@ def simplify_boolean_expression(expr, single_variable_ranges):
def simplify_conditionals(node, loop_conditionals=MappingProxyType({})):
"""Simplifies/Removes conditions inside loops that depend on the loop counter."""
loop_conditionals = loop_conditionals.copy()
if isinstance(node, ast.LoopOverCoordinate):
ctr_sym = node.loop_counter_symbol
loop_conditionals = loop_conditionals.copy()
loop_conditionals[ctr_sym] = sp.And(ctr_sym >= node.start, ctr_sym < node.stop)
simplify_conditionals(node.body)
del loop_conditionals[ctr_sym]
simplify_conditionals(node.body, loop_conditionals)
elif isinstance(node, ast.Conditional):
node.condition_expr = simplify_boolean_expression(node.condition_expr, loop_conditionals)
simplify_conditionals(node.true_block)
if node.false_block:
simplify_conditionals(node.false_block)
simplify_conditionals(node.false_block, loop_conditionals)
if node.condition_expr == sp.true:
node.parent.replace(node, [node.true_block])
if node.condition_expr == sp.false:
node.parent.replace(node, [node.false_block] if node.false_block else [])
elif isinstance(node, ast.Block):
for a in list(node.args):
simplify_conditionals(a)
simplify_conditionals(a, loop_conditionals)
elif isinstance(node, ast.SympyAssignment):
return node
else:
......@@ -829,6 +842,22 @@ def insert_casts(node):
return node.func(*args)
def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -> None:
"""Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last element"""
all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop]
assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
inner_loop = all_inner_loops.pop()
for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
cut_loop(loop, [loop.stop-1])
simplify_conditionals(function_node.body)
cleanup_blocks(function_node.body)
move_constants_before_loop(function_node.body)
cleanup_blocks(function_node.body)
# --------------------------------------- Helper Functions -------------------------------------------------------------
......@@ -836,9 +865,12 @@ def typing_from_sympy_inspection(eqs, default_type="double"):
"""
Creates a default symbol name to type mapping.
If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
:param eqs: list of equations
:param default_type: the type for non-boolean symbols
:return: dictionary, mapping symbol name to type
Args:
eqs: list of equations
default_type: the type for non-boolean symbols
Returns:
dictionary, mapping symbol name to type
"""
result = defaultdict(lambda: default_type)
for eq in eqs:
......@@ -863,6 +895,16 @@ def get_next_parent_of_type(node, parent_type):
return None
def parents_of_type(node, parent_type, include_current=False):
"""Similar to get_next_parent_of_type, but as generator"""
parent = node if include_current else node.parent
while parent is not None:
if isinstance(parent, parent_type):
yield parent
parent = parent.parent
return None
def get_optimal_loop_ordering(fields):
"""
Determines the optimal loop order for a given set of fields.
......@@ -886,9 +928,10 @@ def get_optimal_loop_ordering(fields):
def get_loop_hierarchy(ast_node):
"""Determines the loop structure around a given AST node.
:param ast_node: the AST node
:return: list of coordinate ids, where the first list entry is the innermost loop
"""Determines the loop structure around a given AST node, i.e. the node has to be inside the loops.
Returns:
sequence of LoopOverCoordinate nodes, starting from outer loop to innermost loop
"""
result = []
node = ast_node
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment