Commit 7a94740d authored by Martin Bauer's avatar Martin Bauer
Browse files

Pass field information (shape,stride) as single elements instead of arr

- small (length < 5) arrays with shape and stride information had to be
  memcpy'd to the GPU before every kernel call
- instead of passing the information as arrays, the single elements are
  passed
- leads to more function arguments, but simplifies GPU kernel calls

-> changes in all backends required
parent 99aef3f8
......@@ -2,8 +2,9 @@ import sympy as sp
from sympy.tensor import IndexedBase
from pystencils.field import Field
from pystencils.data_types import TypedSymbol, create_type, cast_func
from pystencils.kernelparameters import FieldStrideSymbol, FieldPointerSymbol, FieldShapeSymbol
from pystencils.sympyextensions import fast_subs
from typing import List, Set, Optional, Union, Any
from typing import List, Set, Optional, Union, Any, Sequence
NodeOrExpr = Union['Node', sp.Expr]
......@@ -120,62 +121,48 @@ class Conditional(Node):
class KernelFunction(Node):
class Argument:
def __init__(self, name, dtype, symbol, kernel_function_node):
from pystencils.transformations import symbol_name_to_variable_name
self.name = name
self.dtype = dtype
self.is_field_ptr_argument = False
self.is_field_shape_argument = False
self.is_field_stride_argument = False
self.is_field_argument = False
self.field_name = ""
self.coordinate = None
self.symbol = symbol
if name.startswith(Field.DATA_PREFIX):
self.is_field_ptr_argument = True
self.is_field_argument = True
self.field_name = name[len(Field.DATA_PREFIX):]
elif name.startswith(Field.SHAPE_PREFIX):
self.is_field_shape_argument = True
self.is_field_argument = True
self.field_name = name[len(Field.SHAPE_PREFIX):]
elif name.startswith(Field.STRIDE_PREFIX):
self.is_field_stride_argument = True
self.is_field_argument = True
self.field_name = name[len(Field.STRIDE_PREFIX):]
self.field = None
if self.is_field_argument:
field_map = {symbol_name_to_variable_name(f.name): f for f in kernel_function_node.fields_accessed}
self.field = field_map[self.field_name]
def __lt__(self, other):
def score(l):
if l.is_field_ptr_argument:
return -4
elif l.is_field_shape_argument:
return -3
elif l.is_field_stride_argument:
return -2
return 0
if score(self) < score(other):
return True
elif score(self) == score(other):
return self.name < other.name
else:
return False
class Parameter:
"""Function parameter.
Each undefined symbol in a `KernelFunction` node becomes a parameter to the function.
Parameters are either symbols introduced by the user that never occur on the left hand side of an
Assignment, or are related to fields/arrays passed to the function.
A parameter consists of the typed symbol (symbol property). For field related parameters this is a symbol
defined in pystencils.kernelparameters.
If the parameter is related to one or multiple fields, these fields are referenced in the fields property.
"""
def __init__(self, symbol, fields):
self.symbol = symbol # type: TypedSymbol
self.fields = fields # type: Sequence[Field]
def __repr__(self):
return '<{0} {1}>'.format(self.dtype, self.name)
return repr(self.symbol)
@property
def is_field_stride(self):
return isinstance(self.symbol, FieldStrideSymbol)
@property
def is_field_shape(self):
return isinstance(self.symbol, FieldShapeSymbol)
@property
def is_field_pointer(self):
return isinstance(self.symbol, FieldPointerSymbol)
@property
def is_field_parameter(self):
return self.is_field_pointer or self.is_field_shape or self.is_field_stride
@property
def field_name(self):
return self.fields[0].name
def __init__(self, body, ghost_layers=None, function_name="kernel", backend=""):
super(KernelFunction, self).__init__()
self._body = body
body.parent = self
self._parameters = None
self.function_name = function_name
self._body.parent = self
self.compile = None
......@@ -193,11 +180,6 @@ class KernelFunction(Node):
def undefined_symbols(self):
return set()
@property
def parameters(self):
self._update_parameters()
return self._parameters
@property
def body(self):
return self._body
......@@ -207,24 +189,37 @@ class KernelFunction(Node):
return [self._body]
@property
def fields_accessed(self):
def fields_accessed(self) -> Set['ResolvedFieldAccess']:
"""Set of Field instances: fields which are accessed inside this kernel function"""
return set(o.field for o in self.atoms(ResolvedFieldAccess))
def _update_parameters(self):
undefined_symbols = self._body.undefined_symbols - self.global_variables
self._parameters = [KernelFunction.Argument(s.name, s.dtype, s, self) for s in undefined_symbols]
def get_parameters(self) -> Sequence['KernelFunction.Parameter']:
"""Returns list of parameters for this function.
This function is expensive, cache the result where possible!
"""
field_map = {f.name: f for f in self.fields_accessed}
def get_fields(symbol):
if hasattr(symbol, 'field_name'):
return field_map[symbol.field_name],
elif hasattr(symbol, 'field_names'):
return tuple(field_map[fn] for fn in symbol.field_names)
return ()
self._parameters.sort()
argument_symbols = self._body.undefined_symbols - self.global_variables
parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols]
parameters.sort(key=lambda p: p.symbol.name)
return parameters
def __str__(self):
self._update_parameters()
return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.function_name, self.parameters,
params = [p.symbol for p in self.get_parameters()]
return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.function_name, params,
("\t" + "\t".join(str(self.body).splitlines(True))))
def __repr__(self):
self._update_parameters()
return '{0} {1}({2})'.format(type(self).__name__, self.function_name, self.parameters)
params = [p.symbol for p in self.get_parameters()]
return '{0} {1}({2})'.format(type(self).__name__, self.function_name, params)
class Block(Node):
......
......@@ -124,7 +124,7 @@ class CBackend:
raise NotImplementedError("CBackend does not support node of type " + str(type(node)))
def _print_KernelFunction(self, node):
function_arguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
func_declaration = "FUNC_PREFIX void %s(%s)" % (node.function_name, ", ".join(function_arguments))
if self._signatureOnly:
return func_declaration
......
......@@ -54,9 +54,10 @@ def __shortened(node):
if isinstance(node, LoopOverCoordinate):
return "Loop over dim %d" % (node.coordinate_to_loop_over,)
elif isinstance(node, KernelFunction):
params = [f.name for f in node.fields_accessed]
params += [p.name for p in node.parameters if not p.is_field_argument]
return "Func: %s (%s)" % (node.function_name, ",".join(params))
params = node.get_parameters()
param_names = [p.field_name for p in params if p.is_field_pointer]
param_names += [p.symbol.name for p in params if not p.is_field_parameter]
return "Func: %s (%s)" % (node.function_name, ",".join(param_names))
elif isinstance(node, SympyAssignment):
return repr(node.lhs)
elif isinstance(node, Block):
......
......@@ -6,6 +6,7 @@ from pystencils.backends.cbackend import CustomCppCode
from pystencils.boundaries.createindexlist import numpy_data_type_for_boundary_object, create_boundary_index_array
from pystencils.cache import memorycache
from pystencils.data_types import create_type
from pystencils.kernelparameters import FieldPointerSymbol
DEFAULT_FLAG_TYPE = np.uint32
......@@ -204,9 +205,9 @@ class BoundaryHandling:
for b_obj, idx_arr in b[self._index_array_name].boundary_object_to_index_list.items():
kwargs[self._field_name] = b[self._field_name]
kwargs['indexField'] = idx_arr
data_used_in_kernel = (p.field_name
data_used_in_kernel = (p.fields[0].name
for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters
if p.is_field_ptr_argument and p.field_name not in kwargs)
if isinstance(p.symbol, FieldPointerSymbol) and p.fields[0].name not in kwargs)
kwargs.update({name: b[name] for name in data_used_in_kernel})
self._boundary_object_to_boundary_info[b_obj].kernel(**kwargs)
......@@ -220,9 +221,9 @@ class BoundaryHandling:
arguments = kwargs.copy()
arguments[self._field_name] = b[self._field_name]
arguments['indexField'] = idx_arr
data_used_in_kernel = (p.field_name
data_used_in_kernel = (p.fields[0].name
for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters
if p.is_field_ptr_argument and p.field_name not in arguments)
if isinstance(p.symbol, FieldPointerSymbol) and p.fields[0].name not in arguments)
arguments.update({name: b[name] for name in data_used_in_kernel if name not in arguments})
kernel = self._boundary_object_to_boundary_info[b_obj].kernel
......
......@@ -55,10 +55,11 @@ import numpy as np
import subprocess
from appdirs import user_config_dir, user_cache_dir
from collections import OrderedDict
from pystencils.kernelparameters import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol
from pystencils.utils import recursive_dict_update
from sysconfig import get_paths
from pystencils import FieldType, Field
from pystencils.data_types import get_base_type
from pystencils import FieldType
from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.utils import file_handle_for_atomic_write, atomic_file_write
......@@ -311,8 +312,8 @@ def equal_size_check(fields):
return ""
ref_field = fields[0]
cond = ["({field.name}_shape[{i}] == {ref_field.name}_shape[{i}])".format(ref_field=ref_field,
field=field_to_test, i=i)
cond = ["(buffer_{field.name}.shape[{i}] == buffer_{ref_field.name}.shape[{i}])".format(ref_field=ref_field,
field=field_to_test, i=i)
for field_to_test in fields[1:]
for i in range(fields[0].spatial_dimensions)]
cond = " && ".join(cond)
......@@ -326,60 +327,45 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
variable_sized_normal_fields = set()
variable_sized_index_fields = set()
for arg in parameter_info:
if arg.is_field_argument:
if arg.is_field_ptr_argument:
pre_call_code += template_extract_array.format(name=arg.field_name)
post_call_code += template_release_buffer.format(name=arg.field_name)
parameters.append("({dtype} *)buffer_{name}.buf".format(dtype=str(arg.field.dtype),
name=arg.field_name))
shapes = ", ".join(["buffer_{name}.shape[{i}]".format(name=arg.field_name, i=i)
for i in range(len(arg.field.strides))])
shape_type = get_base_type(Field.SHAPE_DTYPE)
pre_call_code += "{type} {name}_shape[] = {{ {elements} }};\n".format(type=shape_type,
name=arg.field_name,
elements=shapes)
item_size = get_base_type(arg.dtype).numpy_dtype.itemsize
strides = ["buffer_{name}.strides[{i}] / {bytes}".format(i=i, name=arg.field_name, bytes=item_size)
for i in range(len(arg.field.strides))]
strides = ", ".join(strides)
strides_type = get_base_type(Field.STRIDE_DTYPE)
pre_call_code += "{type} {name}_strides[] = {{ {elements} }};\n".format(type=strides_type,
name=arg.field_name,
elements=strides)
if insert_checks and arg.field.has_fixed_shape:
shape_cond = ["{name}_shape[{i}] == {s}".format(s=s, name=arg.field_name, i=i)
for i, s in enumerate(arg.field.spatial_shape)]
shape_cond = " && ".join(shape_cond)
pre_call_code += template_check_array.format(cond=shape_cond, what="shape", name=arg.field.name,
expected=str(arg.field.shape))
strides_cond = ["({name}_strides[{i}] == {s} || {name}_shape[{i}]<=1)".format(s=s, i=i,
name=arg.field_name)
for i, s in enumerate(arg.field.spatial_strides)]
strides_cond = " && ".join(strides_cond)
expected_strides_str = str([e * item_size for e in arg.field.strides])
pre_call_code += template_check_array.format(cond=strides_cond, what="strides", name=arg.field.name,
expected=expected_strides_str)
if insert_checks and not arg.field.has_fixed_shape:
if FieldType.is_generic(arg.field):
variable_sized_normal_fields.add(arg.field)
elif FieldType.is_indexed(arg.field):
variable_sized_index_fields.add(arg.field)
elif arg.is_field_shape_argument:
parameters.append("{name}_shape".format(name=arg.field_name))
elif arg.is_field_stride_argument:
parameters.append("{name}_strides".format(name=arg.field_name))
for param in parameter_info:
if param.is_field_pointer:
field = param.fields[0]
pre_call_code += template_extract_array.format(name=field.name)
post_call_code += template_release_buffer.format(name=field.name)
parameters.append("({dtype} *)buffer_{name}.buf".format(dtype=str(field.dtype), name=field.name))
if insert_checks and field.has_fixed_shape:
shape_cond = ["buffer_{name}.shape[{i}] == {s}".format(s=s, name=field.name, i=i)
for i, s in enumerate(field.spatial_shape)]
shape_cond = " && ".join(shape_cond)
pre_call_code += template_check_array.format(cond=shape_cond, what="shape", name=field.name,
expected=str(field.shape))
item_size = field.dtype.numpy_dtype.itemsize
expected_strides = [e * item_size for e in field.spatial_strides]
stride_check_code = "(buffer_{name}.strides[{i}] == {s} || buffer_{name}.shape[{i}]<=1)"
strides_cond = " && ".join([stride_check_code.format(s=s, i=i, name=field.name)
for i, s in enumerate(expected_strides)])
pre_call_code += template_check_array.format(cond=strides_cond, what="strides", name=field.name,
expected=str(expected_strides))
if insert_checks and not field.has_fixed_shape:
if FieldType.is_generic(field):
variable_sized_normal_fields.add(field)
elif FieldType.is_indexed(field):
variable_sized_index_fields.add(field)
elif param.is_field_stride:
field = param.fields[0]
item_size = field.dtype.numpy_dtype.itemsize
parameters.append("buffer_{name}.strides[{i}] / {bytes}".format(bytes=item_size, i=param.symbol.coordinate,
name=field.name))
elif param.is_field_shape:
parameters.append("buffer_{name}.shape[{i}]".format(i=param.symbol.coordinate, name=param.field_name))
else:
extract_function, target_type = type_mapping[arg.dtype.numpy_dtype.type]
extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type]
pre_call_code += template_extract_scalar.format(extract_function=extract_function, target_type=target_type,
name=arg.name)
parameters.append(arg.name)
name=param.symbol.name)
parameters.append(param.symbol.name)
pre_call_code += equal_size_check(variable_sized_normal_fields)
pre_call_code += equal_size_check(variable_sized_index_fields)
......@@ -449,7 +435,7 @@ class ExtensionModuleCode:
old_name = ast.function_name
ast.function_name = "kernel_" + name
print(generate_c(ast), file=file)
print(create_function_boilerplate_code(ast.parameters, name), file=file)
print(create_function_boilerplate_code(ast.get_parameters(), name), file=file)
ast.function_name = old_name
print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
......@@ -525,4 +511,4 @@ def compile_and_load(ast):
lib_file = compile_module(code, code_hash_str, base_dir=cache_config['object_cache'])
result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file)
return KernelWrapper(result, ast.parameters, ast)
return KernelWrapper(result, ast.get_parameters(), ast)
......@@ -62,7 +62,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
body = ast.Block(assignments)
loop_order = get_optimal_loop_ordering(fields_without_buffers)
ast_node = make_loop_over_domain(body, function_name, iteration_slice=iteration_slice,
ghost_layers=ghost_layers, loop_order=loop_order)
ghost_layers=ghost_layers, loop_order=loop_order)
ast_node.target = 'cpu'
if split_groups:
......
......@@ -88,8 +88,7 @@ class TypedSymbol(sp.Symbol):
return self._dtype
def _hashable_content(self):
super_class_contents = list(super(TypedSymbol, self)._hashable_content())
return tuple(super_class_contents + [hash(self._dtype)])
return super()._hashable_content(), hash(self._dtype)
def __getnewargs__(self):
return self.name, self.dtype
......
......@@ -3,6 +3,7 @@ import warnings
from pystencils import Field
from pystencils.datahandling.datahandling_interface import DataHandling
from pystencils.datahandling.blockiteration import sliced_block_iteration, block_iteration
from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.utils import DotDict
# noinspection PyPep8Naming
import waLBerla as wlb
......@@ -228,9 +229,9 @@ class ParallelDataHandling(DataHandling):
else:
name_map = self._field_name_to_cpu_data_name
to_array = wlb.field.toArray
data_used_in_kernel = [(name_map[p.field_name], self.fields[p.field_name])
data_used_in_kernel = [(name_map[p.symbol.field_name], self.fields[p.symbol.field_name])
for p in kernel_function.parameters if
p.is_field_ptr_argument and p.field_name not in kwargs]
isinstance(p.symbol, FieldPointerSymbol) and p.symbol.field_name not in kwargs]
result = []
for block in self.blocks:
......
......@@ -4,9 +4,9 @@ from typing import Tuple, Sequence, Optional, List, Set
import numpy as np
import sympy as sp
from sympy.core.cache import cacheit
from sympy.tensor import IndexedBase
from pystencils.alignedarray import aligned_empty
from pystencils.data_types import TypedSymbol, create_type, create_composite_type_from_string, StructType
from pystencils.data_types import create_type, create_composite_type_from_string, StructType
from pystencils.kernelparameters import FieldShapeSymbol, FieldStrideSymbol
from pystencils.stencils import offset_to_direction_string, direction_string_to_offset
from pystencils.sympyextensions import is_integer_sequence
......@@ -182,15 +182,14 @@ class Field:
index_dimensions = len(index_shape)
if isinstance(layout, str):
layout = spatial_layout_string_to_tuple(layout, dim=spatial_dimensions)
shape_symbol = IndexedBase(TypedSymbol(Field.SHAPE_PREFIX + field_name, Field.SHAPE_DTYPE), shape=(1,))
stride_symbol = IndexedBase(TypedSymbol(Field.STRIDE_PREFIX + field_name, Field.STRIDE_DTYPE), shape=(1,))
total_dimensions = spatial_dimensions + index_dimensions
if index_shape is None or len(index_shape) == 0:
shape = tuple([shape_symbol[i] for i in range(total_dimensions)])
shape = tuple([FieldShapeSymbol([field_name], i) for i in range(total_dimensions)])
else:
shape = tuple([shape_symbol[i] for i in range(spatial_dimensions)] + list(index_shape))
shape = tuple([FieldShapeSymbol([field_name], i) for i in range(spatial_dimensions)] + list(index_shape))
strides = tuple([stride_symbol[i] for i in range(total_dimensions)])
strides = tuple([FieldStrideSymbol(field_name, i) for i in range(total_dimensions)])
np_data_type = np.dtype(dtype)
if np_data_type.fields is not None:
......@@ -390,13 +389,6 @@ class Field:
return False
return self.hashable_contents() == other.hashable_contents()
PREFIX = "f"
STRIDE_PREFIX = PREFIX + "stride_"
SHAPE_PREFIX = PREFIX + "shape_"
STRIDE_DTYPE = create_composite_type_from_string("const int *")
SHAPE_DTYPE = create_composite_type_from_string("const int *")
DATA_PREFIX = PREFIX + "d_"
# noinspection PyAttributeOutsideInit,PyUnresolvedReferences
class Access(sp.Symbol):
"""Class representing a relative access into a `Field`.
......
import numpy as np
from pystencils.backends.cbackend import generate_c
from pystencils.transformations import symbol_name_to_variable_name
from pystencils.kernelparameters import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol
from pystencils.sympyextensions import symbol_name_to_variable_name
from pystencils.data_types import StructType, get_base_type
from pystencils.field import FieldType
......@@ -33,7 +34,7 @@ def make_python_function(kernel_function_node, argument_dict=None):
mod = SourceModule(code, options=["-w", "-std=c++11", "-Wno-deprecated-gpu-targets"])
func = mod.get_function(kernel_function_node.function_name)
parameters = kernel_function_node.parameters
parameters = kernel_function_node.get_parameters()
cache = {}
cache_values = []
......@@ -60,40 +61,37 @@ def make_python_function(kernel_function_node, argument_dict=None):
func(*args, **block_and_thread_numbers)
# cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
wrapper.ast = kernel_function_node
wrapper.parameters = kernel_function_node.parameters
wrapper.parameters = kernel_function_node.get_parameters()
wrapper.num_regs = func.num_regs
return wrapper
def _build_numpy_argument_list(parameters, argument_dict):
import pycuda.driver as cuda
argument_dict = {symbol_name_to_variable_name(k): v for k, v in argument_dict.items()}
result = []
for arg in parameters:
if arg.is_field_argument:
field = argument_dict[arg.field_name]
if arg.is_field_ptr_argument:
actual_type = field.dtype
expected_type = arg.dtype.base_type.numpy_dtype
if expected_type != actual_type:
raise ValueError("Data type mismatch for field '%s'. Expected '%s' got '%s'." %
(arg.field_name, expected_type, actual_type))
result.append(field)
elif arg.is_field_stride_argument:
dtype = get_base_type(arg.dtype).numpy_dtype
stride_arr = np.array(field.strides, dtype=dtype) // field.dtype.itemsize
result.append(cuda.In(stride_arr))
elif arg.is_field_shape_argument:
dtype = get_base_type(arg.dtype).numpy_dtype
shape_arr = np.array(field.shape, dtype=dtype)
result.append(cuda.In(shape_arr))
else:
assert False
for param in parameters:
if param.is_field_pointer:
array = argument_dict[param.field_name]
actual_type = array.dtype
expected_type = param.fields[0].dtype.numpy_dtype
if expected_type != actual_type:
raise ValueError("Data type mismatch for field '%s'. Expected '%s' got '%s'." %
(param.field_name, expected_type, actual_type))
result.append(array)
elif param.is_field_stride:
cast_to_dtype = param.symbol.dtype.numpy_dtype.type
array = argument_dict[param.field_name]
stride = cast_to_dtype(array.strides[param.symbol.coordinate] // array.dtype.itemsize)
result.append(stride)
elif param.is_field_shape:
cast_to_dtype = param.symbol.dtype.numpy_dtype.type
array = argument_dict[param.field_name]
result.append(cast_to_dtype(array.shape[param.symbol.coordinate]))
else:
param = argument_dict[arg.name]
expected_type = arg.dtype.numpy_dtype
result.append(expected_type.type(param))
expected_type = param.symbol.dtype.numpy_dtype
result.append(expected_type.type(argument_dict[param.symbol.name]))
assert len(result) == len(parameters)
return result
......@@ -106,34 +104,35 @@ def _check_arguments(parameter_specification, argument_dict):
argument_dict = {symbol_name_to_variable_name(k): v for k, v in argument_dict.items()}
array_shapes = set()
index_arr_shapes = set()
for arg in parameter_specification:
if arg.is_field_argument:
for param in parameter_specification:
if isinstance(param.symbol, FieldPointerSymbol):
symbolic_field = param.fields[0]
try:
field_arr = argument_dict[arg.field_name]
field_arr = argument_dict[symbolic_field.name]
except KeyError:
raise KeyError("Missing field parameter for kernel call " + arg.field_name)
symbolic_field = arg.field
if arg.is_field_ptr_argument:
if symbolic_field.has_fixed_shape:
symbolic_field_shape = tuple(int(i) for i in symbolic_field.shape)
if isinstance(symbolic_field.dtype, StructType):
symbolic_field_shape = symbolic_field_shape[:-1]
if symbolic_field_shape != field_arr.shape:
raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" %
(arg.field_name, str(field_arr.shape), str(symbolic_field.shape)))
if symbolic_field.has_fixed_shape:
symbolic_field_strides = tuple(int(i) * field_arr.dtype.itemsize for i in symbolic_field.strides)
if isinstance(symbolic_field.dtype, StructType):
symbolic_field_strides = symbolic_field_strides[:-1]
if symbolic_field_strides != field_arr.strides:
raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
(arg.field_name, str(field_arr.strides), str(symbolic_field_strides)))
if FieldType.is_indexed(symbolic_field):
index_arr_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
elif FieldType.is_generic(symbolic_field):
array_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
raise KeyError("Missing field parameter for kernel call " + str(symbolic_field))
if symbolic_field.has_fixed_shape:
symbolic_field_shape = tuple(int(i) for i in symbolic_field.shape)
if isinstance(symbolic_field.dtype, StructType):
symbolic_field_shape = symbolic_field_shape[:-1]
if symbolic_field_shape != field_arr.shape:
raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" %
(symbolic_field.name, str(field_arr.shape), str(symbolic_field.shape)))
if symbolic_field.has_fixed_shape:
symbolic_field_strides = tuple(int(i) * field_arr.dtype.itemsize for i in symbolic_field.strides)
if isinstance(symbolic_field.dtype, StructType):
symbolic_field_strides = symbolic_field_strides[:-1]
if symbolic_field_strides != field_arr.strides:
raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
(symbolic_field.name, str(field_arr.strides), str(symbolic_field_strides)))
if FieldType.is_indexed(symbolic_field):
index_arr_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
elif FieldType.is_generic(symbolic_field):
array_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])