diff --git a/transformations/transformations.py b/transformations/transformations.py index 73bc8361dbf0a5089433aea7d16ccf62cfd235fb..299201f0cdf39c9b7fa2bd0ddbcac9eec6832496 100644 --- a/transformations/transformations.py +++ b/transformations/transformations.py @@ -1,6 +1,7 @@ from collections import defaultdict, OrderedDict from operator import attrgetter from copy import deepcopy +import functools import sympy as sp from sympy.logic.boolalg import Boolean @@ -22,6 +23,9 @@ def filteredTreeIteration(node, nodeType): def fastSubs(term, subsDict): """Similar to sympy subs function. This version is much faster for big substitution dictionaries than sympy version""" + if type(term) is sp.Matrix: + return term.copy().applyfunc(functools.partial(fastSubs, subsDict=subsDict)) + def visit(expr): if expr in subsDict: return subsDict[expr] @@ -293,6 +297,7 @@ def substituteArrayAccessesWithConstants(astNode): for a in astNode.args: substituteArrayAccessesWithConstants(a) + def resolveBufferAccesses(astNode, baseBufferIndex, readOnlyFieldNames=set()): def visitSympyExpr(expr, enclosingBlock, sympyAssignment): if isinstance(expr, Field.Access):