Skip to content
Snippets Groups Projects
Commit 32c4acb5 authored by Jan Hönig's avatar Jan Hönig
Browse files

Fixing pipelines.

parent 5024a79b
1 merge request!292Rebase of pystencils Type System
from typing import Union from typing import Union
import sympy as sp
import numpy as np import numpy as np
import pystencils.astnodes as ast import pystencils.astnodes as ast
...@@ -8,9 +9,9 @@ from pystencils.config import CreateKernelConfig ...@@ -8,9 +9,9 @@ from pystencils.config import CreateKernelConfig
from pystencils.enums import Target, Backend from pystencils.enums import Target, Backend
from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
from pystencils.cpu.cpujit import make_python_function from pystencils.cpu.cpujit import make_python_function
from pystencils.typing import StructType, TypedSymbol from pystencils.typing import StructType, TypedSymbol, create_type
from pystencils.typing.transformations import add_types from pystencils.typing.transformations import add_types
from pystencils.field import FieldType from pystencils.field import Field, FieldType
from pystencils.node_collection import NodeCollection from pystencils.node_collection import NodeCollection
from pystencils.transformations import ( from pystencils.transformations import (
filtered_tree_iteration, get_base_buffer_index, get_optimal_loop_ordering, make_loop_over_domain, filtered_tree_iteration, get_base_buffer_index, get_optimal_loop_ordering, make_loop_over_domain,
...@@ -33,7 +34,6 @@ def create_kernel(assignments: Union[AssignmentCollection, NodeCollection], ...@@ -33,7 +34,6 @@ def create_kernel(assignments: Union[AssignmentCollection, NodeCollection],
AST node representing a function, that can be printed as C or CUDA code AST node representing a function, that can be printed as C or CUDA code
""" """
function_name = config.function_name function_name = config.function_name
type_info = config.data_type
iteration_slice = config.iteration_slice iteration_slice = config.iteration_slice
ghost_layers = config.ghost_layers ghost_layers = config.ghost_layers
fields_written = assignments.bound_fields fields_written = assignments.bound_fields
...@@ -44,7 +44,7 @@ def create_kernel(assignments: Union[AssignmentCollection, NodeCollection], ...@@ -44,7 +44,7 @@ def create_kernel(assignments: Union[AssignmentCollection, NodeCollection],
split_groups = assignments.simplification_hints['split_groups'] split_groups = assignments.simplification_hints['split_groups']
assignments = assignments.all_assignments assignments = assignments.all_assignments
# TODO Jan Cleanup: move add_types to create_domain_kernel or create_kernel? # TODO Cleanup: move add_types to create_domain_kernel or create_kernel
assignments = add_types(assignments, config) assignments = add_types(assignments, config)
all_fields = fields_read.union(fields_written) all_fields = fields_read.union(fields_written)
...@@ -61,7 +61,21 @@ def create_kernel(assignments: Union[AssignmentCollection, NodeCollection], ...@@ -61,7 +61,21 @@ def create_kernel(assignments: Union[AssignmentCollection, NodeCollection],
ghost_layers=ghost_layer_info, function_name=function_name, assignments=assignments) ghost_layers=ghost_layer_info, function_name=function_name, assignments=assignments)
if split_groups: if split_groups:
split_inner_loop(ast_node, split_groups) type_info = config.data_type
def type_symbol(term):
if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
if isinstance(type_info, str) or not hasattr(type_info, '__getitem__'):
return TypedSymbol(term.name, create_type(type_info))
else:
return TypedSymbol(term.name, type_info[term.name])
else:
raise ValueError("Term has to be field access or symbol")
typed_split_groups = [[type_symbol(s) for s in split_group] for split_group in split_groups]
split_inner_loop(ast_node, typed_split_groups)
base_pointer_spec = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']] base_pointer_spec = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']]
base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order, base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order,
...@@ -75,6 +89,7 @@ def create_kernel(assignments: Union[AssignmentCollection, NodeCollection], ...@@ -75,6 +89,7 @@ def create_kernel(assignments: Union[AssignmentCollection, NodeCollection],
if any(FieldType.is_buffer(f) for f in all_fields): if any(FieldType.is_buffer(f) for f in all_fields):
resolve_buffer_accesses(ast_node, get_base_buffer_index(ast_node), read_only_fields) resolve_buffer_accesses(ast_node, get_base_buffer_index(ast_node), read_only_fields)
# TODO think about typing
resolve_field_accesses(ast_node, read_only_fields, field_to_base_pointer_info=base_pointer_info) resolve_field_accesses(ast_node, read_only_fields, field_to_base_pointer_info=base_pointer_info)
move_constants_before_loop(ast_node) move_constants_before_loop(ast_node)
return ast_node return ast_node
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment