Commit 32c4acb5 authored by Jan Hönig's avatar Jan Hönig
Browse files

Fixing pipelines.

parent 5024a79b
from typing import Union
import sympy as sp
import numpy as np
import pystencils.astnodes as ast
......@@ -8,9 +9,9 @@ from pystencils.config import CreateKernelConfig
from pystencils.enums import Target, Backend
from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
from pystencils.cpu.cpujit import make_python_function
from pystencils.typing import StructType, TypedSymbol
from pystencils.typing import StructType, TypedSymbol, create_type
from pystencils.typing.transformations import add_types
from pystencils.field import FieldType
from pystencils.field import Field, FieldType
from pystencils.node_collection import NodeCollection
from pystencils.transformations import (
filtered_tree_iteration, get_base_buffer_index, get_optimal_loop_ordering, make_loop_over_domain,
......@@ -33,7 +34,6 @@ def create_kernel(assignments: Union[AssignmentCollection, NodeCollection],
AST node representing a function, that can be printed as C or CUDA code
function_name = config.function_name
type_info = config.data_type
iteration_slice = config.iteration_slice
ghost_layers = config.ghost_layers
fields_written = assignments.bound_fields
......@@ -44,7 +44,7 @@ def create_kernel(assignments: Union[AssignmentCollection, NodeCollection],
split_groups = assignments.simplification_hints['split_groups']
assignments = assignments.all_assignments
# TODO Jan Cleanup: move add_types to create_domain_kernel or create_kernel?
# TODO Cleanup: move add_types to create_domain_kernel or create_kernel
assignments = add_types(assignments, config)
all_fields = fields_read.union(fields_written)
......@@ -61,7 +61,21 @@ def create_kernel(assignments: Union[AssignmentCollection, NodeCollection],
ghost_layers=ghost_layer_info, function_name=function_name, assignments=assignments)
if split_groups:
split_inner_loop(ast_node, split_groups)
type_info = config.data_type
def type_symbol(term):
if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
if isinstance(type_info, str) or not hasattr(type_info, '__getitem__'):
return TypedSymbol(, create_type(type_info))
return TypedSymbol(, type_info[])
raise ValueError("Term has to be field access or symbol")
typed_split_groups = [[type_symbol(s) for s in split_group] for split_group in split_groups]
split_inner_loop(ast_node, typed_split_groups)
base_pointer_spec = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']]
base_pointer_info = { parse_base_pointer_info(base_pointer_spec, loop_order,
......@@ -75,6 +89,7 @@ def create_kernel(assignments: Union[AssignmentCollection, NodeCollection],
if any(FieldType.is_buffer(f) for f in all_fields):
resolve_buffer_accesses(ast_node, get_base_buffer_index(ast_node), read_only_fields)
# TODO think about typing
resolve_field_accesses(ast_node, read_only_fields, field_to_base_pointer_info=base_pointer_info)
return ast_node
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment