Commit ca72adf4 authored by Martin Bauer's avatar Martin Bauer
Browse files

walberla integration + bugfix in GPU block indexing

parent bbd8d348
......@@ -2,7 +2,7 @@ import sympy as sp
from collections import namedtuple
from sympy.core import S
from typing import Set
from sympy.printing.ccode import C89CodePrinter
try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
except ImportError:
......@@ -233,7 +233,7 @@ class CustomSympyPrinter(CCodePrinter):
return result.replace("\n", "")
def _print_Function(self, expr):
function_map = {
infix_functions = {
bitwise_xor: '^',
bit_shift_right: '>>',
bit_shift_left: '<<',
......@@ -248,11 +248,8 @@ class CustomSympyPrinter(CCodePrinter):
return self._typed_number(arg, data_type)
else:
return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg))
elif expr.func == modulo_floor:
assert all(get_type_of_expression(e).is_int() for e in expr.args)
return "({dtype})({0} / {1}) * {1}".format(*expr.args, dtype=get_type_of_expression(expr.args[0]))
elif expr.func in function_map:
return "(%s %s %s)" % (self._print(expr.args[0]), function_map[expr.func], self._print(expr.args[1]))
elif expr.func in infix_functions:
return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
else:
return super(CustomSympyPrinter, self)._print_Function(expr)
......@@ -268,6 +265,9 @@ class CustomSympyPrinter(CCodePrinter):
else:
return res
_print_Max = C89CodePrinter._print_Max
_print_Min = C89CodePrinter._print_Min
# noinspection PyPep8Naming
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
......
......@@ -300,7 +300,7 @@ def get_type_of_expression(expr):
elif isinstance(expr, TypedSymbol):
return expr.dtype
elif isinstance(expr, sp.Symbol):
raise ValueError("All symbols inside this expression have to be typed!")
raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
elif isinstance(expr, cast_func):
return expr.args[1]
elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
......
......@@ -2,6 +2,7 @@ import abc
from typing import Tuple # noqa
import sympy as sp
from pystencils.astnodes import Conditional, Block
from pystencils.integer_functions import div_ceil
from pystencils.slicing import normalize_slice
from pystencils.data_types import TypedSymbol, create_type
from functools import partial
......@@ -70,10 +71,10 @@ class BlockIndexing(AbstractIndexing):
iteration_slice: slice that defines rectangular subarea which is iterated over
permute_block_size_dependent_on_layout: if True the block_size is permuted such that the fastest coordinate
gets the largest amount of threads
compile_time_block_size: compile in concrete block size, otherwise the cuda variable 'blockDim' is used
"""
def __init__(self, field, iteration_slice=None,
block_size=(16, 16, 1), permute_block_size_dependent_on_layout=True):
block_size=(16, 16, 1), permute_block_size_dependent_on_layout=True, compile_time_block_size=False):
if field.spatial_dimensions > 3:
raise NotImplementedError("This indexing scheme supports at most 3 spatial dimensions")
......@@ -83,16 +84,18 @@ class BlockIndexing(AbstractIndexing):
if AUTO_BLOCK_SIZE_LIMITING:
block_size = self.limit_block_size_to_device_maximum(block_size)
self._blockSize = block_size
self._block_size = block_size
self._iterationSlice = normalize_slice(iteration_slice, field.spatial_shape)
self._dim = field.spatial_dimensions
self._symbolicShape = [e if isinstance(e, sp.Basic) else None for e in field.spatial_shape]
self._compile_time_block_size = compile_time_block_size
@property
def coordinates(self):
offsets = _get_start_from_slice(self._iterationSlice)
block_size = self._block_size if self._compile_time_block_size else BLOCK_DIM
coordinates = [block_index * bs + thread_idx + off
for block_index, bs, thread_idx, off in zip(BLOCK_IDX, self._blockSize, THREAD_IDX, offsets)]
for block_index, bs, thread_idx, off in zip(BLOCK_IDX, block_size, THREAD_IDX, offsets)]
return coordinates[:self._dim]
......@@ -102,13 +105,16 @@ class BlockIndexing(AbstractIndexing):
widths = [end - start for start, end in zip(_get_start_from_slice(self._iterationSlice),
_get_end_from_slice(self._iterationSlice, arr_shape))]
widths = sp.Matrix(widths).subs(substitution_dict)
extend_bs = (1,) * (3 - len(self._block_size))
block_size = self._block_size + extend_bs
if not self._compile_time_block_size:
block_size = [sp.Min(bs, shape) for bs, shape in zip(block_size, widths)]
grid = tuple(sp.ceiling(length / block_size)
for length, block_size in zip(widths, self._blockSize)) # type: : Tuple[int, ...]
extend_bs = (1,) * (3 - len(self._blockSize))
grid = tuple(div_ceil(length, block_size)
for length, block_size in zip(widths, block_size))
extend_gr = (1,) * (3 - len(grid))
return {'block': self._blockSize + extend_bs,
return {'block': block_size,
'grid': grid + extend_gr}
def guard(self, kernel_content, arr_shape):
......
......@@ -70,3 +70,32 @@ class modulo_ceil(sp.Function):
assert dtype.is_int()
code = "(({0}) % ({1}) == 0 ? {0} : (({dtype})(({0}) / ({1}))+1) * ({1}))"
return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
# noinspection PyPep8Naming
class div_ceil(sp.Function):
"""Integer division that is always rounded up
Examples:
>>> div_ceil(9, 4)
3
>>> div_ceil(8, 4)
2
>>> from pystencils import TypedSymbol
>>> a, b = TypedSymbol("a", "int64"), TypedSymbol("b", "int32")
>>> div_ceil(a, b).to_c(str)
'( (a) % (b) == 0 ? (int64_t)(a) / (int64_t)(b) : ( (int64_t)(a) / (int64_t)(b) ) +1 )'
"""
nargs = 2
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
return integer // divisor if integer % divisor == 0 else (integer // divisor) + 1
else:
return super().__new__(cls, integer, divisor)
def to_c(self, print_func):
dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
assert dtype.is_int()
code = "( ({0}) % ({1}) == 0 ? ({dtype})({0}) / ({dtype})({1}) : ( ({dtype})({0}) / ({dtype})({1}) ) +1 )"
return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment