Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import itertools
import numpy as np
import pytest
import sympy as sp
import pystencils as ps
from pystencils import Field, x_vector
from pystencils.astnodes import ConditionalFieldAccess
from pystencils.simp import sympy_cse
def add_fixed_constant_boundary_handling(assignments, with_cse):
common_shape = next(iter(set().union(itertools.chain.from_iterable(
[a.atoms(Field.Access) for a in assignments]
)))).field.spatial_shape
ndim = len(common_shape)
def is_out_of_bound(access, shape):
return sp.Or(*[sp.Or(a < 0, a >= s) for a, s in zip(access, shape)])
safe_assignments = [ps.Assignment(
assignment.lhs, assignment.rhs.subs({
a: ConditionalFieldAccess(a, is_out_of_bound(sp.Matrix(a.offsets) + x_vector(ndim), common_shape))
for a in assignment.rhs.atoms(Field.Access) if not a.is_absolute_access
})) for assignment in assignments.all_assignments]
subs = [{a: ConditionalFieldAccess(a, is_out_of_bound(
sp.Matrix(a.offsets) + x_vector(ndim), common_shape))
for a in assignment.rhs.atoms(Field.Access) if not a.is_absolute_access
} for assignment in assignments.all_assignments]
print(subs)
if with_cse:
safe_assignments = sympy_cse(ps.AssignmentCollection(safe_assignments))
return safe_assignments
else:
return ps.AssignmentCollection(safe_assignments)
@pytest.mark.parametrize('with_cse', (False, 'with_cse'))
def test_boundary_check(with_cse):
f, g = ps.fields("f, g : [2D]")
stencil = ps.Assignment(g[0, 0],
(f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4)
f_arr = np.random.rand(1000, 1000)
g_arr = np.zeros_like(f_arr)
# kernel(f=f_arr, g=g_arr)
assignments = add_fixed_constant_boundary_handling(ps.AssignmentCollection([stencil]), with_cse)
print(assignments)
kernel_checked = ps.create_kernel(assignments, ghost_layers=0).compile()
# No SEGFAULT, please!!
kernel_checked(f=f_arr, g=g_arr)