Skip to content
Snippets Groups Projects
Commit 6aeab41e authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Quality of life improvements for astnodes

parent 0f33273b
Branches
Tags
No related merge requests found
import collections.abc
import itertools
import uuid import uuid
from typing import Any, List, Optional, Sequence, Set, Union from typing import Any, List, Optional, Sequence, Set, Union
...@@ -33,7 +35,7 @@ class Node: ...@@ -33,7 +35,7 @@ class Node:
raise NotImplementedError() raise NotImplementedError()
def subs(self, subs_dict) -> None: def subs(self, subs_dict) -> None:
"""Inplace! substitute, similar to sympy's but modifies the AST inplace.""" """Inplace! Substitute, similar to sympy's but modifies the AST inplace."""
for a in self.args: for a in self.args:
a.subs(subs_dict) a.subs(subs_dict)
...@@ -102,7 +104,8 @@ class Conditional(Node): ...@@ -102,7 +104,8 @@ class Conditional(Node):
result = self.true_block.undefined_symbols result = self.true_block.undefined_symbols
if self.false_block: if self.false_block:
result.update(self.false_block.undefined_symbols) result.update(self.false_block.undefined_symbols)
result.update(self.condition_expr.atoms(sp.Symbol)) if hasattr(self.condition_expr, 'atoms'):
result.update(self.condition_expr.atoms(sp.Symbol))
return result return result
def __str__(self): def __str__(self):
...@@ -212,9 +215,16 @@ class KernelFunction(Node): ...@@ -212,9 +215,16 @@ class KernelFunction(Node):
"""Set of Field instances: fields which are accessed inside this kernel function""" """Set of Field instances: fields which are accessed inside this kernel function"""
return set(o.field for o in self.atoms(ResolvedFieldAccess)) return set(o.field for o in self.atoms(ResolvedFieldAccess))
def fields_written(self): @property
assigments = self.atoms(SympyAssignment) def fields_written(self) -> Set['ResolvedFieldAccess']:
return {a.lhs.field for a in assigments if isinstance(a.lhs, ResolvedFieldAccess)} assignments = self.atoms(SympyAssignment)
return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)}
@property
def fields_read(self) -> Set['ResolvedFieldAccess']:
assignments = self.atoms(SympyAssignment)
return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')]
for a in assignments))
def get_parameters(self) -> Sequence['KernelFunction.Parameter']: def get_parameters(self) -> Sequence['KernelFunction.Parameter']:
"""Returns list of parameters for this function. """Returns list of parameters for this function.
...@@ -283,8 +293,15 @@ class Block(Node): ...@@ -283,8 +293,15 @@ class Block(Node):
a.subs(subs_dict) a.subs(subs_dict)
def insert_front(self, node): def insert_front(self, node):
node.parent = self if isinstance(node, collections.abc.Iterable):
self._nodes.insert(0, node) node = list(node)
for n in node:
n.parent = self
self._nodes = node + self._nodes
else:
node.parent = self
self._nodes.insert(0, node)
def insert_before(self, new_node, insert_before): def insert_before(self, new_node, insert_before):
new_node.parent = self new_node.parent = self
...@@ -485,7 +502,7 @@ class SympyAssignment(Node): ...@@ -485,7 +502,7 @@ class SympyAssignment(Node):
def __init__(self, lhs_symbol, rhs_expr, is_const=True): def __init__(self, lhs_symbol, rhs_expr, is_const=True):
super(SympyAssignment, self).__init__(parent=None) super(SympyAssignment, self).__init__(parent=None)
self._lhs_symbol = lhs_symbol self._lhs_symbol = lhs_symbol
self.rhs = rhs_expr self.rhs = sp.simplify(rhs_expr)
self._is_const = is_const self._is_const = is_const
self._is_declaration = self.__is_declaration() self._is_declaration = self.__is_declaration()
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment