diff --git a/transformations.py b/transformations.py index 946886b314448a336a6e823cc8c108156dee4a55..6bdfcb20e48a79b9b854b01dff352fb7f5415a8a 100644 --- a/transformations.py +++ b/transformations.py @@ -6,7 +6,7 @@ import sympy as sp from sympy.logic.boolalg import Boolean from sympy.tensor import IndexedBase from pystencils.assignment import Assignment -from pystencils.field import Field, FieldType, offset_component_to_direction_string +from pystencils.field import Field, FieldType from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \ pointer_arithmetic_func, get_type_of_expression, collate_types from pystencils.slicing import normalize_slice @@ -158,9 +158,9 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr): >>> x, y = sp.symbols("x y") >>> prev_pointer = TypedSymbol("ptr", "double") >>> create_intermediate_base_pointer(field[1,-2](5), {0: x}, prev_pointer) - (ptr_E, x*fstride_myfield[0] + fstride_myfield[0]) + (ptr_01, x*fstride_myfield[0] + fstride_myfield[0]) >>> create_intermediate_base_pointer(field[1,-2](5), {0: x, 1 : y }, prev_pointer) - (ptr_E_2S, x*fstride_myfield[0] + y*fstride_myfield[1] + fstride_myfield[0] - 2*fstride_myfield[1]) + (ptr_01_1m2, x*fstride_myfield[0] + y*fstride_myfield[1] + fstride_myfield[0] - 2*fstride_myfield[1]) """ field = field_access.field offset = 0 @@ -172,22 +172,20 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr): if coordinate_id < field.spatial_dimensions: offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id] if type(field_access.offsets[coordinate_id]) is int: - offset_comp = offset_component_to_direction_string(coordinate_id, field_access.offsets[coordinate_id]) - name += "_" - name += offset_comp if offset_comp else "C" + name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id]) else: list_to_hash.append(field_access.offsets[coordinate_id]) else: if type(coordinate_value) is int: - name += "_%d" % (coordinate_value,) + name += "_%d%d" % (coordinate_id, coordinate_value) else: list_to_hash.append(coordinate_value) if len(list_to_hash) > 0: - name += "%0.6X" % (abs(hash(tuple(list_to_hash)))) + name += "_%0.6X" % (hash(tuple(list_to_hash))) + name = name.replace("-", 'm') new_ptr = TypedSymbol(previous_ptr.name + name, previous_ptr.dtype) - return new_ptr, offset @@ -503,7 +501,7 @@ def move_constants_before_loop(ast_node): def get_blocks(node, result_list): if isinstance(node, ast.Block): - result_list.insert(0, node) + result_list.append(node) if isinstance(node, ast.Node): for a in node.args: get_blocks(a, result_list) @@ -521,10 +519,13 @@ def move_constants_before_loop(ast_node): exists_already = check_if_assignment_already_in_block(child, target) else: exists_already = False + if not exists_already: target.insert_before(child, child_to_insert_before) + elif exists_already and exists_already.rhs == child.rhs: + pass else: - assert exists_already.rhs == child.rhs, "Symbol with same name exists already" + block.append(child) # don't move in this case - better would be to rename symbol def split_inner_loop(ast_node: ast.Node, symbol_groups): @@ -880,6 +881,7 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) - simplify_conditionals(function_node.body, loop_counter_simplification=True) cleanup_blocks(function_node.body) + move_constants_before_loop(function_node.body) cleanup_blocks(function_node.body)