Commit ca31d3cc authored by Martin Bauer's avatar Martin Bauer
Browse files

Additional tests

parent 0d6422fb
......@@ -17,12 +17,12 @@ class Node:
@property
def args(self) -> List[NodeOrExpr]:
"""Returns all arguments/children of this node."""
return []
raise NotImplementedError()
@property
def symbols_defined(self) -> Set[sp.Symbol]:
"""Set of symbols which are defined by this node."""
return set()
raise NotImplementedError()
@property
def undefined_symbols(self) -> Set[sp.Symbol]:
......
......@@ -70,10 +70,12 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
split_inner_loop(code, typed_split_groups)
base_pointer_spec = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']]
base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order, field)
base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order,
field.spatial_dimensions, field.index_dimensions)
for field in fields_without_buffers}
buffer_base_pointer_info = {field.name: parse_base_pointer_info([['spatialInner0']], [0], field)
buffer_base_pointer_info = {field.name: parse_base_pointer_info([['spatialInner0']], [0],
field.spatial_dimensions, field.index_dimensions)
for field in buffers}
base_pointer_info.update(buffer_base_pointer_info)
......
......@@ -61,6 +61,7 @@ def make_python_function(kernel_function_node, argument_dict=None):
# cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
wrapper.ast = kernel_function_node
wrapper.parameters = kernel_function_node.parameters
wrapper.num_regs = func.num_regs
return wrapper
......
......@@ -56,8 +56,10 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde
ast = KernelFunction(block, function_name=function_name, ghost_layers=ghost_layers, backend='gpucuda')
ast.global_variables.update(indexing.index_variables)
base_pointer_info = [['spatialInner0']]
base_pointer_infos = {f.name: parse_base_pointer_info(base_pointer_info, [2, 1, 0], f) for f in all_fields}
base_pointer_spec = [['spatialInner0']]
base_pointer_info = {f.name: parse_base_pointer_info(base_pointer_spec, [2, 1, 0],
f.spatial_dimensions, f.index_dimensions)
for f in all_fields}
coord_mapping = {f.name: cell_idx_symbols for f in all_fields}
......@@ -71,7 +73,7 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde
base_buffer_index += var * stride
resolve_buffer_accesses(ast, base_buffer_index, read_only_fields)
resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_infos,
resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_info,
field_to_fixed_coordinates=coord_mapping)
substitute_array_accesses_with_constants(ast)
......@@ -131,13 +133,15 @@ def created_indexed_cuda_kernel(assignments, index_fields, function_name="kernel
ast.global_variables.update(indexing.index_variables)
coord_mapping = indexing.coordinates
base_pointer_info = [['spatialInner0']]
base_pointer_infos = {f.name: parse_base_pointer_info(base_pointer_info, [2, 1, 0], f) for f in all_fields}
base_pointer_spec = [['spatialInner0']]
base_pointer_info = {f.name: parse_base_pointer_info(base_pointer_spec, [2, 1, 0],
f.spatial_dimensions, f.index_dimensions)
for f in all_fields}
coord_mapping = {f.name: coord_mapping for f in index_fields}
coord_mapping.update({f.name: coordinate_typed_symbols for f in non_index_fields})
resolve_field_accesses(ast, read_only_fields, field_to_fixed_coordinates=coord_mapping,
field_to_base_pointer_info=base_pointer_infos)
field_to_base_pointer_info=base_pointer_info)
substitute_array_accesses_with_constants(ast)
# add the function which determines #blocks and #threads as additional member to KernelFunction node
......
......@@ -191,7 +191,7 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
return new_ptr, offset
def parse_base_pointer_info(base_pointer_specification, loop_order, field):
def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dimensions, index_dimensions):
"""
Creates base pointer specification for :func:`resolve_field_accesses` function.
......@@ -210,10 +210,16 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, field):
Args:
base_pointer_specification: nested list with above specifications
loop_order: list with ordering of loops from outer to inner
field:
spatial_dimensions: number of spatial dimensions
index_dimensions: number of index dimensions
Returns:
list of tuples that can be passed to :func:`resolve_field_accesses`
Examples:
>>> parse_base_pointer_info([['spatialOuter0'], ['index0']], loop_order=[2,1,0],
... spatial_dimensions=3, index_dimensions=1)
[[0], [3], [1, 2]]
"""
result = []
specified_coordinates = set()
......@@ -222,7 +228,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, field):
new_group = []
def add_new_element(elem):
if elem >= field.spatial_dimensions + field.index_dimensions:
if elem >= spatial_dimensions + index_dimensions:
raise ValueError("Coordinate %d does not exist" % (elem,))
new_group.append(elem)
if elem in specified_coordinates:
......@@ -240,19 +246,19 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, field):
index = int(element[len("Outer"):])
add_new_element(loop_order[-index])
elif element == "all":
for i in range(field.spatial_dimensions):
for i in range(spatial_dimensions):
add_new_element(i)
else:
raise ValueError("Could not parse " + element)
elif element.startswith("index"):
index = int(element[len("index"):])
add_new_element(field.spatial_dimensions + index)
add_new_element(spatial_dimensions + index)
else:
raise ValueError("Unknown specification %s" % (element,))
result.append(new_group)
all_coordinates = set(range(field.spatial_dimensions + field.index_dimensions))
all_coordinates = set(range(spatial_dimensions + index_dimensions))
rest = all_coordinates - specified_coordinates
if rest:
result.append(list(rest))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment