From 6aeab41e96206410fc57797ced12b8be91394669 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Tue, 24 Sep 2019 14:11:56 +0200 Subject: [PATCH] Quality of life improvements for astnodes --- pystencils/astnodes.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index b9ad28e2..69ce607d 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -1,3 +1,5 @@ +import collections.abc +import itertools import uuid from typing import Any, List, Optional, Sequence, Set, Union @@ -33,7 +35,7 @@ class Node: raise NotImplementedError() def subs(self, subs_dict) -> None: - """Inplace! substitute, similar to sympy's but modifies the AST inplace.""" + """Inplace! Substitute, similar to sympy's but modifies the AST inplace.""" for a in self.args: a.subs(subs_dict) @@ -102,7 +104,8 @@ class Conditional(Node): result = self.true_block.undefined_symbols if self.false_block: result.update(self.false_block.undefined_symbols) - result.update(self.condition_expr.atoms(sp.Symbol)) + if hasattr(self.condition_expr, 'atoms'): + result.update(self.condition_expr.atoms(sp.Symbol)) return result def __str__(self): @@ -212,9 +215,16 @@ class KernelFunction(Node): """Set of Field instances: fields which are accessed inside this kernel function""" return set(o.field for o in self.atoms(ResolvedFieldAccess)) - def fields_written(self): - assigments = self.atoms(SympyAssignment) - return {a.lhs.field for a in assigments if isinstance(a.lhs, ResolvedFieldAccess)} + @property + def fields_written(self) -> Set['ResolvedFieldAccess']: + assignments = self.atoms(SympyAssignment) + return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)} + + @property + def fields_read(self) -> Set['ResolvedFieldAccess']: + assignments = self.atoms(SympyAssignment) + return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')] + for a in assignments)) def get_parameters(self) -> Sequence['KernelFunction.Parameter']: """Returns list of parameters for this function. @@ -283,8 +293,15 @@ class Block(Node): a.subs(subs_dict) def insert_front(self, node): - node.parent = self - self._nodes.insert(0, node) + if isinstance(node, collections.abc.Iterable): + node = list(node) + for n in node: + n.parent = self + + self._nodes = node + self._nodes + else: + node.parent = self + self._nodes.insert(0, node) def insert_before(self, new_node, insert_before): new_node.parent = self @@ -485,7 +502,7 @@ class SympyAssignment(Node): def __init__(self, lhs_symbol, rhs_expr, is_const=True): super(SympyAssignment, self).__init__(parent=None) self._lhs_symbol = lhs_symbol - self.rhs = rhs_expr + self.rhs = sp.simplify(rhs_expr) self._is_const = is_const self._is_declaration = self.__is_declaration() -- GitLab