Commit ac6ced35 authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'loop_step' into 'master'

Fixes for buffers in loops with step size > 1

See merge request pycodegen/pystencils!252
parents 3dd60595 67c87a99
Pipeline #32537 passed with stages
in 44 minutes and 48 seconds
......@@ -529,6 +529,14 @@ class CustomSympyPrinter(CCodePrinter):
return res + '.0f' if '.' not in res else res + 'f'
elif dtype.numpy_dtype == np.float64:
return res + '.0' if '.' not in res else res
elif dtype.is_int():
tokens = res.split('.')
if len(tokens) == 1:
return res
elif int(tokens[1]) != 0:
raise ValueError(f"Cannot print non-integer number {res} as an integer.")
return tokens[0]
return res
......@@ -19,7 +19,7 @@ AssignmentOrAstNodeList = List[Union[Assignment, ast.Node]]
def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "kernel", type_info='double',
split_groups=(), iteration_slice=None, ghost_layers=None,
skip_independence_check=False) -> KernelFunction:
skip_independence_check=False, allow_double_writes=False) -> KernelFunction:
"""Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
Loops are created according to the field accesses in the equations.
......@@ -40,6 +40,9 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
all dimensions
skip_independence_check: don't check that loop iterations are independent. This is needed e.g. for
periodicity kernel, that access the field outside the iteration bounds. Use with care!
allow_double_writes: If True, don't check if every field is only written at a single location. This is required
for example for kernels that are compiled with loop step sizes > 1, that handle multiple
cells at once. Use with care!
AST node representing a function, that can be printed as C or CUDA code
......@@ -55,7 +58,8 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
raise ValueError("Term has to be field access or symbol")
fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check)
fields_read, fields_written, assignments = add_types(
assignments, type_info, not skip_independence_check, check_double_write_condition=not allow_double_writes)
all_fields = fields_read.union(fields_written)
read_only_fields = set([ for f in fields_read - fields_written])
......@@ -20,13 +20,21 @@ class IntegerFunctionTwoArgsMixIn(sp.Function):
for a in args:
type = get_type_of_expression(a)
if not type.is_int():
raise ValueError("Argument to integer function is not an int but " + str(type))
dtype = get_type_of_expression(a)
if not dtype.is_int():
raise ValueError("Argument to integer function is not an int but " + str(dtype))
except NotImplementedError:
raise ValueError("Integer functions can only be constructed with typed expressions")
return super().__new__(cls, *args)
def _eval_evalf(self, *pargs, **kwargs):
arg1 = self.args[0].evalf(*pargs, **kwargs) if hasattr(self.args[0], 'evalf') else self.args[0]
arg2 = self.args[1].evalf(*pargs, **kwargs) if hasattr(self.args[1], 'evalf') else self.args[1]
return self._eval_op(arg1, arg2)
def _eval_op(self, arg1, arg2):
return self
# noinspection PyPep8Naming
class bitwise_xor(IntegerFunctionTwoArgsMixIn):
......@@ -55,7 +63,9 @@ class bitwise_or(IntegerFunctionTwoArgsMixIn):
# noinspection PyPep8Naming
class int_div(IntegerFunctionTwoArgsMixIn):
def _eval_op(self, arg1, arg2):
return int(arg1 // arg2)
# noinspection PyPep8Naming
......@@ -20,6 +20,7 @@ from pystencils.field import AbstractField, Field, FieldType
from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.slicing import normalize_slice
from pystencils.integer_functions import int_div
class NestedScopes:
......@@ -357,20 +358,23 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
assert len(loops) == len(parents_of_innermost_loop)
assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
loop_iterations = [(l.stop - l.start) / l.step for l in loops]
loop_counters = [l.loop_counter_symbol for l in loops]
actual_sizes = [int_div((l.stop - l.start), l.step) for l in loops]
actual_steps = [int_div((l.loop_counter_symbol - l.start), l.step) for l in loops]
actual_sizes = loop_iterations
actual_steps = loop_counters
field_accesses = ast_node.atoms(AbstractField.AbstractAccess)
buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
loop_counters = [v * len(buffer_accesses) for v in loop_counters]
buffer_index_size = len(buffer_accesses)
base_buffer_index = loop_counters[0]
stride = 1
for idx, var in enumerate(loop_counters[1:]):
cur_stride = loop_iterations[idx]
stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
base_buffer_index += var * stride
return base_buffer_index
base_buffer_index = actual_steps[0]
actual_stride = 1
for idx, actual_step in enumerate(actual_steps[1:]):
cur_stride = actual_sizes[idx]
actual_stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
base_buffer_index += actual_stride * actual_step
return base_buffer_index * buffer_index_size
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
......@@ -933,7 +937,7 @@ class KernelConstraintsCheck:
def add_types(eqs, type_for_symbol, check_independence_condition):
def add_types(eqs, type_for_symbol, check_independence_condition, check_double_write_condition=True):
"""Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
Additionally returns sets of all fields which are read/written
......@@ -951,7 +955,8 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'):
type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
check = KernelConstraintsCheck(type_for_symbol, check_independence_condition,
def visit(obj):
if isinstance(obj, (list, tuple)):
......@@ -2,7 +2,7 @@
import numpy as np
from pystencils import Assignment, Field, FieldType, create_kernel
from pystencils import Assignment, Field, FieldType, create_kernel, make_slice
from pystencils.field import create_numpy_array_with_layout, layout_string_to_tuple
from pystencils.slicing import (
add_ghost_layers, get_ghost_region_slice, get_slice_before_ghost_layer)
......@@ -186,3 +186,49 @@ def test_field_layouts():
unpack_code = create_kernel(unpack_eqs, data_type={'dst_field': dst_arr.dtype, 'buffer': buffer.dtype})
unpack_kernel = unpack_code.compile()
unpack_kernel(buffer=bufferArr, dst_field=dst_arr)
def test_iteration_slices():
num_cell_values = 19
fields = _generate_fields(num_directions=num_cell_values)
for (src_arr, dst_arr, bufferArr) in fields:
src_field = Field.create_from_numpy_array("src_field", src_arr, index_dimensions=1)
dst_field = Field.create_from_numpy_array("dst_field", dst_arr, index_dimensions=1)
buffer = Field.create_generic("buffer", spatial_dimensions=1, index_dimensions=1,
field_type=FieldType.BUFFER, dtype=src_arr.dtype)
pack_eqs = []
# Since we are packing all cell values for all cells, then
# the buffer index is equivalent to the field index
for idx in range(num_cell_values):
eq = Assignment(buffer(idx), src_field(idx))
dim = src_field.spatial_dimensions
# Pack only the leftmost slice, only every second cell
pack_slice = (slice(None, None, 2),) * (dim-1) + (0, )
# Fill the entire array with data
src_arr[ (slice(None, None, 1),) * dim] = np.arange(num_cell_values)
pack_code = create_kernel(pack_eqs, iteration_slice=pack_slice, data_type={'src_field': src_arr.dtype, 'buffer': buffer.dtype})
pack_kernel = pack_code.compile()
pack_kernel(buffer=bufferArr, src_field=src_arr)
unpack_eqs = []
for idx in range(num_cell_values):
eq = Assignment(dst_field(idx), buffer(idx))
unpack_code = create_kernel(unpack_eqs, iteration_slice=pack_slice, data_type={'dst_field': dst_arr.dtype, 'buffer': buffer.dtype})
unpack_kernel = unpack_code.compile()
unpack_kernel(buffer=bufferArr, dst_field=dst_arr)
# Check if only every second entry of the leftmost slice has been copied
np.testing.assert_equal(dst_arr[pack_slice], src_arr[pack_slice])
np.testing.assert_equal(dst_arr[(slice(1, None, 2),) * (dim-1) + (0,)], 0.0)
np.testing.assert_equal(dst_arr[(slice(None, None, 1),) * (dim-1) + (slice(1,None),)], 0.0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment