From fb25d97acc74c90ce2e47adc71b460decc354fc0 Mon Sep 17 00:00:00 2001 From: Arttu Miettinen Date: Mon, 29 Mar 2021 12:30:09 +0300 Subject: [PATCH] Adds support for larger indices in indexed kernels in CPU mode by changing int32 indices to int64 indices. --- lbmpy/advanced_streaming/indexing.py | 2 +- lbmpy/boundaries/boundaryhandling.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/lbmpy/advanced_streaming/indexing.py b/lbmpy/advanced_streaming/indexing.py index 6548235..6445bb9 100644 --- a/lbmpy/advanced_streaming/indexing.py +++ b/lbmpy/advanced_streaming/indexing.py @@ -219,7 +219,7 @@ class NeighbourOffsetArrays(CustomCodeNode): def _offset_symbols(dim): return [TypedSymbol(f"neighbour_offset_{d}", create_type(np.int64)) for d in ['x', 'y', 'z'][:dim]] - def __init__(self, stencil, offsets_dtype=np.int32): + def __init__(self, stencil, offsets_dtype=np.int64): offsets_dtype = create_type(offsets_dtype) dim = len(stencil[0]) diff --git a/lbmpy/boundaries/boundaryhandling.py b/lbmpy/boundaries/boundaryhandling.py index 7ebfad5..5bf5de4 100644 --- a/lbmpy/boundaries/boundaryhandling.py +++ b/lbmpy/boundaries/boundaryhandling.py @@ -172,13 +172,19 @@ class LbmWeightInfo(CustomCodeNode): super(LbmWeightInfo, self).__init__(code, symbols_read=set(), symbols_defined={w_sym}) # end class LbmWeightInfo +def wider_type(typ): + if typ == np.int32: + return np.int64 + if typ == np.uint32: + return np.uint64 + return typ def create_lattice_boltzmann_boundary_kernel(pdf_field, index_field, lb_method, boundary_functor, prev_timestep=Timestep.BOTH, streaming_pattern='pull', target='cpu', **kernel_creation_args): - index_dtype = index_field.dtype.numpy_dtype.fields['dir'][0] - offsets_dtype = index_field.dtype.numpy_dtype.fields['x'][0] + index_dtype = wider_type(index_field.dtype.numpy_dtype.fields['dir'][0]) + offsets_dtype = wider_type(index_field.dtype.numpy_dtype.fields['x'][0]) indexing = BetweenTimestepsIndexing( pdf_field, lb_method.stencil, prev_timestep, streaming_pattern, index_dtype, offsets_dtype) -- 2.25.1