Commit d7ad30ff authored by Martin Bauer's avatar Martin Bauer
Browse files

Started with Support for AST nodes in assignment collection

- allow AST nodes in assignment collection to e.g. put RNG nodes in
  LB method
parent 45701461
...@@ -4,7 +4,23 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, ...@@ -4,7 +4,23 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set,
import sympy as sp import sympy as sp
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.sympyextensions import count_operations, fast_subs, sort_assignments_topologically from pystencils.astnodes import Node
from pystencils.sympyextensions import count_operations, fast_subs
def transform_rhs(assignment_list, transformation, *args, **kwargs):
"""Applies a transformation function on the rhs of each element of the passed assignment list
If the list also contains other object, like AST nodes, these are ignored.
Additional parameters are passed to the transformation function"""
return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if isinstance(a, Assignment) else a
for a in assignment_list]
def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
return [Assignment(transformation(a.lhs, *args, **kwargs),
transformation(a.rhs, *args, **kwargs))
if isinstance(a, Assignment) else a
for a in assignment_list]
class AssignmentCollection: class AssignmentCollection:
...@@ -205,17 +221,15 @@ class AssignmentCollection: ...@@ -205,17 +221,15 @@ class AssignmentCollection:
Returns: Returns:
New AssignmentCollection where substitutions have been applied, self is not altered. New AssignmentCollection where substitutions have been applied, self is not altered.
""" """
if substitute_on_lhs: transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs
new_subexpressions = [fast_subs(eq, substitutions) for eq in self.subexpressions] transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions)
new_equations = [fast_subs(eq, substitutions) for eq in self.main_assignments] transformed_assignments = transform(self.main_assignments, fast_subs, substitutions)
else:
new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, substitutions)) for eq in self.subexpressions]
new_equations = [Assignment(eq.lhs, fast_subs(eq.rhs, substitutions)) for eq in self.main_assignments]
if add_substitutions_as_subexpressions: if add_substitutions_as_subexpressions:
new_subexpressions = [Assignment(b, a) for a, b in substitutions.items()] + new_subexpressions transformed_subexpressions = [Assignment(b, a) for a, b in
new_subexpressions = sort_assignments_topologically(new_subexpressions) substitutions.items()] + transformed_subexpressions
return self.copy(new_equations, new_subexpressions) transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions)
return self.copy(transformed_assignments, transformed_subexpressions)
def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection': def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection':
"""Returns a new collection which contains self and other. Subexpressions are renamed if they clash.""" """Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
...@@ -405,3 +419,22 @@ class SymbolGen: ...@@ -405,3 +419,22 @@ class SymbolGen:
name = "{}_{}".format(self._symbol, self._ctr) name = "{}_{}".format(self._symbol, self._ctr)
self._ctr += 1 self._ctr += 1
return sp.Symbol(name) return sp.Symbol(name)
def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
"""Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
edges = []
for c1, e1 in enumerate(assignments):
if isinstance(e1, Assignment):
symbols = [e1.lhs]
elif isinstance(e1, Node):
symbols = e1.symbols_defined
else:
symbols = []
for lhs in symbols:
for c2, e2 in enumerate(assignments):
if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols:
edges.append((c1, c2))
elif isinstance(e2, Node) and lhs in e2.undefined_symbols:
edges.append((c1, c2))
return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))]
...@@ -4,7 +4,7 @@ import sympy as sp ...@@ -4,7 +4,7 @@ import sympy as sp
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.field import AbstractField, Field from pystencils.field import AbstractField, Field
from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.simp.assignment_collection import AssignmentCollection, transform_rhs
from pystencils.sympyextensions import subs_additive from pystencils.sympyextensions import subs_additive
AC = AssignmentCollection AC = AssignmentCollection
...@@ -128,15 +128,14 @@ def add_subexpressions_for_field_reads(ac: AC, subexpressions=True, main_assignm ...@@ -128,15 +128,14 @@ def add_subexpressions_for_field_reads(ac: AC, subexpressions=True, main_assignm
if main_assignments: if main_assignments:
for assignment in ac.main_assignments: for assignment in ac.main_assignments:
field_reads.update(assignment.rhs.atoms(Field.Access)) field_reads.update(assignment.rhs.atoms(Field.Access))
substitutions = {fa: sp.Dummy() for fa in field_reads} substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads}
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False) return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)
def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]: def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]:
"""Applies sympy expand operation to all equations in collection.""" """Applies sympy expand operation to all equations in collection."""
def f(assignment_collection: AC) -> AC: def f(ac: AC) -> AC:
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments] return ac.copy(transform_rhs(ac.main_assignments, operation))
return assignment_collection.copy(result)
f.__name__ = operation.__name__ f.__name__ = operation.__name__
return f return f
...@@ -144,7 +143,6 @@ def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callabl ...@@ -144,7 +143,6 @@ def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callabl
def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]: def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]:
"""Applies the given operation on all subexpressions of the AC.""" """Applies the given operation on all subexpressions of the AC."""
def f(ac: AC) -> AC: def f(ac: AC) -> AC:
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions] return ac.copy(ac.main_assignments, transform_rhs(ac.subexpressions, operation))
return ac.copy(ac.main_assignments, result)
f.__name__ = operation.__name__ f.__name__ = operation.__name__
return f return f
...@@ -573,12 +573,6 @@ def get_symmetric_part(expr: sp.Expr, symbols: Iterable[sp.Symbol]) -> sp.Expr: ...@@ -573,12 +573,6 @@ def get_symmetric_part(expr: sp.Expr, symbols: Iterable[sp.Symbol]) -> sp.Expr:
return sp.Rational(1, 2) * (expr + expr.subs(substitution_dict)) return sp.Rational(1, 2) * (expr + expr.subs(substitution_dict))
def sort_assignments_topologically(assignments: Sequence[Assignment]) -> List[Assignment]:
"""Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
res = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in assignments])
return [Assignment(a, b) for a, b in res]
class SymbolCreator: class SymbolCreator:
def __getattribute__(self, name): def __getattribute__(self, name):
return sp.Symbol(name) return sp.Symbol(name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment