Commit ca31d3cc authored by Martin Bauer's avatar Martin Bauer
Browse files

Additional tests

parent 0d6422fb
...@@ -17,12 +17,12 @@ class Node: ...@@ -17,12 +17,12 @@ class Node:
@property @property
def args(self) -> List[NodeOrExpr]: def args(self) -> List[NodeOrExpr]:
"""Returns all arguments/children of this node.""" """Returns all arguments/children of this node."""
return [] raise NotImplementedError()
@property @property
def symbols_defined(self) -> Set[sp.Symbol]: def symbols_defined(self) -> Set[sp.Symbol]:
"""Set of symbols which are defined by this node.""" """Set of symbols which are defined by this node."""
return set() raise NotImplementedError()
@property @property
def undefined_symbols(self) -> Set[sp.Symbol]: def undefined_symbols(self) -> Set[sp.Symbol]:
......
...@@ -70,10 +70,12 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke ...@@ -70,10 +70,12 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
split_inner_loop(code, typed_split_groups) split_inner_loop(code, typed_split_groups)
base_pointer_spec = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']] base_pointer_spec = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']]
base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order, field) base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order,
field.spatial_dimensions, field.index_dimensions)
for field in fields_without_buffers} for field in fields_without_buffers}
buffer_base_pointer_info = {field.name: parse_base_pointer_info([['spatialInner0']], [0], field) buffer_base_pointer_info = {field.name: parse_base_pointer_info([['spatialInner0']], [0],
field.spatial_dimensions, field.index_dimensions)
for field in buffers} for field in buffers}
base_pointer_info.update(buffer_base_pointer_info) base_pointer_info.update(buffer_base_pointer_info)
......
...@@ -61,6 +61,7 @@ def make_python_function(kernel_function_node, argument_dict=None): ...@@ -61,6 +61,7 @@ def make_python_function(kernel_function_node, argument_dict=None):
# cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called # cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
wrapper.ast = kernel_function_node wrapper.ast = kernel_function_node
wrapper.parameters = kernel_function_node.parameters wrapper.parameters = kernel_function_node.parameters
wrapper.num_regs = func.num_regs
return wrapper return wrapper
......
...@@ -56,8 +56,10 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde ...@@ -56,8 +56,10 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde
ast = KernelFunction(block, function_name=function_name, ghost_layers=ghost_layers, backend='gpucuda') ast = KernelFunction(block, function_name=function_name, ghost_layers=ghost_layers, backend='gpucuda')
ast.global_variables.update(indexing.index_variables) ast.global_variables.update(indexing.index_variables)
base_pointer_info = [['spatialInner0']] base_pointer_spec = [['spatialInner0']]
base_pointer_infos = {f.name: parse_base_pointer_info(base_pointer_info, [2, 1, 0], f) for f in all_fields} base_pointer_info = {f.name: parse_base_pointer_info(base_pointer_spec, [2, 1, 0],
f.spatial_dimensions, f.index_dimensions)
for f in all_fields}
coord_mapping = {f.name: cell_idx_symbols for f in all_fields} coord_mapping = {f.name: cell_idx_symbols for f in all_fields}
...@@ -71,7 +73,7 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde ...@@ -71,7 +73,7 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde
base_buffer_index += var * stride base_buffer_index += var * stride
resolve_buffer_accesses(ast, base_buffer_index, read_only_fields) resolve_buffer_accesses(ast, base_buffer_index, read_only_fields)
resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_infos, resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_info,
field_to_fixed_coordinates=coord_mapping) field_to_fixed_coordinates=coord_mapping)
substitute_array_accesses_with_constants(ast) substitute_array_accesses_with_constants(ast)
...@@ -131,13 +133,15 @@ def created_indexed_cuda_kernel(assignments, index_fields, function_name="kernel ...@@ -131,13 +133,15 @@ def created_indexed_cuda_kernel(assignments, index_fields, function_name="kernel
ast.global_variables.update(indexing.index_variables) ast.global_variables.update(indexing.index_variables)
coord_mapping = indexing.coordinates coord_mapping = indexing.coordinates
base_pointer_info = [['spatialInner0']] base_pointer_spec = [['spatialInner0']]
base_pointer_infos = {f.name: parse_base_pointer_info(base_pointer_info, [2, 1, 0], f) for f in all_fields} base_pointer_info = {f.name: parse_base_pointer_info(base_pointer_spec, [2, 1, 0],
f.spatial_dimensions, f.index_dimensions)
for f in all_fields}
coord_mapping = {f.name: coord_mapping for f in index_fields} coord_mapping = {f.name: coord_mapping for f in index_fields}
coord_mapping.update({f.name: coordinate_typed_symbols for f in non_index_fields}) coord_mapping.update({f.name: coordinate_typed_symbols for f in non_index_fields})
resolve_field_accesses(ast, read_only_fields, field_to_fixed_coordinates=coord_mapping, resolve_field_accesses(ast, read_only_fields, field_to_fixed_coordinates=coord_mapping,
field_to_base_pointer_info=base_pointer_infos) field_to_base_pointer_info=base_pointer_info)
substitute_array_accesses_with_constants(ast) substitute_array_accesses_with_constants(ast)
# add the function which determines #blocks and #threads as additional member to KernelFunction node # add the function which determines #blocks and #threads as additional member to KernelFunction node
......
...@@ -191,7 +191,7 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr): ...@@ -191,7 +191,7 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
return new_ptr, offset return new_ptr, offset
def parse_base_pointer_info(base_pointer_specification, loop_order, field): def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dimensions, index_dimensions):
""" """
Creates base pointer specification for :func:`resolve_field_accesses` function. Creates base pointer specification for :func:`resolve_field_accesses` function.
...@@ -210,10 +210,16 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, field): ...@@ -210,10 +210,16 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, field):
Args: Args:
base_pointer_specification: nested list with above specifications base_pointer_specification: nested list with above specifications
loop_order: list with ordering of loops from outer to inner loop_order: list with ordering of loops from outer to inner
field: spatial_dimensions: number of spatial dimensions
index_dimensions: number of index dimensions
Returns: Returns:
list of tuples that can be passed to :func:`resolve_field_accesses` list of tuples that can be passed to :func:`resolve_field_accesses`
Examples:
>>> parse_base_pointer_info([['spatialOuter0'], ['index0']], loop_order=[2,1,0],
... spatial_dimensions=3, index_dimensions=1)
[[0], [3], [1, 2]]
""" """
result = [] result = []
specified_coordinates = set() specified_coordinates = set()
...@@ -222,7 +228,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, field): ...@@ -222,7 +228,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, field):
new_group = [] new_group = []
def add_new_element(elem): def add_new_element(elem):
if elem >= field.spatial_dimensions + field.index_dimensions: if elem >= spatial_dimensions + index_dimensions:
raise ValueError("Coordinate %d does not exist" % (elem,)) raise ValueError("Coordinate %d does not exist" % (elem,))
new_group.append(elem) new_group.append(elem)
if elem in specified_coordinates: if elem in specified_coordinates:
...@@ -240,19 +246,19 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, field): ...@@ -240,19 +246,19 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, field):
index = int(element[len("Outer"):]) index = int(element[len("Outer"):])
add_new_element(loop_order[-index]) add_new_element(loop_order[-index])
elif element == "all": elif element == "all":
for i in range(field.spatial_dimensions): for i in range(spatial_dimensions):
add_new_element(i) add_new_element(i)
else: else:
raise ValueError("Could not parse " + element) raise ValueError("Could not parse " + element)
elif element.startswith("index"): elif element.startswith("index"):
index = int(element[len("index"):]) index = int(element[len("index"):])
add_new_element(field.spatial_dimensions + index) add_new_element(spatial_dimensions + index)
else: else:
raise ValueError("Unknown specification %s" % (element,)) raise ValueError("Unknown specification %s" % (element,))
result.append(new_group) result.append(new_group)
all_coordinates = set(range(field.spatial_dimensions + field.index_dimensions)) all_coordinates = set(range(spatial_dimensions + index_dimensions))
rest = all_coordinates - specified_coordinates rest = all_coordinates - specified_coordinates
if rest: if rest:
result.append(list(rest)) result.append(list(rest))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment