Commit 755f168c authored by Martin Bauer's avatar Martin Bauer
Browse files

Fixes in field index resolve

- move_constants_before_loop: in the rare case that a symbol
  with same name exists in move target - the assignment is not moved
  previously an exception was raised in this case
- consistent naming of intermediate base pointers
parent 627ad747
...@@ -6,7 +6,7 @@ import sympy as sp ...@@ -6,7 +6,7 @@ import sympy as sp
from sympy.logic.boolalg import Boolean from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase from sympy.tensor import IndexedBase
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.field import Field, FieldType, offset_component_to_direction_string from pystencils.field import Field, FieldType
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \ from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \
pointer_arithmetic_func, get_type_of_expression, collate_types pointer_arithmetic_func, get_type_of_expression, collate_types
from pystencils.slicing import normalize_slice from pystencils.slicing import normalize_slice
...@@ -158,9 +158,9 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr): ...@@ -158,9 +158,9 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
>>> x, y = sp.symbols("x y") >>> x, y = sp.symbols("x y")
>>> prev_pointer = TypedSymbol("ptr", "double") >>> prev_pointer = TypedSymbol("ptr", "double")
>>> create_intermediate_base_pointer(field[1,-2](5), {0: x}, prev_pointer) >>> create_intermediate_base_pointer(field[1,-2](5), {0: x}, prev_pointer)
(ptr_E, x*fstride_myfield[0] + fstride_myfield[0]) (ptr_01, x*fstride_myfield[0] + fstride_myfield[0])
>>> create_intermediate_base_pointer(field[1,-2](5), {0: x, 1 : y }, prev_pointer) >>> create_intermediate_base_pointer(field[1,-2](5), {0: x, 1 : y }, prev_pointer)
(ptr_E_2S, x*fstride_myfield[0] + y*fstride_myfield[1] + fstride_myfield[0] - 2*fstride_myfield[1]) (ptr_01_1m2, x*fstride_myfield[0] + y*fstride_myfield[1] + fstride_myfield[0] - 2*fstride_myfield[1])
""" """
field = field_access.field field = field_access.field
offset = 0 offset = 0
...@@ -172,22 +172,20 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr): ...@@ -172,22 +172,20 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
if coordinate_id < field.spatial_dimensions: if coordinate_id < field.spatial_dimensions:
offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id] offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id]
if type(field_access.offsets[coordinate_id]) is int: if type(field_access.offsets[coordinate_id]) is int:
offset_comp = offset_component_to_direction_string(coordinate_id, field_access.offsets[coordinate_id]) name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id])
name += "_"
name += offset_comp if offset_comp else "C"
else: else:
list_to_hash.append(field_access.offsets[coordinate_id]) list_to_hash.append(field_access.offsets[coordinate_id])
else: else:
if type(coordinate_value) is int: if type(coordinate_value) is int:
name += "_%d" % (coordinate_value,) name += "_%d%d" % (coordinate_id, coordinate_value)
else: else:
list_to_hash.append(coordinate_value) list_to_hash.append(coordinate_value)
if len(list_to_hash) > 0: if len(list_to_hash) > 0:
name += "%0.6X" % (abs(hash(tuple(list_to_hash)))) name += "_%0.6X" % (hash(tuple(list_to_hash)))
name = name.replace("-", 'm')
new_ptr = TypedSymbol(previous_ptr.name + name, previous_ptr.dtype) new_ptr = TypedSymbol(previous_ptr.name + name, previous_ptr.dtype)
return new_ptr, offset return new_ptr, offset
...@@ -503,7 +501,7 @@ def move_constants_before_loop(ast_node): ...@@ -503,7 +501,7 @@ def move_constants_before_loop(ast_node):
def get_blocks(node, result_list): def get_blocks(node, result_list):
if isinstance(node, ast.Block): if isinstance(node, ast.Block):
result_list.insert(0, node) result_list.append(node)
if isinstance(node, ast.Node): if isinstance(node, ast.Node):
for a in node.args: for a in node.args:
get_blocks(a, result_list) get_blocks(a, result_list)
...@@ -521,10 +519,13 @@ def move_constants_before_loop(ast_node): ...@@ -521,10 +519,13 @@ def move_constants_before_loop(ast_node):
exists_already = check_if_assignment_already_in_block(child, target) exists_already = check_if_assignment_already_in_block(child, target)
else: else:
exists_already = False exists_already = False
if not exists_already: if not exists_already:
target.insert_before(child, child_to_insert_before) target.insert_before(child, child_to_insert_before)
elif exists_already and exists_already.rhs == child.rhs:
pass
else: else:
assert exists_already.rhs == child.rhs, "Symbol with same name exists already" block.append(child) # don't move in this case - better would be to rename symbol
def split_inner_loop(ast_node: ast.Node, symbol_groups): def split_inner_loop(ast_node: ast.Node, symbol_groups):
...@@ -880,6 +881,7 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) - ...@@ -880,6 +881,7 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -
simplify_conditionals(function_node.body, loop_counter_simplification=True) simplify_conditionals(function_node.body, loop_counter_simplification=True)
cleanup_blocks(function_node.body) cleanup_blocks(function_node.body)
move_constants_before_loop(function_node.body) move_constants_before_loop(function_node.body)
cleanup_blocks(function_node.body) cleanup_blocks(function_node.body)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment