diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py index 21bc93db39691a279800b10e3eb4b500ab620f56..8870e65b374eece2901a93b694b3f1e91c8a2d25 100644 --- a/pystencils/cpu/kernelcreation.py +++ b/pystencils/cpu/kernelcreation.py @@ -1,5 +1,6 @@ from typing import Union +import sympy as sp import numpy as np import pystencils.astnodes as ast @@ -8,9 +9,9 @@ from pystencils.config import CreateKernelConfig from pystencils.enums import Target, Backend from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment from pystencils.cpu.cpujit import make_python_function -from pystencils.typing import StructType, TypedSymbol +from pystencils.typing import StructType, TypedSymbol, create_type from pystencils.typing.transformations import add_types -from pystencils.field import FieldType +from pystencils.field import Field, FieldType from pystencils.node_collection import NodeCollection from pystencils.transformations import ( filtered_tree_iteration, get_base_buffer_index, get_optimal_loop_ordering, make_loop_over_domain, @@ -33,7 +34,6 @@ def create_kernel(assignments: Union[AssignmentCollection, NodeCollection], AST node representing a function, that can be printed as C or CUDA code """ function_name = config.function_name - type_info = config.data_type iteration_slice = config.iteration_slice ghost_layers = config.ghost_layers fields_written = assignments.bound_fields @@ -44,7 +44,7 @@ def create_kernel(assignments: Union[AssignmentCollection, NodeCollection], split_groups = assignments.simplification_hints['split_groups'] assignments = assignments.all_assignments - # TODO Jan Cleanup: move add_types to create_domain_kernel or create_kernel? + # TODO Cleanup: move add_types to create_domain_kernel or create_kernel assignments = add_types(assignments, config) all_fields = fields_read.union(fields_written) @@ -61,7 +61,21 @@ def create_kernel(assignments: Union[AssignmentCollection, NodeCollection], ghost_layers=ghost_layer_info, function_name=function_name, assignments=assignments) if split_groups: - split_inner_loop(ast_node, split_groups) + type_info = config.data_type + + def type_symbol(term): + if isinstance(term, Field.Access) or isinstance(term, TypedSymbol): + return term + elif isinstance(term, sp.Symbol): + if isinstance(type_info, str) or not hasattr(type_info, '__getitem__'): + return TypedSymbol(term.name, create_type(type_info)) + else: + return TypedSymbol(term.name, type_info[term.name]) + else: + raise ValueError("Term has to be field access or symbol") + + typed_split_groups = [[type_symbol(s) for s in split_group] for split_group in split_groups] + split_inner_loop(ast_node, typed_split_groups) base_pointer_spec = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']] base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order, @@ -75,6 +89,7 @@ def create_kernel(assignments: Union[AssignmentCollection, NodeCollection], if any(FieldType.is_buffer(f) for f in all_fields): resolve_buffer_accesses(ast_node, get_base_buffer_index(ast_node), read_only_fields) + # TODO think about typing resolve_field_accesses(ast_node, read_only_fields, field_to_base_pointer_info=base_pointer_info) move_constants_before_loop(ast_node) return ast_node