Commit 6aeab41e authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Quality of life improvements for astnodes

parent 0f33273b
import collections.abc
import itertools
import uuid
from typing import Any, List, Optional, Sequence, Set, Union
......@@ -33,7 +35,7 @@ class Node:
raise NotImplementedError()
def subs(self, subs_dict) -> None:
"""Inplace! substitute, similar to sympy's but modifies the AST inplace."""
"""Inplace! Substitute, similar to sympy's but modifies the AST inplace."""
for a in self.args:
a.subs(subs_dict)
......@@ -102,6 +104,7 @@ class Conditional(Node):
result = self.true_block.undefined_symbols
if self.false_block:
result.update(self.false_block.undefined_symbols)
if hasattr(self.condition_expr, 'atoms'):
result.update(self.condition_expr.atoms(sp.Symbol))
return result
......@@ -212,9 +215,16 @@ class KernelFunction(Node):
"""Set of Field instances: fields which are accessed inside this kernel function"""
return set(o.field for o in self.atoms(ResolvedFieldAccess))
def fields_written(self):
assigments = self.atoms(SympyAssignment)
return {a.lhs.field for a in assigments if isinstance(a.lhs, ResolvedFieldAccess)}
@property
def fields_written(self) -> Set['ResolvedFieldAccess']:
assignments = self.atoms(SympyAssignment)
return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)}
@property
def fields_read(self) -> Set['ResolvedFieldAccess']:
assignments = self.atoms(SympyAssignment)
return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')]
for a in assignments))
def get_parameters(self) -> Sequence['KernelFunction.Parameter']:
"""Returns list of parameters for this function.
......@@ -283,6 +293,13 @@ class Block(Node):
a.subs(subs_dict)
def insert_front(self, node):
if isinstance(node, collections.abc.Iterable):
node = list(node)
for n in node:
n.parent = self
self._nodes = node + self._nodes
else:
node.parent = self
self._nodes.insert(0, node)
......@@ -485,7 +502,7 @@ class SympyAssignment(Node):
def __init__(self, lhs_symbol, rhs_expr, is_const=True):
super(SympyAssignment, self).__init__(parent=None)
self._lhs_symbol = lhs_symbol
self.rhs = rhs_expr
self.rhs = sp.simplify(rhs_expr)
self._is_const = is_const
self._is_declaration = self.__is_declaration()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment