Commit 6aeab41e authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Quality of life improvements for astnodes

parent 0f33273b
import itertools
import uuid
from typing import Any, List, Optional, Sequence, Set, Union
......@@ -33,7 +35,7 @@ class Node:
raise NotImplementedError()
def subs(self, subs_dict) -> None:
"""Inplace! substitute, similar to sympy's but modifies the AST inplace."""
"""Inplace! Substitute, similar to sympy's but modifies the AST inplace."""
for a in self.args:
......@@ -102,7 +104,8 @@ class Conditional(Node):
result = self.true_block.undefined_symbols
if self.false_block:
if hasattr(self.condition_expr, 'atoms'):
return result
def __str__(self):
......@@ -212,9 +215,16 @@ class KernelFunction(Node):
"""Set of Field instances: fields which are accessed inside this kernel function"""
return set(o.field for o in self.atoms(ResolvedFieldAccess))
def fields_written(self):
assigments = self.atoms(SympyAssignment)
return {a.lhs.field for a in assigments if isinstance(a.lhs, ResolvedFieldAccess)}
def fields_written(self) -> Set['ResolvedFieldAccess']:
assignments = self.atoms(SympyAssignment)
return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)}
def fields_read(self) -> Set['ResolvedFieldAccess']:
assignments = self.atoms(SympyAssignment)
return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')]
for a in assignments))
def get_parameters(self) -> Sequence['KernelFunction.Parameter']:
"""Returns list of parameters for this function.
......@@ -283,8 +293,15 @@ class Block(Node):
def insert_front(self, node):
node.parent = self
self._nodes.insert(0, node)
if isinstance(node,
node = list(node)
for n in node:
n.parent = self
self._nodes = node + self._nodes
node.parent = self
self._nodes.insert(0, node)
def insert_before(self, new_node, insert_before):
new_node.parent = self
......@@ -485,7 +502,7 @@ class SympyAssignment(Node):
def __init__(self, lhs_symbol, rhs_expr, is_const=True):
super(SympyAssignment, self).__init__(parent=None)
self._lhs_symbol = lhs_symbol
self.rhs = rhs_expr
self.rhs = sp.simplify(rhs_expr)
self._is_const = is_const
self._is_declaration = self.__is_declaration()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment