diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 5cb8c8d3de7a7118d7696c33298e544c09b38db3..eb5910f9c5a7aa9a990d25d681e45e63a34188f4 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -211,17 +211,18 @@ class KernelFunction(Node): return self._body, @property - def fields_accessed(self) -> Set['ResolvedFieldAccess']: + def fields_accessed(self) -> Set[Field]: """Set of Field instances: fields which are accessed inside this kernel function""" - return set(o.field for o in self.atoms(ResolvedFieldAccess)) + from pystencils.interpolation_astnodes import InterpolatorAccess + return set(o.field for o in itertools.chain(self.atoms(ResolvedFieldAccess), self.atoms(InterpolatorAccess))) @property - def fields_written(self) -> Set['ResolvedFieldAccess']: + def fields_written(self) -> Set[Field]: assignments = self.atoms(SympyAssignment) return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)} @property - def fields_read(self) -> Set['ResolvedFieldAccess']: + def fields_read(self) -> Set[Field]: assignments = self.atoms(SympyAssignment) return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')] for a in assignments))