From 6aeab41e96206410fc57797ced12b8be91394669 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Tue, 24 Sep 2019 14:11:56 +0200
Subject: [PATCH] Quality of life improvements for astnodes

---
 pystencils/astnodes.py | 33 +++++++++++++++++++++++++--------
 1 file changed, 25 insertions(+), 8 deletions(-)

diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index b9ad28e2..69ce607d 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -1,3 +1,5 @@
+import collections.abc
+import itertools
 import uuid
 from typing import Any, List, Optional, Sequence, Set, Union
 
@@ -33,7 +35,7 @@ class Node:
         raise NotImplementedError()
 
     def subs(self, subs_dict) -> None:
-        """Inplace! substitute, similar to sympy's but modifies the AST inplace."""
+        """Inplace! Substitute, similar to sympy's but modifies the AST inplace."""
         for a in self.args:
             a.subs(subs_dict)
 
@@ -102,7 +104,8 @@ class Conditional(Node):
         result = self.true_block.undefined_symbols
         if self.false_block:
             result.update(self.false_block.undefined_symbols)
-        result.update(self.condition_expr.atoms(sp.Symbol))
+        if hasattr(self.condition_expr, 'atoms'):
+            result.update(self.condition_expr.atoms(sp.Symbol))
         return result
 
     def __str__(self):
@@ -212,9 +215,16 @@ class KernelFunction(Node):
         """Set of Field instances: fields which are accessed inside this kernel function"""
         return set(o.field for o in self.atoms(ResolvedFieldAccess))
 
-    def fields_written(self):
-        assigments = self.atoms(SympyAssignment)
-        return {a.lhs.field for a in assigments if isinstance(a.lhs, ResolvedFieldAccess)}
+    @property
+    def fields_written(self) -> Set['ResolvedFieldAccess']:
+        assignments = self.atoms(SympyAssignment)
+        return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)}
+
+    @property
+    def fields_read(self) -> Set['ResolvedFieldAccess']:
+        assignments = self.atoms(SympyAssignment)
+        return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')]
+                                                         for a in assignments))
 
     def get_parameters(self) -> Sequence['KernelFunction.Parameter']:
         """Returns list of parameters for this function.
@@ -283,8 +293,15 @@ class Block(Node):
             a.subs(subs_dict)
 
     def insert_front(self, node):
-        node.parent = self
-        self._nodes.insert(0, node)
+        if isinstance(node, collections.abc.Iterable):
+            node = list(node)
+            for n in node:
+                n.parent = self
+
+            self._nodes = node + self._nodes
+        else:
+            node.parent = self
+            self._nodes.insert(0, node)
 
     def insert_before(self, new_node, insert_before):
         new_node.parent = self
@@ -485,7 +502,7 @@ class SympyAssignment(Node):
     def __init__(self, lhs_symbol, rhs_expr, is_const=True):
         super(SympyAssignment, self).__init__(parent=None)
         self._lhs_symbol = lhs_symbol
-        self.rhs = rhs_expr
+        self.rhs = sp.simplify(rhs_expr)
         self._is_const = is_const
         self._is_declaration = self.__is_declaration()
 
-- 
GitLab