Skip to content
Snippets Groups Projects
Commit 3dd60595 authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

Merge branch 'type_boundary' into 'master'

Use int64 for indexing

See merge request !251
parents e8b9fa9c c8ce1744
No related merge requests found
...@@ -83,7 +83,7 @@ class OpenClSympyPrinter(CudaSympyPrinter): ...@@ -83,7 +83,7 @@ class OpenClSympyPrinter(CudaSympyPrinter):
function_name, dimension = tuple(symbol_name.split(".")) function_name, dimension = tuple(symbol_name.split("."))
dimension = self.DIMENSION_MAPPING[dimension] dimension = self.DIMENSION_MAPPING[dimension]
function_name = self.INDEXING_FUNCTION_MAPPING[function_name] function_name = self.INDEXING_FUNCTION_MAPPING[function_name]
return f"(int) {function_name}({dimension})" return f"(int64_t) {function_name}({dimension})"
def _print_TextureAccess(self, node): def _print_TextureAccess(self, node):
raise NotImplementedError() raise NotImplementedError()
......
...@@ -430,7 +430,7 @@ class BoundaryOffsetInfo(CustomCodeNode): ...@@ -430,7 +430,7 @@ class BoundaryOffsetInfo(CustomCodeNode):
inverse_dir = tuple([-i for i in direction]) inverse_dir = tuple([-i for i in direction])
inv_dirs.append(str(stencil.index(inverse_dir))) inv_dirs.append(str(stencil.index(inverse_dir)))
code += "const int %s [] = { %s };\n" % (self.INV_DIR_SYMBOL.name, ", ".join(inv_dirs)) code += "const int64_t %s [] = { %s };\n" % (self.INV_DIR_SYMBOL.name, ", ".join(inv_dirs))
offset_symbols = BoundaryOffsetInfo._offset_symbols(dim) offset_symbols = BoundaryOffsetInfo._offset_symbols(dim)
super(BoundaryOffsetInfo, self).__init__(code, symbols_read=set(), super(BoundaryOffsetInfo, self).__init__(code, symbols_read=set(),
symbols_defined=set(offset_symbols + [self.INV_DIR_SYMBOL])) symbols_defined=set(offset_symbols + [self.INV_DIR_SYMBOL]))
...@@ -439,13 +439,12 @@ class BoundaryOffsetInfo(CustomCodeNode): ...@@ -439,13 +439,12 @@ class BoundaryOffsetInfo(CustomCodeNode):
def _offset_symbols(dim): def _offset_symbols(dim):
return [TypedSymbol(f"c{d}", create_type(np.int64)) for d in ['x', 'y', 'z'][:dim]] return [TypedSymbol(f"c{d}", create_type(np.int64)) for d in ['x', 'y', 'z'][:dim]]
INV_DIR_SYMBOL = TypedSymbol("invdir", "int") INV_DIR_SYMBOL = TypedSymbol("invdir", np.int64)
def create_boundary_kernel(field, index_field, stencil, boundary_functor, target='cpu', **kernel_creation_args): def create_boundary_kernel(field, index_field, stencil, boundary_functor, target='cpu', **kernel_creation_args):
elements = [BoundaryOffsetInfo(stencil)] elements = [BoundaryOffsetInfo(stencil)]
index_arr_dtype = index_field.dtype.numpy_dtype dir_symbol = TypedSymbol("dir", np.int64)
dir_symbol = TypedSymbol("dir", index_arr_dtype.fields['dir'][0])
elements += [Assignment(dir_symbol, index_field[0]('dir'))] elements += [Assignment(dir_symbol, index_field[0]('dir'))]
elements += boundary_functor(field, direction_symbol=dir_symbol, index_field=index_field) elements += boundary_functor(field, direction_symbol=dir_symbol, index_field=index_field)
return create_indexed_kernel(elements, [index_field], target=target, **kernel_creation_args) return create_indexed_kernel(elements, [index_field], target=target, **kernel_creation_args)
from typing import List, Union from typing import List, Union
import sympy as sp import sympy as sp
import numpy as np
import pystencils.astnodes as ast import pystencils.astnodes as ast
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
from pystencils.cpu.cpujit import make_python_function from pystencils.cpu.cpujit import make_python_function
from pystencils.data_types import BasicType, StructType, TypedSymbol, create_type from pystencils.data_types import StructType, TypedSymbol, create_type
from pystencils.field import Field, FieldType from pystencils.field import Field, FieldType
from pystencils.transformations import ( from pystencils.transformations import (
add_types, filtered_tree_iteration, get_base_buffer_index, get_optimal_loop_ordering, add_types, filtered_tree_iteration, get_base_buffer_index, get_optimal_loop_ordering,
...@@ -127,7 +128,7 @@ def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, fu ...@@ -127,7 +128,7 @@ def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, fu
data_type = idx_field.dtype data_type = idx_field.dtype
if data_type.has_element(name): if data_type.has_element(name):
rhs = idx_field[0](name) rhs = idx_field[0](name)
lhs = TypedSymbol(name, BasicType(data_type.get_element_type(name))) lhs = TypedSymbol(name, np.int64)
return SympyAssignment(lhs, rhs) return SympyAssignment(lhs, rhs)
raise ValueError(f"Index {name} not found in any of the passed index fields") raise ValueError(f"Index {name} not found in any of the passed index fields")
......
import numpy as np
from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
from pystencils.data_types import BasicType, StructType, TypedSymbol from pystencils.data_types import StructType, TypedSymbol
from pystencils.field import Field, FieldType from pystencils.field import Field, FieldType
from pystencils.gpucuda.cudajit import make_python_function from pystencils.gpucuda.cudajit import make_python_function
from pystencils.gpucuda.indexing import BlockIndexing from pystencils.gpucuda.indexing import BlockIndexing
...@@ -129,7 +131,7 @@ def created_indexed_cuda_kernel(assignments, ...@@ -129,7 +131,7 @@ def created_indexed_cuda_kernel(assignments,
data_type = ind_f.dtype data_type = ind_f.dtype
if data_type.has_element(name): if data_type.has_element(name):
rhs = ind_f[0](name) rhs = ind_f[0](name)
lhs = TypedSymbol(name, BasicType(data_type.get_element_type(name))) lhs = TypedSymbol(name, np.int64)
return SympyAssignment(lhs, rhs) return SympyAssignment(lhs, rhs)
raise ValueError(f"Index {name} not found in any of the passed index fields") raise ValueError(f"Index {name} not found in any of the passed index fields")
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment