diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 5cb8c8d3de7a7118d7696c33298e544c09b38db3..eb5910f9c5a7aa9a990d25d681e45e63a34188f4 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -211,17 +211,18 @@ class KernelFunction(Node): return self._body, @property - def fields_accessed(self) -> Set['ResolvedFieldAccess']: + def fields_accessed(self) -> Set[Field]: """Set of Field instances: fields which are accessed inside this kernel function""" - return set(o.field for o in self.atoms(ResolvedFieldAccess)) + from pystencils.interpolation_astnodes import InterpolatorAccess + return set(o.field for o in itertools.chain(self.atoms(ResolvedFieldAccess), self.atoms(InterpolatorAccess))) @property - def fields_written(self) -> Set['ResolvedFieldAccess']: + def fields_written(self) -> Set[Field]: assignments = self.atoms(SympyAssignment) return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)} @property - def fields_read(self) -> Set['ResolvedFieldAccess']: + def fields_read(self) -> Set[Field]: assignments = self.atoms(SympyAssignment) return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')] for a in assignments)) diff --git a/pystencils_tests/test_interpolation.py b/pystencils_tests/test_interpolation.py index 02fb2c3934c2514bb804c46198e4352daa3307c5..433cc309baf85cf37c66a2398a0d530b6e88fd40 100644 --- a/pystencils_tests/test_interpolation.py +++ b/pystencils_tests/test_interpolation.py @@ -7,9 +7,11 @@ """ """ +import itertools from os.path import dirname, join import numpy as np +import pytest import sympy import pycuda.autoinit # NOQA @@ -215,19 +217,20 @@ def test_rotate_interpolation_size_change(): pyconrad.imshow(out, "small out " + address_mode) -def test_field_interpolated(): +@pytest.mark.parametrize('address_mode, target', + itertools.product(['border', 'wrap', 'clamp', 'mirror'], ['cpu', 'gpu'])) +def test_field_interpolated(address_mode, target): x_f, y_f = pystencils.fields('x,y: float64 [2d]') - for address_mode in ['border', 'wrap', 'clamp', 'mirror']: - assignments = pystencils.AssignmentCollection({ - y_f.center(): x_f.interpolated_access([0.5 * x_ + 2.7, 0.25 * y_ + 7.2], address_mode=address_mode) - }) - print(assignments) - ast = pystencils.create_kernel(assignments) - print(ast) - print(pystencils.show_code(ast)) - kernel = ast.compile() + assignments = pystencils.AssignmentCollection({ + y_f.center(): x_f.interpolated_access([0.5 * x_ + 2.7, 0.25 * y_ + 7.2], address_mode=address_mode) + }) + print(assignments) + ast = pystencils.create_kernel(assignments) + print(ast) + print(pystencils.show_code(ast)) + kernel = ast.compile() - out = np.zeros_like(lenna) - kernel(x=lenna, y=out) - pyconrad.imshow(out, "out " + address_mode) + out = np.zeros_like(lenna) + kernel(x=lenna, y=out) + pyconrad.imshow(out, "out " + address_mode)