Commit ac6ced35 authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'loop_step' into 'master'

Fixes for buffers in loops with step size > 1

See merge request pycodegen/pystencils!252
parents 3dd60595 67c87a99
...@@ -529,6 +529,14 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -529,6 +529,14 @@ class CustomSympyPrinter(CCodePrinter):
return res + '.0f' if '.' not in res else res + 'f' return res + '.0f' if '.' not in res else res + 'f'
elif dtype.numpy_dtype == np.float64: elif dtype.numpy_dtype == np.float64:
return res + '.0' if '.' not in res else res return res + '.0' if '.' not in res else res
elif dtype.is_int():
tokens = res.split('.')
if len(tokens) == 1:
return res
elif int(tokens[1]) != 0:
raise ValueError(f"Cannot print non-integer number {res} as an integer.")
else:
return tokens[0]
else: else:
return res return res
......
...@@ -19,7 +19,7 @@ AssignmentOrAstNodeList = List[Union[Assignment, ast.Node]] ...@@ -19,7 +19,7 @@ AssignmentOrAstNodeList = List[Union[Assignment, ast.Node]]
def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "kernel", type_info='double', def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "kernel", type_info='double',
split_groups=(), iteration_slice=None, ghost_layers=None, split_groups=(), iteration_slice=None, ghost_layers=None,
skip_independence_check=False) -> KernelFunction: skip_independence_check=False, allow_double_writes=False) -> KernelFunction:
"""Creates an abstract syntax tree for a kernel function, by taking a list of update rules. """Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
Loops are created according to the field accesses in the equations. Loops are created according to the field accesses in the equations.
...@@ -40,6 +40,9 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke ...@@ -40,6 +40,9 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
all dimensions all dimensions
skip_independence_check: don't check that loop iterations are independent. This is needed e.g. for skip_independence_check: don't check that loop iterations are independent. This is needed e.g. for
periodicity kernel, that access the field outside the iteration bounds. Use with care! periodicity kernel, that access the field outside the iteration bounds. Use with care!
allow_double_writes: If True, don't check if every field is only written at a single location. This is required
for example for kernels that are compiled with loop step sizes > 1, that handle multiple
cells at once. Use with care!
Returns: Returns:
AST node representing a function, that can be printed as C or CUDA code AST node representing a function, that can be printed as C or CUDA code
...@@ -55,7 +58,8 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke ...@@ -55,7 +58,8 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
else: else:
raise ValueError("Term has to be field access or symbol") raise ValueError("Term has to be field access or symbol")
fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check) fields_read, fields_written, assignments = add_types(
assignments, type_info, not skip_independence_check, check_double_write_condition=not allow_double_writes)
all_fields = fields_read.union(fields_written) all_fields = fields_read.union(fields_written)
read_only_fields = set([f.name for f in fields_read - fields_written]) read_only_fields = set([f.name for f in fields_read - fields_written])
......
...@@ -20,13 +20,21 @@ class IntegerFunctionTwoArgsMixIn(sp.Function): ...@@ -20,13 +20,21 @@ class IntegerFunctionTwoArgsMixIn(sp.Function):
for a in args: for a in args:
try: try:
type = get_type_of_expression(a) dtype = get_type_of_expression(a)
if not type.is_int(): if not dtype.is_int():
raise ValueError("Argument to integer function is not an int but " + str(type)) raise ValueError("Argument to integer function is not an int but " + str(dtype))
except NotImplementedError: except NotImplementedError:
raise ValueError("Integer functions can only be constructed with typed expressions") raise ValueError("Integer functions can only be constructed with typed expressions")
return super().__new__(cls, *args) return super().__new__(cls, *args)
def _eval_evalf(self, *pargs, **kwargs):
arg1 = self.args[0].evalf(*pargs, **kwargs) if hasattr(self.args[0], 'evalf') else self.args[0]
arg2 = self.args[1].evalf(*pargs, **kwargs) if hasattr(self.args[1], 'evalf') else self.args[1]
return self._eval_op(arg1, arg2)
def _eval_op(self, arg1, arg2):
return self
# noinspection PyPep8Naming # noinspection PyPep8Naming
class bitwise_xor(IntegerFunctionTwoArgsMixIn): class bitwise_xor(IntegerFunctionTwoArgsMixIn):
...@@ -55,7 +63,9 @@ class bitwise_or(IntegerFunctionTwoArgsMixIn): ...@@ -55,7 +63,9 @@ class bitwise_or(IntegerFunctionTwoArgsMixIn):
# noinspection PyPep8Naming # noinspection PyPep8Naming
class int_div(IntegerFunctionTwoArgsMixIn): class int_div(IntegerFunctionTwoArgsMixIn):
pass
def _eval_op(self, arg1, arg2):
return int(arg1 // arg2)
# noinspection PyPep8Naming # noinspection PyPep8Naming
......
...@@ -20,6 +20,7 @@ from pystencils.field import AbstractField, Field, FieldType ...@@ -20,6 +20,7 @@ from pystencils.field import AbstractField, Field, FieldType
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.slicing import normalize_slice from pystencils.slicing import normalize_slice
from pystencils.integer_functions import int_div
class NestedScopes: class NestedScopes:
...@@ -357,20 +358,23 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): ...@@ -357,20 +358,23 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
assert len(loops) == len(parents_of_innermost_loop) assert len(loops) == len(parents_of_innermost_loop)
assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop)) assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
loop_iterations = [(l.stop - l.start) / l.step for l in loops] actual_sizes = [int_div((l.stop - l.start), l.step) for l in loops]
loop_counters = [l.loop_counter_symbol for l in loops] actual_steps = [int_div((l.loop_counter_symbol - l.start), l.step) for l in loops]
else:
actual_sizes = loop_iterations
actual_steps = loop_counters
field_accesses = ast_node.atoms(AbstractField.AbstractAccess) field_accesses = ast_node.atoms(AbstractField.AbstractAccess)
buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)} buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
loop_counters = [v * len(buffer_accesses) for v in loop_counters] buffer_index_size = len(buffer_accesses)
base_buffer_index = loop_counters[0] base_buffer_index = actual_steps[0]
stride = 1 actual_stride = 1
for idx, var in enumerate(loop_counters[1:]): for idx, actual_step in enumerate(actual_steps[1:]):
cur_stride = loop_iterations[idx] cur_stride = actual_sizes[idx]
stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride actual_stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
base_buffer_index += var * stride base_buffer_index += actual_stride * actual_step
return base_buffer_index return base_buffer_index * buffer_index_size
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()): def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
...@@ -933,7 +937,7 @@ class KernelConstraintsCheck: ...@@ -933,7 +937,7 @@ class KernelConstraintsCheck:
self.scopes.access_symbol(rhs) self.scopes.access_symbol(rhs)
def add_types(eqs, type_for_symbol, check_independence_condition): def add_types(eqs, type_for_symbol, check_independence_condition, check_double_write_condition=True):
"""Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`. """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
Additionally returns sets of all fields which are read/written Additionally returns sets of all fields which are read/written
...@@ -951,7 +955,8 @@ def add_types(eqs, type_for_symbol, check_independence_condition): ...@@ -951,7 +955,8 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'): if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'):
type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol) type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
check = KernelConstraintsCheck(type_for_symbol, check_independence_condition) check = KernelConstraintsCheck(type_for_symbol, check_independence_condition,
check_double_write_condition=check_double_write_condition)
def visit(obj): def visit(obj):
if isinstance(obj, (list, tuple)): if isinstance(obj, (list, tuple)):
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import numpy as np import numpy as np
from pystencils import Assignment, Field, FieldType, create_kernel from pystencils import Assignment, Field, FieldType, create_kernel, make_slice
from pystencils.field import create_numpy_array_with_layout, layout_string_to_tuple from pystencils.field import create_numpy_array_with_layout, layout_string_to_tuple
from pystencils.slicing import ( from pystencils.slicing import (
add_ghost_layers, get_ghost_region_slice, get_slice_before_ghost_layer) add_ghost_layers, get_ghost_region_slice, get_slice_before_ghost_layer)
...@@ -186,3 +186,49 @@ def test_field_layouts(): ...@@ -186,3 +186,49 @@ def test_field_layouts():
unpack_code = create_kernel(unpack_eqs, data_type={'dst_field': dst_arr.dtype, 'buffer': buffer.dtype}) unpack_code = create_kernel(unpack_eqs, data_type={'dst_field': dst_arr.dtype, 'buffer': buffer.dtype})
unpack_kernel = unpack_code.compile() unpack_kernel = unpack_code.compile()
unpack_kernel(buffer=bufferArr, dst_field=dst_arr) unpack_kernel(buffer=bufferArr, dst_field=dst_arr)
def test_iteration_slices():
num_cell_values = 19
fields = _generate_fields(num_directions=num_cell_values)
for (src_arr, dst_arr, bufferArr) in fields:
src_field = Field.create_from_numpy_array("src_field", src_arr, index_dimensions=1)
dst_field = Field.create_from_numpy_array("dst_field", dst_arr, index_dimensions=1)
buffer = Field.create_generic("buffer", spatial_dimensions=1, index_dimensions=1,
field_type=FieldType.BUFFER, dtype=src_arr.dtype)
pack_eqs = []
# Since we are packing all cell values for all cells, then
# the buffer index is equivalent to the field index
for idx in range(num_cell_values):
eq = Assignment(buffer(idx), src_field(idx))
pack_eqs.append(eq)
dim = src_field.spatial_dimensions
# Pack only the leftmost slice, only every second cell
pack_slice = (slice(None, None, 2),) * (dim-1) + (0, )
# Fill the entire array with data
src_arr[ (slice(None, None, 1),) * dim] = np.arange(num_cell_values)
dst_arr.fill(0.0)
pack_code = create_kernel(pack_eqs, iteration_slice=pack_slice, data_type={'src_field': src_arr.dtype, 'buffer': buffer.dtype})
pack_kernel = pack_code.compile()
pack_kernel(buffer=bufferArr, src_field=src_arr)
unpack_eqs = []
for idx in range(num_cell_values):
eq = Assignment(dst_field(idx), buffer(idx))
unpack_eqs.append(eq)
unpack_code = create_kernel(unpack_eqs, iteration_slice=pack_slice, data_type={'dst_field': dst_arr.dtype, 'buffer': buffer.dtype})
unpack_kernel = unpack_code.compile()
unpack_kernel(buffer=bufferArr, dst_field=dst_arr)
# Check if only every second entry of the leftmost slice has been copied
np.testing.assert_equal(dst_arr[pack_slice], src_arr[pack_slice])
np.testing.assert_equal(dst_arr[(slice(1, None, 2),) * (dim-1) + (0,)], 0.0)
np.testing.assert_equal(dst_arr[(slice(None, None, 1),) * (dim-1) + (slice(1,None),)], 0.0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment