Commit 99aef3f8 authored by Martin Bauer's avatar Martin Bauer
Browse files

Refactored buffer treatment

- put all buffer related stuff into separate functions
- should be functionally equivalent
parent f83ceab3
...@@ -3,7 +3,7 @@ from functools import partial ...@@ -3,7 +3,7 @@ from functools import partial
from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
from pystencils.transformations import resolve_buffer_accesses, resolve_field_accesses, make_loop_over_domain, \ from pystencils.transformations import resolve_buffer_accesses, resolve_field_accesses, make_loop_over_domain, \
add_types, get_optimal_loop_ordering, parse_base_pointer_info, move_constants_before_loop, \ add_types, get_optimal_loop_ordering, parse_base_pointer_info, move_constants_before_loop, \
split_inner_loop, substitute_array_accesses_with_constants split_inner_loop, substitute_array_accesses_with_constants, get_base_buffer_index
from pystencils.data_types import TypedSymbol, BasicType, StructType, create_type from pystencils.data_types import TypedSymbol, BasicType, StructType, create_type
from pystencils.field import Field, FieldType from pystencils.field import Field, FieldType
import pystencils.astnodes as ast import pystencils.astnodes as ast
...@@ -61,13 +61,13 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke ...@@ -61,13 +61,13 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
body = ast.Block(assignments) body = ast.Block(assignments)
loop_order = get_optimal_loop_ordering(fields_without_buffers) loop_order = get_optimal_loop_ordering(fields_without_buffers)
code, loop_strides, loop_vars = make_loop_over_domain(body, function_name, iteration_slice=iteration_slice, ast_node = make_loop_over_domain(body, function_name, iteration_slice=iteration_slice,
ghost_layers=ghost_layers, loop_order=loop_order) ghost_layers=ghost_layers, loop_order=loop_order)
code.target = 'cpu' ast_node.target = 'cpu'
if split_groups: if split_groups:
typed_split_groups = [[type_symbol(s) for s in split_group] for split_group in split_groups] typed_split_groups = [[type_symbol(s) for s in split_group] for split_group in split_groups]
split_inner_loop(code, typed_split_groups) split_inner_loop(ast_node, typed_split_groups)
base_pointer_spec = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']] base_pointer_spec = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']]
base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order, base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order,
...@@ -79,20 +79,13 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke ...@@ -79,20 +79,13 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
for field in buffers} for field in buffers}
base_pointer_info.update(buffer_base_pointer_info) base_pointer_info.update(buffer_base_pointer_info)
base_buffer_index = loop_vars[0] if any(FieldType.is_buffer(f) for f in all_fields):
stride = 1 resolve_buffer_accesses(ast_node, get_base_buffer_index(ast_node), read_only_fields)
for idx, var in enumerate(loop_vars[1:]): resolve_field_accesses(ast_node, read_only_fields, field_to_base_pointer_info=base_pointer_info)
cur_stride = loop_strides[idx] substitute_array_accesses_with_constants(ast_node)
stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride move_constants_before_loop(ast_node)
base_buffer_index += var * stride ast_node.compile = partial(make_python_function, ast_node)
return ast_node
resolve_buffer_accesses(code, base_buffer_index, read_only_fields)
resolve_field_accesses(code, read_only_fields, field_to_base_pointer_info=base_pointer_info)
substitute_array_accesses_with_constants(code)
move_constants_before_loop(code)
code.compile = partial(make_python_function, code)
return code
def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, function_name="kernel", def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, function_name="kernel",
......
...@@ -2,7 +2,8 @@ from functools import partial ...@@ -2,7 +2,8 @@ from functools import partial
from pystencils.gpucuda.indexing import BlockIndexing from pystencils.gpucuda.indexing import BlockIndexing
from pystencils.transformations import resolve_field_accesses, add_types, parse_base_pointer_info, \ from pystencils.transformations import resolve_field_accesses, add_types, parse_base_pointer_info, \
get_common_shape, substitute_array_accesses_with_constants, resolve_buffer_accesses, unify_shape_symbols get_common_shape, substitute_array_accesses_with_constants, resolve_buffer_accesses, unify_shape_symbols, \
get_base_buffer_index
from pystencils.astnodes import Block, KernelFunction, SympyAssignment, LoopOverCoordinate from pystencils.astnodes import Block, KernelFunction, SympyAssignment, LoopOverCoordinate
from pystencils.data_types import TypedSymbol, BasicType, StructType from pystencils.data_types import TypedSymbol, BasicType, StructType
from pystencils import Field, FieldType from pystencils import Field, FieldType
...@@ -63,16 +64,11 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde ...@@ -63,16 +64,11 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde
coord_mapping = {f.name: cell_idx_symbols for f in all_fields} coord_mapping = {f.name: cell_idx_symbols for f in all_fields}
loop_vars = [num_buffer_accesses * i for i in indexing.coordinates]
loop_strides = list(fields_without_buffers)[0].shape loop_strides = list(fields_without_buffers)[0].shape
base_buffer_index = loop_vars[0] if any(FieldType.is_buffer(f) for f in all_fields):
stride = 1 resolve_buffer_accesses(ast, get_base_buffer_index(ast, indexing.coordinates, loop_strides), read_only_fields)
for idx, var in enumerate(loop_vars[1:]):
stride *= loop_strides[idx]
base_buffer_index += var * stride
resolve_buffer_accesses(ast, base_buffer_index, read_only_fields)
resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_info, resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_info,
field_to_fixed_coordinates=coord_mapping) field_to_fixed_coordinates=coord_mapping)
......
...@@ -73,20 +73,6 @@ def get_common_shape(field_set): ...@@ -73,20 +73,6 @@ def get_common_shape(field_set):
return shape return shape
def get_field_accesses(expr, result=set()):
if isinstance(expr, Field.Access):
result.add(expr)
for o in expr.offsets:
get_field_accesses(o, result)
for i in expr.index:
get_field_accesses(i, result)
elif hasattr(expr, 'atoms'):
new_accesses = expr.atoms(Field.Access)
result.update(new_accesses)
for a in new_accesses:
get_field_accesses(a, result)
def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layers=None, loop_order=None): def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layers=None, loop_order=None):
"""Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST. """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.
...@@ -103,14 +89,12 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer ...@@ -103,14 +89,12 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer
:class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts :class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts
""" """
# find correct ordering by inspecting participating FieldAccesses # find correct ordering by inspecting participating FieldAccesses
field_accesses = set() field_accesses = body.atoms(Field.Access)
get_field_accesses(body, field_accesses)
field_accesses = {e for e in field_accesses if not e.is_absolute_access} field_accesses = {e for e in field_accesses if not e.is_absolute_access}
# exclude accesses to buffers from field_list, because buffers are treated separately # exclude accesses to buffers from field_list, because buffers are treated separately
field_list = [e.field for e in field_accesses if not FieldType.is_buffer(e.field)] field_list = [e.field for e in field_accesses if not FieldType.is_buffer(e.field)]
fields = set(field_list) fields = set(field_list)
num_buffer_accesses = len(field_accesses) - len(field_list)
if loop_order is None: if loop_order is None:
loop_order = get_optimal_loop_ordering(fields) loop_order = get_optimal_loop_ordering(fields)
...@@ -127,11 +111,6 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer ...@@ -127,11 +111,6 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer
if isinstance(ghost_layers, int): if isinstance(ghost_layers, int):
ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order) ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)
def get_loop_stride(loop_begin, loop_end, step):
return (loop_end - loop_begin) / step
loop_strides = []
loop_vars = []
current_body = body current_body = body
for i, loop_coordinate in enumerate(reversed(loop_order)): for i, loop_coordinate in enumerate(reversed(loop_order)):
if iteration_slice is None: if iteration_slice is None:
...@@ -139,24 +118,19 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer ...@@ -139,24 +118,19 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer
end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1] end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1) new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
current_body = ast.Block([new_loop]) current_body = ast.Block([new_loop])
loop_strides.append(get_loop_stride(begin, end, 1))
loop_vars.append(new_loop.loop_counter_symbol)
else: else:
slice_component = iteration_slice[loop_coordinate] slice_component = iteration_slice[loop_coordinate]
if type(slice_component) is slice: if type(slice_component) is slice:
sc = slice_component sc = slice_component
new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step) new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step)
current_body = ast.Block([new_loop]) current_body = ast.Block([new_loop])
loop_strides.append(get_loop_stride(sc.start, sc.stop, sc.step))
loop_vars.append(new_loop.loop_counter_symbol)
else: else:
assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate), assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate),
sp.sympify(slice_component)) sp.sympify(slice_component))
current_body.insert_front(assignment) current_body.insert_front(assignment)
loop_vars = [num_buffer_accesses * var for var in loop_vars]
ast_node = ast.KernelFunction(current_body, ghost_layers=ghost_layers, function_name=function_name, backend='cpu') ast_node = ast.KernelFunction(current_body, ghost_layers=ghost_layers, function_name=function_name, backend='cpu')
return ast_node, loop_strides, loop_vars return ast_node
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr): def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
...@@ -341,7 +315,43 @@ def substitute_array_accesses_with_constants(ast_node): ...@@ -341,7 +315,43 @@ def substitute_array_accesses_with_constants(ast_node):
substitute_array_accesses_with_constants(a) substitute_array_accesses_with_constants(a)
def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
"""Used for buffer fields to determine the linearized index of the buffer dependent on loop counter symbols.
Args:
ast_node: ast before any field accesses are resolved
loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
for GPU kernels: list of 'loop counters' from inner to outer loop
loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default
Returns:
base buffer index - required by 'resolve_buffer_accesses' function
"""
if loop_counters is None or loop_iterations is None:
loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)]
loops.reverse()
parents_of_innermost_loop = list(parents_of_type(loops[0], ast.LoopOverCoordinate, include_current=True))
assert len(loops) == len(parents_of_innermost_loop)
assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
loop_iterations = [(l.stop - l.start) / l.step for l in loops]
loop_counters = [l.loop_counter_symbol for l in loops]
field_accesses = ast_node.atoms(Field.Access)
buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
loop_counters = [v * len(buffer_accesses) for v in loop_counters]
base_buffer_index = loop_counters[0]
stride = 1
for idx, var in enumerate(loop_counters[1:]):
cur_stride = loop_iterations[idx]
stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
base_buffer_index += var * stride
return base_buffer_index
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()): def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
def visit_sympy_expr(expr, enclosing_block, sympy_assignment): def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, Field.Access): if isinstance(expr, Field.Access):
field_access = expr field_access = expr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment