From 443527ae8f93b686ac5499bd45af6e1b52bc4f08 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 30 Sep 2019 18:46:57 +0200
Subject: [PATCH] Add ConditionalFieldAccess (Field.Access after out-of-bounds
 check)

---
 pystencils/astnodes.py                        | 25 +++++++
 pystencils/backends/cbackend.py               |  3 +
 .../test_conditional_field_access.py          | 69 +++++++++++++++++++
 3 files changed, 97 insertions(+)
 create mode 100644 pystencils_tests/test_conditional_field_access.py

diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index bfed79dbe..47f1fd7d1 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -744,3 +744,28 @@ class EmptyLine(Node):
 
     def __repr__(self):
         return self.__str__()
+
+
+class ConditionalFieldAccess(sp.Function):
+    """
+    :class:`pystencils.Field.Access` that is only executed if a certain condition is met.
+    Can be used, for instance, for out-of-bound checks.
+    """
+
+    def __new__(cls, field_access, outofbounds_condition, outofbounds_value=0):
+        return sp.Function.__new__(cls, field_access, outofbounds_condition, sp.S(outofbounds_value))
+
+    @property
+    def access(self):
+        return self.args[0]
+
+    @property
+    def outofbounds_condition(self):
+        return self.args[1]
+
+    @property
+    def outofbounds_value(self):
+        return self.args[2]
+
+    def __getnewargs__(self):
+        return self.access, self.outofbounds_condition, self.outofbounds_value
diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index a50f00b24..7248846b3 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -426,6 +426,9 @@ class CustomSympyPrinter(CCodePrinter):
         )
         return code
 
+    def _print_ConditionalFieldAccess(self, node):
+        return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True)))
+
     _print_Max = C89CodePrinter._print_Max
     _print_Min = C89CodePrinter._print_Min
 
diff --git a/pystencils_tests/test_conditional_field_access.py b/pystencils_tests/test_conditional_field_access.py
new file mode 100644
index 000000000..f68b34679
--- /dev/null
+++ b/pystencils_tests/test_conditional_field_access.py
@@ -0,0 +1,69 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
+#
+# Distributed under terms of the GPLv3 license.
+
+"""
+
+"""
+import itertools
+
+import numpy as np
+import pytest
+import sympy as sp
+
+import pystencils as ps
+from pystencils import Field, x_vector
+from pystencils.astnodes import ConditionalFieldAccess
+from pystencils.simp import sympy_cse
+
+
+def add_fixed_constant_boundary_handling(assignments, with_cse):
+
+    common_shape = next(iter(set().union(itertools.chain.from_iterable(
+        [a.atoms(Field.Access) for a in assignments]
+    )))).field.spatial_shape
+    ndim = len(common_shape)
+
+    def is_out_of_bound(access, shape):
+        return sp.Or(*[sp.Or(a < 0, a >= s) for a, s in zip(access, shape)])
+
+    safe_assignments = [ps.Assignment(
+        assignment.lhs, assignment.rhs.subs({
+            a: ConditionalFieldAccess(a, is_out_of_bound(sp.Matrix(a.offsets) + x_vector(ndim), common_shape))
+            for a in assignment.rhs.atoms(Field.Access) if not a.is_absolute_access
+        })) for assignment in assignments.all_assignments]
+
+    subs = [{a: ConditionalFieldAccess(a, is_out_of_bound(
+        sp.Matrix(a.offsets) + x_vector(ndim), common_shape))
+        for a in assignment.rhs.atoms(Field.Access) if not a.is_absolute_access
+    } for assignment in assignments.all_assignments]
+    print(subs)
+
+    if with_cse:
+        safe_assignments = sympy_cse(ps.AssignmentCollection(safe_assignments))
+        return safe_assignments
+    else:
+        return ps.AssignmentCollection(safe_assignments)
+
+
+@pytest.mark.parametrize('with_cse', (False, 'with_cse'))
+def test_boundary_check(with_cse):
+
+    f, g = ps.fields("f, g : [2D]")
+    stencil = ps.Assignment(g[0, 0],
+                            (f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4)
+
+    f_arr = np.random.rand(1000, 1000)
+    g_arr = np.zeros_like(f_arr)
+    # kernel(f=f_arr, g=g_arr)
+
+    assignments = add_fixed_constant_boundary_handling(ps.AssignmentCollection([stencil]), with_cse)
+
+    print(assignments)
+    kernel_checked = ps.create_kernel(assignments, ghost_layers=0).compile()
+    print(ps.show_code(kernel_checked))
+
+    # No SEGFAULT, please!!
+    kernel_checked(f=f_arr, g=g_arr)
-- 
GitLab