Commit 1c3e9880 authored by Nils Kohl's avatar Nils Kohl 🌝 Committed by Martin Bauer
Browse files

Added simplifications & transformations

parent 3d8dd38f
......@@ -3,6 +3,7 @@ from typing import Callable, List
from pystencils import Field
from pystencils.assignment import Assignment
from pystencils.field import AbstractField
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.sympyextensions import subs_additive
......@@ -85,6 +86,33 @@ def add_subexpressions_for_divisions(ac: AC) -> AC:
return ac.new_with_substitutions(substitutions, True)
def add_subexpressions_for_sums(ac: AC) -> AC:
r"""Introduces subexpressions for all sums - i.e. splits addends into subexpressions."""
addends = []
def contains_sum(term):
if term.func == sp.add.Add:
return True
if term.is_Atom:
return False
return any([contains_sum(a) for a in term.args])
def search_addends(term):
if term.func == sp.add.Add:
if all([not contains_sum(a) for a in term.args]):
addends.extend(term.args)
for a in term.args:
search_addends(a)
for eq in ac.all_assignments:
search_addends(eq.rhs)
addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, AbstractField.AbstractAccess)]
new_symbol_gen = ac.subexpression_symbol_generator
substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)}
return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)
def add_subexpressions_for_field_reads(ac: AC, subexpressions=True, main_assignments=True) -> AC:
r"""Substitutes field accesses on rhs of assignments with subexpressions
......
......@@ -702,7 +702,7 @@ def cut_loop(loop_node, cutting_points):
"""
if loop_node.step != 1:
raise NotImplementedError("Can only split loops that have a step of 1")
new_loops = []
new_loops = ast.Block([])
new_start = loop_node.start
cutting_points = list(cutting_points) + [loop_node.stop]
for new_end in cutting_points:
......@@ -1089,6 +1089,35 @@ def get_optimal_loop_ordering(fields):
return list(layout)
def get_loop_hierarchy(ast_node):
"""Determines the loop structure around a given AST node, i.e. the node has to be inside the loops.
Returns:
sequence of LoopOverCoordinate nodes, starting from outer loop to innermost loop
"""
result = []
node = ast_node
while node is not None:
node = get_next_parent_of_type(node, ast.LoopOverCoordinate)
if node:
result.append(node.coordinate_to_loop_over)
return reversed(result)
def get_loop_counter_symbol_hierarchy(astNode):
"""Determines the loop counter symbols around a given AST node.
:param astNode: the AST node
:return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop
"""
result = []
node = astNode
while node is not None:
node = get_next_parent_of_type(node, ast.LoopOverCoordinate)
if node:
result.append(node.loop_counter_symbol)
return result
def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None:
"""Replaces the stride of the innermost loop of a variable sized kernel with 1 (assumes optimal loop ordering).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment