diff --git a/pystencils/kerncraft_coupling/kerncraft_interface.py b/pystencils/kerncraft_coupling/kerncraft_interface.py index 8e8bcc618a53e8ceddd9671e4e9826da66fd71e6..2d0859f7e78b113077189960cab0404b373a1cda 100644 --- a/pystencils/kerncraft_coupling/kerncraft_interface.py +++ b/pystencils/kerncraft_coupling/kerncraft_interface.py @@ -3,10 +3,13 @@ import fcntl from collections import defaultdict from tempfile import TemporaryDirectory import textwrap +import itertools +import string from jinja2 import Environment, PackageLoader, StrictUndefined, Template import sympy as sp from kerncraft.kerncraft import KernelCode +from kerncraft.kernel import symbol_pos_int from kerncraft.machinemodel import MachineModel from pystencils.astnodes import \ @@ -75,10 +78,6 @@ class PyStencilsKerncraftKernel(KernelCode): cur_node = cur_node.parent self._loop_stack = list(reversed(self._loop_stack)) - # Data sources & destinations - self.sources = defaultdict(list) - self.destinations = defaultdict(list) - def get_layout_tuple(f): if f.has_fixed_shape: return get_layout_from_strides(f.strides) @@ -88,23 +87,37 @@ class PyStencilsKerncraftKernel(KernelCode): layout_list.insert(0 if assumed_layout == 'SoA' else -1, max(layout_list) + 1) return layout_list + # Variables (arrays) and Constants (scalar sizes) + const_names_iter = itertools.product(string.ascii_uppercase, repeat=1) + constants_reversed = {} + fields_accessed = self.kernel_ast.fields_accessed + for field in fields_accessed: + layout = get_layout_tuple(field) + permuted_shape = list(field.shape[i] for i in layout) + # Replace shape dimensions with constant variables (necessary for layer condition + # analysis) + for i, d in enumerate(permuted_shape): + if d not in self.constants.values(): + const_symbol = symbol_pos_int(''.join(next(const_names_iter))) + self.set_constant(const_symbol, d) + constants_reversed[d] = const_symbol + permuted_shape[i] = constants_reversed[d] + self.set_variable(field.name, (str(field.dtype),), tuple(permuted_shape)) + + # Data sources & destinations + self.sources = defaultdict(list) + self.destinations = defaultdict(list) + reads, writes = search_resolved_field_accesses_in_ast(inner_loop) for accesses, target_dict in [(reads, self.sources), (writes, self.destinations)]: for fa in accesses: - coord = [sp.Symbol(LoopOverCoordinate.get_loop_counter_name(i), positive=True, integer=True) + off + coord = [symbol_pos_int(LoopOverCoordinate.get_loop_counter_name(i)) + off for i, off in enumerate(fa.offsets)] coord += list(fa.idx_coordinate_values) layout = get_layout_tuple(fa.field) permuted_coord = [sp.sympify(coord[i]) for i in layout] target_dict[fa.field.name].append(permuted_coord) - # Variables (arrays) - fields_accessed = self.kernel_ast.fields_accessed - for field in fields_accessed: - layout = get_layout_tuple(field) - permuted_shape = list(field.shape[i] for i in layout) - self.set_variable(field.name, (str(field.dtype),), tuple(permuted_shape)) - # Scalars may be safely ignored # for param in self.kernel_ast.get_parameters(): # if not param.is_field_parameter: diff --git a/pystencils_tests/test_kerncraft_coupling.py b/pystencils_tests/test_kerncraft_coupling.py index 754604f1ed3aa06fb4138fb46fc918801c51049b..533cc954ed0c73fe7053c05115711353a13ee814 100644 --- a/pystencils_tests/test_kerncraft_coupling.py +++ b/pystencils_tests/test_kerncraft_coupling.py @@ -165,18 +165,3 @@ def test_benchmark(): timeloop_time = timeloop.benchmark(number_of_time_steps_for_estimation=1) np.testing.assert_almost_equal(c_benchmark_run, timeloop_time, decimal=4) - - -@pytest.mark.kerncraft -def test_kerncraft_generic_field(): - machine_file_path = INPUT_FOLDER / "Example_SandyBridgeEP_E5-2680.yml" - machine = MachineModel(path_to_yaml=machine_file_path) - - a = fields('a: double[3D]') - b = fields('b: double[3D]') - s = sp.Symbol("s") - rhs = a[0, -1, 0] + a[0, 1, 0] + a[-1, 0, 0] + a[1, 0, 0] + a[0, 0, -1] + a[0, 0, 1] - - update_rule = Assignment(b[0, 0, 0], s * rhs) - ast = create_kernel([update_rule]) - k = PyStencilsKerncraftKernel(ast, machine, debug_print=True)