Commit 1eacfdad authored by Martin Bauer's avatar Martin Bauer
Browse files

New transformations for staggered field traversal

- loop cutting
- simplification of conditionals inside loop
parent 74b69826
...@@ -20,6 +20,13 @@ class ResolvedFieldAccess(sp.Indexed): ...@@ -20,6 +20,13 @@ class ResolvedFieldAccess(sp.Indexed):
self.args[1].subs(old, new), self.args[1].subs(old, new),
self.field, self.offsets, self.idxCoordinateValues) self.field, self.offsets, self.idxCoordinateValues)
def fastSubs(self, subsDict):
if self in subsDict:
return subsDict[self]
return ResolvedFieldAccess(self.args[0].subs(subsDict),
self.field, self.offsets, self.idxCoordinateValues)
def _hashable_content(self): def _hashable_content(self):
superClassContents = super(ResolvedFieldAccess, self)._hashable_content() superClassContents = super(ResolvedFieldAccess, self)._hashable_content()
return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field)) return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field))
...@@ -89,8 +96,23 @@ class Conditional(Node): ...@@ -89,8 +96,23 @@ class Conditional(Node):
""" """
assert conditionExpr.is_Boolean or conditionExpr.is_Relational assert conditionExpr.is_Boolean or conditionExpr.is_Relational
self.conditionExpr = conditionExpr self.conditionExpr = conditionExpr
self.trueBlock = trueBlock
self.falseBlock = falseBlock def handleChild(c):
if c is None:
return None
if not isinstance(c, Block):
c = Block([c])
c.parent = self
return c
self.trueBlock = handleChild(trueBlock)
self.falseBlock = handleChild(falseBlock)
def subs(self, *args, **kwargs):
self.trueBlock.subs(*args, **kwargs)
if self.falseBlock:
self.falseBlock.subs(*args, **kwargs)
self.conditionExpr = self.conditionExpr.subs(*args, **kwargs)
@property @property
def args(self): def args(self):
...@@ -107,7 +129,7 @@ class Conditional(Node): ...@@ -107,7 +129,7 @@ class Conditional(Node):
def undefinedSymbols(self): def undefinedSymbols(self):
result = self.trueBlock.undefinedSymbols result = self.trueBlock.undefinedSymbols
if self.falseBlock: if self.falseBlock:
result = result.update(self.falseBlock.undefinedSymbols) result.update(self.falseBlock.undefinedSymbols)
result.update(self.conditionExpr.atoms(sp.Symbol)) result.update(self.conditionExpr.atoms(sp.Symbol))
return result return result
...@@ -243,11 +265,21 @@ class Block(Node): ...@@ -243,11 +265,21 @@ class Block(Node):
def insertBefore(self, newNode, insertBefore): def insertBefore(self, newNode, insertBefore):
newNode.parent = self newNode.parent = self
idx = self._nodes.index(insertBefore) idx = self._nodes.index(insertBefore)
# move all assignment (definitions to the top)
if isinstance(newNode, SympyAssignment) and newNode.isDeclaration:
while idx > 0 and not (isinstance(self._nodes[idx-1], SympyAssignment) and self._nodes[idx-1].isDeclaration):
idx -= 1
self._nodes.insert(idx, newNode) self._nodes.insert(idx, newNode)
def append(self, node): def append(self, node):
node.parent = self if isinstance(node, list) or isinstance(node, tuple):
self._nodes.append(node) for n in node:
n.parent = self
node.parent = self
def takeChildNodes(self): def takeChildNodes(self):
tmp = self._nodes tmp = self._nodes
...@@ -339,7 +371,6 @@ class LoopOverCoordinate(Node): ...@@ -339,7 +371,6 @@ class LoopOverCoordinate(Node):
elif child == self.stop: elif child == self.stop:
self.stop = replacement self.stop = replacement
@property @property
def symbolsDefined(self): def symbolsDefined(self):
return set([self.loopCounterSymbol]) return set([self.loopCounterSymbol])
...@@ -389,14 +420,14 @@ class LoopOverCoordinate(Node): ...@@ -389,14 +420,14 @@ class LoopOverCoordinate(Node):
def __str__(self): def __str__(self):
return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loopCounterName, self.start, return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loopCounterName, self.start,
self.loopCounterName, self.stop, self.loopCounterName, self.stop,
self.loopCounterName, self.step, self.loopCounterName, self.step,
("\t" + "\t".join(str(self.body).splitlines(True)))) ("\t" + "\t".join(str(self.body).splitlines(True))))
def __repr__(self): def __repr__(self):
return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loopCounterName, self.start, return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loopCounterName, self.start,
self.loopCounterName, self.stop, self.loopCounterName, self.stop,
self.loopCounterName, self.step) self.loopCounterName, self.step)
class SympyAssignment(Node): class SympyAssignment(Node):
...@@ -161,7 +161,7 @@ class CBackend(object): ...@@ -161,7 +161,7 @@ class CBackend(object):
def _print_Conditional(self, node): def _print_Conditional(self, node):
conditionExpr = self.sympyPrinter.doprint(node.conditionExpr) conditionExpr = self.sympyPrinter.doprint(node.conditionExpr)
trueBlock = self._print_Block(node.trueBlock) trueBlock = self._print_Block(node.trueBlock)
result = "if (%s) \n %s " % (conditionExpr, trueBlock) result = "if (%s)\n%s " % (conditionExpr, trueBlock)
if node.falseBlock: if node.falseBlock:
falseBlock = self._print_Block(node.falseBlock) falseBlock = self._print_Block(node.falseBlock)
result += "else " + falseBlock result += "else " + falseBlock
""" r"""
*pystencils* looks for a configuration file in JSON format at the following locations in the listed order. *pystencils* looks for a configuration file in JSON format at the following locations in the listed order.
...@@ -58,7 +58,6 @@ compiled into the shared library. Then, the same script can be run from the comp ...@@ -58,7 +58,6 @@ compiled into the shared library. Then, the same script can be run from the comp
- **'objectCache'**: path to a folder where intermediate files are stored - **'objectCache'**: path to a folder where intermediate files are stored
- **'clearCacheOnStart'**: when true the cache is cleared on each start of a *pystencils* script - **'clearCacheOnStart'**: when true the cache is cleared on each start of a *pystencils* script
- **'sharedLibrary'**: path to a shared library file, which is created if `readFromSharedLibrary=false` - **'sharedLibrary'**: path to a shared library file, which is created if `readFromSharedLibrary=false`
""" """
from __future__ import print_function from __future__ import print_function
import os import os
...@@ -197,7 +196,8 @@ def readConfig(): ...@@ -197,7 +196,8 @@ def readConfig():
configPath, configExists = getConfigurationFilePath() configPath, configExists = getConfigurationFilePath()
config = defaultConfig.copy() config = defaultConfig.copy()
if configExists: if configExists:
loadedConfig = json.load(open(configPath, 'r')) with open(configPath, 'r') as jsonConfigFile:
loadedConfig = json.load(jsonConfigFile)
config = _recursiveDictUpdate(config, loadedConfig) config = _recursiveDictUpdate(config, loadedConfig)
else: else:
createFolder(configPath, True) createFolder(configPath, True)
...@@ -86,6 +86,8 @@ def fastSubs(term, subsDict, skip=None): ...@@ -86,6 +86,8 @@ def fastSubs(term, subsDict, skip=None):
def visit(expr): def visit(expr):
if skip and skip(expr): if skip and skip(expr):
return expr return expr
if hasattr(expr, "fastSubs"):
return expr.fastSubs(subsDict)
if expr in subsDict: if expr in subsDict:
return subsDict[expr] return subsDict[expr]
if not hasattr(expr, 'args'): if not hasattr(expr, 'args'):
...@@ -390,7 +390,12 @@ def moveConstantsBeforeLoop(astNode): ...@@ -390,7 +390,12 @@ def moveConstantsBeforeLoop(astNode):
if isinstance(element, ast.Block): if isinstance(element, ast.Block):
lastBlock = element lastBlock = element
lastBlockChild = prevElement lastBlockChild = prevElement
if node.undefinedSymbols.intersection(element.symbolsDefined):
if isinstance(element, ast.Conditional):
criticalSymbols = element.conditionExpr.atoms(sp.Symbol)
criticalSymbols = element.symbolsDefined
if node.undefinedSymbols.intersection(criticalSymbols):
break break
prevElement = element prevElement = element
element = element.parent element = element.parent
...@@ -496,6 +501,120 @@ def splitInnerLoop(astNode, symbolGroups): ...@@ -496,6 +501,120 @@ def splitInnerLoop(astNode, symbolGroups):
outerLoop.parent.append(ast.TemporaryMemoryFree(tmpArrayPointer)) outerLoop.parent.append(ast.TemporaryMemoryFree(tmpArrayPointer))
def cutLoop(loopNode, cuttingPoints):
"""Cuts loop at given cutting points, that means one loop is transformed into len(cuttingPoints)+1 new loops
that range from oldBegin to cuttingPoint[1], ..., cuttingPoint[-1] to oldEnd"""
if loopNode.step != 1:
raise NotImplementedError("Can only split loops that have a step of 1")
newLoops = []
newStart = loopNode.start
cuttingPoints = list(cuttingPoints) + [loopNode.stop]
for newEnd in cuttingPoints:
if newEnd - newStart == 1:
newBody = deepcopy(loopNode.body)
newBody.subs({loopNode.loopCounterSymbol: newStart})
newLoop = ast.LoopOverCoordinate(deepcopy(loopNode.body), loopNode.coordinateToLoopOver,
newStart, newEnd, loopNode.step)
newStart = newEnd
loopNode.parent.replace(loopNode, newLoops)
def isConditionNecessary(condition, preCondition, symbol):
Determines if a logical condition of a single variable is already contained in a stronger preCondition
so if from preCondition follows that condition is always true, then this condition is not necessary
:param condition: sympy relational of one variable
:param preCondition: logical expression that is known to be true
:param symbol: the single symbol of interest
:return: returns not (preCondition => condition) where "=>" is logical implication
from sympy.solvers.inequalities import reduce_rational_inequalities
from sympy.logic.boolalg import to_dnf
def toDnfList(expr):
result = to_dnf(expr)
if isinstance(result, sp.Or):
return [orTerm.args for orTerm in result.args]
elif isinstance(result, sp.And):
return [result.args]
return result
t1 = reduce_rational_inequalities(toDnfList(sp.And(condition, preCondition)), symbol)
t2 = reduce_rational_inequalities(toDnfList(preCondition), symbol)
return t1 != t2
def simplifyBooleanExpression(expr, singleVariableRanges):
"""Simplification of boolean expression using known ranges of variables
The singleVariableRanges parameter is a dict mapping a variable name to a sympy logical expression that
contains only this variable and defines a range for it. For example with a being a symbol
{ a: sp.And(a >=0, a < 10) }
from sympy.core.relational import Relational
from sympy.logic.boolalg import to_dnf
expr = to_dnf(expr)
def visit(e):
if isinstance(e, Relational):
symbols = e.atoms(sp.Symbol)
if len(symbols) == 1:
symbol = symbols.pop()
if symbol in singleVariableRanges:
if not isConditionNecessary(e, singleVariableRanges[symbol], symbol):
return sp.true
return e
newArgs = [visit(a) for a in e.args]
return e.func(*newArgs) if newArgs else e
return visit(expr)
def simplifyConditionals(node, loopConditionals={}):
"""Simplifies/Removes conditions inside loops that depend on the loop counter."""
if isinstance(node, ast.LoopOverCoordinate):
ctrSym = node.loopCounterSymbol
loopConditionals[ctrSym] = sp.And(ctrSym >= node.start, ctrSym < node.stop)
del loopConditionals[ctrSym]
elif isinstance(node, ast.Conditional):
node.conditionExpr = simplifyBooleanExpression(node.conditionExpr, loopConditionals)
if node.falseBlock:
if node.conditionExpr == sp.true:
node.parent.replace(node, [node.trueBlock])
if node.conditionExpr == sp.false:
node.parent.replace(node, [node.falseBlock] if node.falseBlock else [])
elif isinstance(node, ast.Block):
for a in list(node.args):
elif isinstance(node, ast.SympyAssignment):
return node
raise ValueError("Can not handle node", type(node))
def cleanupBlocks(node):
"""Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """
if isinstance(node, ast.SympyAssignment):
elif isinstance(node, ast.Block):
for a in list(node.args):
if len(node.args) <= 1 and isinstance(node.parent, ast.Block):
node.parent.replace(node, node.args)
for a in node.args:
def symbolNameToVariableName(symbolName): def symbolNameToVariableName(symbolName):
"""Replaces characters which are allowed in sympy symbol names but not in C/C++ variable names""" """Replaces characters which are allowed in sympy symbol names but not in C/C++ variable names"""
return symbolName.replace("^", "_") return symbolName.replace("^", "_")
...@@ -546,17 +665,23 @@ def typeAllEquations(eqs, typeForSymbol): ...@@ -546,17 +665,23 @@ def typeAllEquations(eqs, typeForSymbol):
else: else:
assert False, "Expected a symbol as left-hand-side" assert False, "Expected a symbol as left-hand-side"
typedEquations = [] def visit(object):
for eq in eqs: if isinstance(object, list) or isinstance(object, tuple):
if isinstance(eq, sp.Eq) or isinstance(eq, ast.SympyAssignment): return [visit(e) for e in object]
newLhs = processLhs(eq.lhs) if isinstance(object, sp.Eq) or isinstance(object, ast.SympyAssignment):
newRhs = processRhs(eq.rhs) newLhs = processLhs(object.lhs)
typedEquations.append(ast.SympyAssignment(newLhs, newRhs)) newRhs = processRhs(object.rhs)
return ast.SympyAssignment(newLhs, newRhs)
elif isinstance(object, ast.Conditional):
falseBlock = None if object.falseBlock is None else visit(object.falseBlock)
return ast.Conditional(processRhs(object.conditionExpr),
trueBlock=visit(object.trueBlock), falseBlock=falseBlock)
elif isinstance(object, ast.Block):
return ast.Block([visit(e) for e in object.args])
else: else:
assert isinstance(eq, ast.Node), "Only equations and ast nodes are allowed in input" return object
typedEquations = typedEquations typedEquations = visit(eqs)
return fieldsRead, fieldsWritten, typedEquations return fieldsRead, fieldsWritten, typedEquations
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment