diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index bfed79dbe57baab66a5cd0dc3d35297535aa5fb5..47f1fd7d1d1715bf85e326e485aad8231dadcdfe 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -744,3 +744,28 @@ class EmptyLine(Node): def __repr__(self): return self.__str__() + + +class ConditionalFieldAccess(sp.Function): + """ + :class:`pystencils.Field.Access` that is only executed if a certain condition is met. + Can be used, for instance, for out-of-bound checks. + """ + + def __new__(cls, field_access, outofbounds_condition, outofbounds_value=0): + return sp.Function.__new__(cls, field_access, outofbounds_condition, sp.S(outofbounds_value)) + + @property + def access(self): + return self.args[0] + + @property + def outofbounds_condition(self): + return self.args[1] + + @property + def outofbounds_value(self): + return self.args[2] + + def __getnewargs__(self): + return self.access, self.outofbounds_condition, self.outofbounds_value diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index a50f00b2406da664620d19fc874603d826f5971a..7248846b312c920768e1f5b68af65aa21cfb46b8 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -426,6 +426,9 @@ class CustomSympyPrinter(CCodePrinter): ) return code + def _print_ConditionalFieldAccess(self, node): + return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True))) + _print_Max = C89CodePrinter._print_Max _print_Min = C89CodePrinter._print_Min diff --git a/pystencils_tests/test_conditional_field_access.py b/pystencils_tests/test_conditional_field_access.py new file mode 100644 index 0000000000000000000000000000000000000000..f68b34679ae5de0996a09a30921c85ec49bece58 --- /dev/null +++ b/pystencils_tests/test_conditional_field_access.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# +# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> +# +# Distributed under terms of the GPLv3 license. + +""" + +""" +import itertools + +import numpy as np +import pytest +import sympy as sp + +import pystencils as ps +from pystencils import Field, x_vector +from pystencils.astnodes import ConditionalFieldAccess +from pystencils.simp import sympy_cse + + +def add_fixed_constant_boundary_handling(assignments, with_cse): + + common_shape = next(iter(set().union(itertools.chain.from_iterable( + [a.atoms(Field.Access) for a in assignments] + )))).field.spatial_shape + ndim = len(common_shape) + + def is_out_of_bound(access, shape): + return sp.Or(*[sp.Or(a < 0, a >= s) for a, s in zip(access, shape)]) + + safe_assignments = [ps.Assignment( + assignment.lhs, assignment.rhs.subs({ + a: ConditionalFieldAccess(a, is_out_of_bound(sp.Matrix(a.offsets) + x_vector(ndim), common_shape)) + for a in assignment.rhs.atoms(Field.Access) if not a.is_absolute_access + })) for assignment in assignments.all_assignments] + + subs = [{a: ConditionalFieldAccess(a, is_out_of_bound( + sp.Matrix(a.offsets) + x_vector(ndim), common_shape)) + for a in assignment.rhs.atoms(Field.Access) if not a.is_absolute_access + } for assignment in assignments.all_assignments] + print(subs) + + if with_cse: + safe_assignments = sympy_cse(ps.AssignmentCollection(safe_assignments)) + return safe_assignments + else: + return ps.AssignmentCollection(safe_assignments) + + +@pytest.mark.parametrize('with_cse', (False, 'with_cse')) +def test_boundary_check(with_cse): + + f, g = ps.fields("f, g : [2D]") + stencil = ps.Assignment(g[0, 0], + (f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4) + + f_arr = np.random.rand(1000, 1000) + g_arr = np.zeros_like(f_arr) + # kernel(f=f_arr, g=g_arr) + + assignments = add_fixed_constant_boundary_handling(ps.AssignmentCollection([stencil]), with_cse) + + print(assignments) + kernel_checked = ps.create_kernel(assignments, ghost_layers=0).compile() + print(ps.show_code(kernel_checked)) + + # No SEGFAULT, please!! + kernel_checked(f=f_arr, g=g_arr)