Commit ef924b18 authored by Martin Bauer's avatar Martin Bauer
Browse files

Code Cleanup

- assignment collection
- sympyextensions
parent c43672d2
......@@ -5,6 +5,7 @@ from pystencils.kernelcreation import createKernel, createIndexedKernel
from pystencils.display_utils import showCode, toDot
from pystencils.assignment_collection import AssignmentCollection
from pystencils.assignment import Assignment
from pystencils.sympyextensions import SymbolCreator
__all__ = ['Field', 'FieldType', 'extractCommonSubexpressions',
'TypedSymbol',
......@@ -12,4 +13,5 @@ __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions',
'createKernel', 'createIndexedKernel',
'showCode', 'toDot',
'AssignmentCollection',
'Assignment']
'Assignment',
'SymbolCreator']
import sympy as sp
from typing import Callable, List
from pystencils import Assignment, AssignmentCollection
from pystencils.sympyextensions import replaceAdditive
from pystencils.sympyextensions import subs_additive
def sympyCseOnEquationList(eqs):
ec = AssignmentCollection(eqs, [])
return sympyCSE(ec).allEquations
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
"""Extracts common subexpressions from a list of assignments."""
ec = AssignmentCollection(assignments, [])
return sympy_cse(ec).all_assignments
def sympyCSE(assignment_collection):
"""
Searches for common subexpressions inside the equation collection, in both the existing subexpressions as well
as the equations themselves. It uses the sympy subexpression detection to do this. Return a new equation collection
def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection:
"""Searches for common subexpressions inside the equation collection.
Searches is done in both the existing subexpressions as well as the assignments themselves.
It uses the sympy subexpression detection to do this. Return a new equation collection
with the additional subexpressions found
"""
symbolGen = assignment_collection.subexpressionSymbolNameGenerator
replacements, newEq = sp.cse(assignment_collection.subexpressions + assignment_collection.mainAssignments,
symbols=symbolGen)
replacementEqs = [Assignment(*r) for r in replacements]
symbol_gen = ac.subexpression_symbol_generator
replacements, new_eq = sp.cse(ac.subexpressions + ac.main_assignments,
symbols=symbol_gen)
replacement_eqs = [Assignment(*r) for r in replacements]
modifiedSubexpressions = newEq[:len(assignment_collection.subexpressions)]
modifiedUpdateEquations = newEq[len(assignment_collection.subexpressions):]
modified_subexpressions = new_eq[:len(ac.subexpressions)]
modified_update_equations = new_eq[len(ac.subexpressions):]
newSubexpressions = replacementEqs + modifiedSubexpressions
topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions])
newSubexpressions = [Assignment(a[0], a[1]) for a in topologicallySortedPairs]
new_subexpressions = replacement_eqs + modified_subexpressions
topologically_sorted_pairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in new_subexpressions])
new_subexpressions = [Assignment(a[0], a[1]) for a in topologically_sorted_pairs]
return assignment_collection.copy(modifiedUpdateEquations, newSubexpressions)
return ac.copy(modified_update_equations, new_subexpressions)
def applyOnAllEquations(assignment_collection, operation):
def apply_to_all_assignments(assignment_collection: AssignmentCollection,
operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
"""Applies sympy expand operation to all equations in collection"""
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.mainAssignments]
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments]
return assignment_collection.copy(result)
def applyOnAllSubexpressions(assignment_collection, operation):
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.subexpressions]
return assignment_collection.copy(assignment_collection.mainAssignments, result)
def apply_on_all_subexpressions(ac: AssignmentCollection,
operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions]
return ac.copy(ac.main_assignments, result)
def subexpressionSubstitutionInExistingSubexpressions(assignment_collection):
def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollection) -> AssignmentCollection:
"""Goes through the subexpressions list and replaces the term in the following subexpressions"""
result = []
for outerCtr, s in enumerate(assignment_collection.subexpressions):
newRhs = s.rhs
for outerCtr, s in enumerate(ac.subexpressions):
new_rhs = s.rhs
for innerCtr in range(outerCtr):
subExpr = assignment_collection.subexpressions[innerCtr]
newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0)
newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs)
result.append(Assignment(s.lhs, newRhs))
sub_expr = ac.subexpressions[innerCtr]
new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs)
result.append(Assignment(s.lhs, new_rhs))
return assignment_collection.copy(assignment_collection.mainAssignments, result)
return ac.copy(ac.main_assignments, result)
def subexpressionSubstitutionInmainAssignments(assignment_collection):
"""Replaces already existing subexpressions in the equations of the assignment_collection"""
def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) -> AssignmentCollection:
"""Replaces already existing subexpressions in the equations of the assignment_collection."""
result = []
for s in assignment_collection.mainAssignments:
newRhs = s.rhs
for subExpr in assignment_collection.subexpressions:
newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0)
result.append(Assignment(s.lhs, newRhs))
return assignment_collection.copy(result)
for s in ac.main_assignments:
new_rhs = s.rhs
for subExpr in ac.subexpressions:
new_rhs = subs_additive(new_rhs, subExpr.lhs, subExpr.rhs, required_match_replacement=1.0)
result.append(Assignment(s.lhs, new_rhs))
return ac.copy(result)
def addSubexpressionsForDivisions(assignment_collection):
def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentCollection:
"""Introduces subexpressions for all divisions which have no constant in the denominator.
e.g. :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced."""
For example :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced.
"""
divisors = set()
def searchDivisors(term):
def search_divisors(term):
if term.func == sp.Pow:
if term.exp.is_integer and term.exp.is_number and term.exp < 0:
divisors.add(term)
else:
for a in term.args:
searchDivisors(a)
search_divisors(a)
for eq in assignment_collection.allEquations:
searchDivisors(eq.rhs)
for eq in ac.all_assignments:
search_divisors(eq.rhs)
newSymbolGen = assignment_collection.subexpressionSymbolNameGenerator
substitutions = {divisor: newSymbol for newSymbol, divisor in zip(newSymbolGen, divisors)}
return assignment_collection.copyWithSubstitutionsApplied(substitutions, True)
new_symbol_gen = ac.subexpression_symbol_generator
substitutions = {divisor: newSymbol for newSymbol, divisor in zip(new_symbol_gen, divisors)}
return ac.new_with_substitutions(substitutions, True)
import sympy as sp
from collections import namedtuple
from typing import Callable, Any, Optional, Sequence
from pystencils.assignment_collection.assignment_collection import AssignmentCollection
class SimplificationStrategy(object):
"""
A simplification strategy is an ordered collection of simplification rules.
"""A simplification strategy is an ordered collection of simplification rules.
Each simplification is a function taking an equation collection, and returning a new simplified
equation collection. The strategy can nicely print intermediate simplification stages and results
to Jupyter notebooks.
......@@ -13,10 +15,11 @@ class SimplificationStrategy(object):
def __init__(self):
self._rules = []
def add(self, rule):
"""
Adds the given simplification rule to the end of the collection.
:param rule: function that taking one equation collection and returning a (simplified) equation collection
def add(self, rule: Callable[[AssignmentCollection], AssignmentCollection]) -> None:
"""Adds the given simplification rule to the end of the collection.
Args:
rule: function that rewrites/simplifies an assignment collection
"""
self._rules.append(rule)
......@@ -24,19 +27,20 @@ class SimplificationStrategy(object):
def rules(self):
return self._rules
def apply(self, updateRule):
"""Applies all simplification rules to the given equation collection"""
def apply(self, assignment_collection: AssignmentCollection) -> AssignmentCollection:
"""Runs all rules on the given assignment collection."""
for t in self._rules:
updateRule = t(updateRule)
return updateRule
assignment_collection = t(assignment_collection)
return assignment_collection
def __call__(self, assignment_collection):
def __call__(self, assignment_collection: AssignmentCollection) -> AssignmentCollection:
"""Same as apply"""
return self.apply(assignment_collection)
def createSimplificationReport(self, assignment_collection):
"""
Returns a simplification report containing the number of operations at each simplification stage, together
def create_simplification_report(self, assignment_collection: AssignmentCollection) -> Any:
"""Creates a report to be displayed as HTML in a Jupyter notebook.
The simplification report contains the number of operations at each simplification stage together
with the run-time the simplification took.
"""
......@@ -60,70 +64,83 @@ class SimplificationStrategy(object):
return result
def _repr_html_(self):
htmlTable = '<table style="border:none">'
htmlTable += "<tr><th>Name</th><th>Runtime</th><th>Adds</th><th>Muls</th><th>Divs</th><th>Total</th></tr>"
html_table = '<table style="border:none">'
html_table += "<tr><th>Name</th>" \
"<th>Runtime</th>" \
"<th>Adds</th>" \
"<th>Muls</th>" \
"<th>Divs</th>" \
"<th>Total</th></tr>"
line = "<tr><td>{simplificationName}</td>" \
"<td>{runtime}</td> <td>{adds}</td> <td>{muls}</td> <td>{divs}</td> <td>{total}</td> </tr>"
for e in self.elements:
htmlTable += line.format(**e._asdict())
htmlTable += "</table>"
return htmlTable
# noinspection PyProtectedMember
html_table += line.format(**e._asdict())
html_table += "</table>"
return html_table
import timeit
report = Report()
op = assignment_collection.operationCount
op = assignment_collection.operation_count
total = op['adds'] + op['muls'] + op['divs']
report.add(ReportElement("OriginalTerm", '-', op['adds'], op['muls'], op['divs'], total))
for t in self._rules:
startTime = timeit.default_timer()
start_time = timeit.default_timer()
assignment_collection = t(assignment_collection)
endTime = timeit.default_timer()
op = assignment_collection.operationCount
timeStr = "%.2f ms" % ((endTime - startTime) * 1000,)
end_time = timeit.default_timer()
op = assignment_collection.operation_count
time_str = "%.2f ms" % ((end_time - start_time) * 1000,)
total = op['adds'] + op['muls'] + op['divs']
report.add(ReportElement(t.__name__, timeStr, op['adds'], op['muls'], op['divs'], total))
report.add(ReportElement(t.__name__, time_str, op['adds'], op['muls'], op['divs'], total))
return report
def showIntermediateResults(self, assignment_collection, symbols=None):
def show_intermediate_results(self, assignment_collection: AssignmentCollection,
symbols: Optional[Sequence[sp.Symbol]] = None) -> Any:
"""Shows the assignment collection after the application of each rule as HTML report for Jupyter notebook.
Args:
assignment_collection: the collection to apply the rules to
symbols: if not None, only the assignments are shown that have one of these symbols as left hand side
"""
class IntermediateResults:
def __init__(self, strategy, eqColl, resSyms):
def __init__(self, strategy, collection, restrict_symbols):
self.strategy = strategy
self.assignment_collection = eqColl
self.restrictSymbols = resSyms
self.assignment_collection = collection
self.restrict_symbols = restrict_symbols
def __str__(self):
def printEqCollection(title, eqColl):
def print_assignment_collection(title, c):
text = title
if self.restrictSymbols:
text += "\n".join([str(e) for e in eqColl.get(self.restrictSymbols)])
if self.restrict_symbols:
text += "\n".join([str(e) for e in c.get(self.restrict_symbols)])
else:
text += (" " * 3 + (" " * 3).join(str(eqColl).splitlines(True)))
text += (" " * 3 + (" " * 3).join(str(c).splitlines(True)))
return text
result = printEqCollection("Initial Version", self.assignment_collection)
eqColl = self.assignment_collection
result = print_assignment_collection("Initial Version", self.assignment_collection)
collection = self.assignment_collection
for rule in self.strategy.rules:
eqColl = rule(eqColl)
result += printEqCollection(rule.__name__, eqColl)
collection = rule(collection)
result += print_assignment_collection(rule.__name__, collection)
return result
def _repr_html_(self):
def printEqCollection(title, eqColl):
def print_assignment_collection(title, c):
text = '<h5 style="padding-bottom:10px">%s</h5> <div style="padding-left:20px;">' % (title, )
if self.restrictSymbols:
text += "\n".join(["$$" + sp.latex(e) + '$$' for e in eqColl.get(self.restrictSymbols)])
if self.restrict_symbols:
text += "\n".join(["$$" + sp.latex(e) + '$$' for e in c.get(self.restrict_symbols)])
else:
text += eqColl._repr_html_()
# noinspection PyProtectedMember
text += c._repr_html_()
text += "</div>"
return text
result = printEqCollection("Initial Version", self.assignment_collection)
eqColl = self.assignment_collection
result = print_assignment_collection("Initial Version", self.assignment_collection)
collection = self.assignment_collection
for rule in self.strategy.rules:
eqColl = rule(eqColl)
result += printEqCollection(rule.__name__, eqColl)
collection = rule(collection)
result += print_assignment_collection(rule.__name__, collection)
return result
return IntermediateResults(self, assignment_collection, symbols)
......
......@@ -2,7 +2,7 @@ import sympy as sp
from sympy.tensor import IndexedBase
from pystencils.field import Field
from pystencils.data_types import TypedSymbol, createType, castFunc
from pystencils.sympyextensions import fastSubs
from pystencils.sympyextensions import fast_subs
class Node(object):
......@@ -275,11 +275,11 @@ class Block(Node):
@property
def undefinedSymbols(self):
result = set()
definedSymbols = set()
defined_symbols = set()
for a in self.args:
result.update(a.undefinedSymbols)
definedSymbols.update(a.symbolsDefined)
return result - definedSymbols
defined_symbols.update(a.symbolsDefined)
return result - defined_symbols
def __str__(self):
return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes)
......@@ -426,8 +426,8 @@ class SympyAssignment(Node):
self._isDeclaration = False
def subs(self, *args, **kwargs):
self.lhs = fastSubs(self.lhs, *args, **kwargs)
self.rhs = fastSubs(self.rhs, *args, **kwargs)
self.lhs = fast_subs(self.lhs, *args, **kwargs)
self.rhs = fast_subs(self.rhs, *args, **kwargs)
@property
def args(self):
......@@ -494,11 +494,11 @@ class ResolvedFieldAccess(sp.Indexed):
self.args[1].subs(old, new),
self.field, self.offsets, self.idxCoordinateValues)
def fastSubs(self, subsDict):
if self in subsDict:
return subsDict[self]
return ResolvedFieldAccess(self.args[0].subs(subsDict),
self.args[1].subs(subsDict),
def fast_subs(self, substitutions):
if self in substitutions:
return substitutions[self]
return ResolvedFieldAccess(self.args[0].subs(substitutions),
self.args[1].subs(substitutions),
self.field, self.offsets, self.idxCoordinateValues)
def _hashable_content(self):
......
import sympy as sp
from collections import namedtuple, defaultdict
from pystencils.sympyextensions import normalizeProduct, prod
from pystencils.sympyextensions import normalize_product, prod
def defaultDiffSortKey(d):
......@@ -57,7 +57,7 @@ class Diff(sp.Expr):
if self.arg.func != sp.Mul:
constant, variable = 1, self.arg
else:
for factor in normalizeProduct(self.arg):
for factor in normalize_product(self.arg):
if factor in functions or isinstance(factor, Diff):
variable *= factor
else:
......@@ -150,7 +150,7 @@ class DiffOperator(sp.Expr):
i.e. DiffOperator('x')*DiffOperator('x') is a second derivative replaced by Diff(Diff(arg, x), t)
"""
def handleMul(mul):
args = normalizeProduct(mul)
args = normalize_product(mul)
diffs = [a for a in args if isinstance(a, DiffOperator)]
if len(diffs) == 0:
return mul * argument if applyToConstants else mul
......@@ -254,7 +254,7 @@ def fullDiffExpand(expr, functions=None, constants=None):
for term in diffInner.args if diffInner.func == sp.Add else [diffInner]:
independentTerms = 1
dependentTerms = []
for factor in normalizeProduct(term):
for factor in normalize_product(term):
if factor in functions or isinstance(factor, Diff):
dependentTerms.append(factor)
else:
......@@ -310,7 +310,7 @@ def expandUsingProductRule(expr):
if arg.func not in (sp.Mul, sp.Pow):
return Diff(arg, target=expr.target, superscript=expr.superscript)
else:
prodList = normalizeProduct(arg)
prodList = normalize_product(arg)
result = 0
for i in range(len(prodList)):
preFactor = prod(prodList[j] for j in range(len(prodList)) if i != j)
......@@ -347,7 +347,7 @@ def combineUsingProductRule(expr):
if isinstance(term, Diff):
diffDict[DiffInfo(term.target, term.superscript)].append(DiffSplit(1, term.arg))
else:
mulArgs = normalizeProduct(term)
mulArgs = normalize_product(term)
diffs = [d for d in mulArgs if isinstance(d, Diff)]
factor = prod(d for d in mulArgs if not isinstance(d, Diff))
if len(diffs) == 0:
......
......@@ -8,7 +8,7 @@ from sympy.tensor import IndexedBase
from pystencils.assignment import Assignment
from pystencils.alignedarray import aligned_empty
from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFromString, StructType
from pystencils.sympyextensions import isIntegerSequence
from pystencils.sympyextensions import is_integer_sequence
class FieldType(Enum):
......@@ -221,7 +221,7 @@ class Field(object):
@property
def hasFixedShape(self):
return isIntegerSequence(self.shape)
return is_integer_sequence(self.shape)
@property
def indexShape(self):
......@@ -229,7 +229,7 @@ class Field(object):
@property
def hasFixedIndexShape(self):
return isIntegerSequence(self.indexShape)
return is_integer_sequence(self.indexShape)
@property
def spatialStrides(self):
......
......@@ -3,7 +3,7 @@ import sympy as sp
from pystencils.assignment_collection import AssignmentCollection
from pystencils.field import Field
from pystencils.transformations import fastSubs
from pystencils.sympyextensions import fast_subs
from pystencils.derivative import Diff
......@@ -103,7 +103,7 @@ def discretizeStaggered(term, symbolsToFieldDict, coordinate, coordinateOffset,
neighborGrad = (field[up+offset](i) - field[down+offset](i)) / (2 * dx)
substitutions[grad(s)[d]] = (centerGrad + neighborGrad) / 2
return fastSubs(term, substitutions)
return fast_subs(term, substitutions)
def discretizeDivergence(vectorTerm, symbolsToFieldDict, dx):
......@@ -356,7 +356,7 @@ class Discretization2ndOrder:
elif isinstance(expr, sp.Matrix):
return expr.applyfunc(self.__call__)
elif isinstance(expr, AssignmentCollection):
return expr.copy(mainAssignments=[e for e in expr.mainAssignments],
return expr.copy(main_assignments=[e for e in expr.main_assignments],
subexpressions=[e for e in expr.subexpressions])
transientTerms = expr.atoms(Transient)
......
......@@ -12,7 +12,7 @@ from kerncraft.iaca import iaca_analyse_instrumented_binary, iaca_instrumentatio
from pystencils.kerncraft_coupling.generate_benchmark import generateBenchmark
from pystencils.astnodes import LoopOverCoordinate, SympyAssignment, ResolvedFieldAccess
from pystencils.field import getLayoutFromStrides
from pystencils.sympyextensions import countNumberOfOperationsInAst
from pystencils.sympyextensions import count_operations_in_ast
from pystencils.utils import DotDict
......@@ -78,7 +78,7 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel):
self.datatype = list(self.variables.values())[0][0]
# flops
operationCount = countNumberOfOperationsInAst(innerLoop)
operationCount = count_operations_in_ast(innerLoop)
self._flops = {
'+': operationCount['adds'],
'*': operationCount['muls'],
......
......@@ -33,9 +33,9 @@ def createKernel(equations, target='cpu', dataType="double", iterationSlice=None
# ---- Normalizing parameters
splitGroups = ()
if isinstance(equations, AssignmentCollection):
if 'splitGroups' in equations.simplificationHints:
splitGroups = equations.simplificationHints['splitGroups']
equations = equations.allEquations
if 'splitGroups' in equations.simplification_hints:
splitGroups = equations.simplification_hints['splitGroups']
equations = equations.all_assignments
# ---- Creating ast
if target == 'cpu':
......@@ -84,7 +84,7 @@ def createIndexedKernel(equations, indexFields, target='cpu', dataType="double",
"""
if isinstance(equations, AssignmentCollection):
equations = equations.allEquations
equations = equations.all_assignments
if target == 'cpu':
from pystencils.cpu import createIndexedKernel
from pystencils.cpu import addOpenMP
......
This diff is collapsed.
......@@ -21,22 +21,6 @@ def filteredTreeIteration(node, nodeType):
yield from filteredTreeIteration(arg, nodeType)
def fastSubs(term, subsDict):
"""Similar to sympy subs function.
This version is much faster for big substitution dictionaries than sympy version"""
if type(term) is sp.Matrix:
return term.copy().applyfunc(functools.partial(fastSubs, subsDict=subsDict))
def visit(expr):
if expr in subsDict:
return subsDict[expr]
if not hasattr(expr, 'args'):
return expr
paramList = [visit(a) for a in expr.args]
return expr if not paramList else expr.func(*paramList)
return visit(term)
def getCommonShape(fieldSet):
"""Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
ValueError is raised"""
......
import sympy as sp
import warnings
from pystencils.sympyextensions import fastSubs
from pystencils.sympyextensions import fast_subs
from pystencils.transformations import filteredTreeIteration
from pystencils.data_types import TypedSymbol, VectorType, BasicType, getTypeOfExpression, castFunc, collateTypes, \
PointerType
......@@ -97,7 +97,7 @@ def insertVectorCasts(astNode):
substitutionDict = {}
for asmt in filteredTreeIteration(astNode, ast.SympyAssignment):
subsExpr = fastSubs(asmt.rhs, substitutionDict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
subsExpr = fast_subs(asmt.rhs, substitutionDict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
asmt.rhs = visitExpr(subsExpr)
rhsType = getTypeOfExpression(asmt.rhs)
if isinstance(asmt.lhs, TypedSymbol):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment