diff --git a/astnodes.py b/astnodes.py index 66d4ae8d0e4eaf1a394a03f43a3a0957e40d42b8..89d3a85449f86b3bd77986661e0a1a37e80b6712 100644 --- a/astnodes.py +++ b/astnodes.py @@ -20,6 +20,13 @@ class ResolvedFieldAccess(sp.Indexed): self.args[1].subs(old, new), self.field, self.offsets, self.idxCoordinateValues) + def fastSubs(self, subsDict): + if self in subsDict: + return subsDict[self] + return ResolvedFieldAccess(self.args[0].subs(subsDict), + self.args[1].subs(subsDict), + self.field, self.offsets, self.idxCoordinateValues) + def _hashable_content(self): superClassContents = super(ResolvedFieldAccess, self)._hashable_content() return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field)) @@ -89,8 +96,23 @@ class Conditional(Node): """ assert conditionExpr.is_Boolean or conditionExpr.is_Relational self.conditionExpr = conditionExpr - self.trueBlock = trueBlock - self.falseBlock = falseBlock + + def handleChild(c): + if c is None: + return None + if not isinstance(c, Block): + c = Block([c]) + c.parent = self + return c + + self.trueBlock = handleChild(trueBlock) + self.falseBlock = handleChild(falseBlock) + + def subs(self, *args, **kwargs): + self.trueBlock.subs(*args, **kwargs) + if self.falseBlock: + self.falseBlock.subs(*args, **kwargs) + self.conditionExpr = self.conditionExpr.subs(*args, **kwargs) @property def args(self): @@ -107,7 +129,7 @@ class Conditional(Node): def undefinedSymbols(self): result = self.trueBlock.undefinedSymbols if self.falseBlock: - result = result.update(self.falseBlock.undefinedSymbols) + result.update(self.falseBlock.undefinedSymbols) result.update(self.conditionExpr.atoms(sp.Symbol)) return result @@ -243,11 +265,21 @@ class Block(Node): def insertBefore(self, newNode, insertBefore): newNode.parent = self idx = self._nodes.index(insertBefore) + + # move all assignment (definitions to the top) + if isinstance(newNode, SympyAssignment) and newNode.isDeclaration: + while idx > 0 and not (isinstance(self._nodes[idx-1], SympyAssignment) and self._nodes[idx-1].isDeclaration): + idx -= 1 self._nodes.insert(idx, newNode) def append(self, node): - node.parent = self - self._nodes.append(node) + if isinstance(node, list) or isinstance(node, tuple): + for n in node: + n.parent = self + self._nodes.append(n) + else: + node.parent = self + self._nodes.append(node) def takeChildNodes(self): tmp = self._nodes @@ -339,7 +371,6 @@ class LoopOverCoordinate(Node): elif child == self.stop: self.stop = replacement - @property def symbolsDefined(self): return set([self.loopCounterSymbol]) @@ -389,14 +420,14 @@ class LoopOverCoordinate(Node): def __str__(self): return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loopCounterName, self.start, - self.loopCounterName, self.stop, - self.loopCounterName, self.step, - ("\t" + "\t".join(str(self.body).splitlines(True)))) + self.loopCounterName, self.stop, + self.loopCounterName, self.step, + ("\t" + "\t".join(str(self.body).splitlines(True)))) def __repr__(self): return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loopCounterName, self.start, - self.loopCounterName, self.stop, - self.loopCounterName, self.step) + self.loopCounterName, self.stop, + self.loopCounterName, self.step) class SympyAssignment(Node): diff --git a/backends/cbackend.py b/backends/cbackend.py index f19dec05f806a8b3c2b93c0e7b14d026fd889817..00d69957fe301feed21162cc8e9f9c57f139eca7 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -161,7 +161,7 @@ class CBackend(object): def _print_Conditional(self, node): conditionExpr = self.sympyPrinter.doprint(node.conditionExpr) trueBlock = self._print_Block(node.trueBlock) - result = "if (%s) \n %s " % (conditionExpr, trueBlock) + result = "if (%s)\n%s " % (conditionExpr, trueBlock) if node.falseBlock: falseBlock = self._print_Block(node.falseBlock) result += "else " + falseBlock diff --git a/cpu/cpujit.py b/cpu/cpujit.py index 37229a9769704f7afe0d0fcec02daa11380e2032..69c3d8cd84b991ba86719733dfc44e80b813771f 100644 --- a/cpu/cpujit.py +++ b/cpu/cpujit.py @@ -1,4 +1,4 @@ -""" +r""" *pystencils* looks for a configuration file in JSON format at the following locations in the listed order. @@ -58,7 +58,6 @@ compiled into the shared library. Then, the same script can be run from the comp - **'objectCache'**: path to a folder where intermediate files are stored - **'clearCacheOnStart'**: when true the cache is cleared on each start of a *pystencils* script - **'sharedLibrary'**: path to a shared library file, which is created if `readFromSharedLibrary=false` - """ from __future__ import print_function import os @@ -197,7 +196,8 @@ def readConfig(): configPath, configExists = getConfigurationFilePath() config = defaultConfig.copy() if configExists: - loadedConfig = json.load(open(configPath, 'r')) + with open(configPath, 'r') as jsonConfigFile: + loadedConfig = json.load(jsonConfigFile) config = _recursiveDictUpdate(config, loadedConfig) else: createFolder(configPath, True) diff --git a/sympyextensions.py b/sympyextensions.py index 4abeb7b66097843ab1d7fdc8348d97a39acdb38c..a45ede9f89428e17c073a09bcc1e3422b07cea61 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -86,6 +86,8 @@ def fastSubs(term, subsDict, skip=None): def visit(expr): if skip and skip(expr): return expr + if hasattr(expr, "fastSubs"): + return expr.fastSubs(subsDict) if expr in subsDict: return subsDict[expr] if not hasattr(expr, 'args'): diff --git a/transformations/transformations.py b/transformations/transformations.py index 7cc41daf13d30b3fd9c0873d1e867f841bf32ea8..1a8b797f613181b97bd482b482c4e3070e733ff4 100644 --- a/transformations/transformations.py +++ b/transformations/transformations.py @@ -390,7 +390,12 @@ def moveConstantsBeforeLoop(astNode): if isinstance(element, ast.Block): lastBlock = element lastBlockChild = prevElement - if node.undefinedSymbols.intersection(element.symbolsDefined): + + if isinstance(element, ast.Conditional): + criticalSymbols = element.conditionExpr.atoms(sp.Symbol) + else: + criticalSymbols = element.symbolsDefined + if node.undefinedSymbols.intersection(criticalSymbols): break prevElement = element element = element.parent @@ -496,6 +501,120 @@ def splitInnerLoop(astNode, symbolGroups): outerLoop.parent.append(ast.TemporaryMemoryFree(tmpArrayPointer)) +def cutLoop(loopNode, cuttingPoints): + """Cuts loop at given cutting points, that means one loop is transformed into len(cuttingPoints)+1 new loops + that range from oldBegin to cuttingPoint[1], ..., cuttingPoint[-1] to oldEnd""" + if loopNode.step != 1: + raise NotImplementedError("Can only split loops that have a step of 1") + newLoops = [] + newStart = loopNode.start + cuttingPoints = list(cuttingPoints) + [loopNode.stop] + for newEnd in cuttingPoints: + if newEnd - newStart == 1: + newBody = deepcopy(loopNode.body) + newBody.subs({loopNode.loopCounterSymbol: newStart}) + newLoops.append(newBody) + else: + newLoop = ast.LoopOverCoordinate(deepcopy(loopNode.body), loopNode.coordinateToLoopOver, + newStart, newEnd, loopNode.step) + newLoops.append(newLoop) + newStart = newEnd + loopNode.parent.replace(loopNode, newLoops) + + +def isConditionNecessary(condition, preCondition, symbol): + """ + Determines if a logical condition of a single variable is already contained in a stronger preCondition + so if from preCondition follows that condition is always true, then this condition is not necessary + :param condition: sympy relational of one variable + :param preCondition: logical expression that is known to be true + :param symbol: the single symbol of interest + :return: returns not (preCondition => condition) where "=>" is logical implication + """ + from sympy.solvers.inequalities import reduce_rational_inequalities + from sympy.logic.boolalg import to_dnf + + def toDnfList(expr): + result = to_dnf(expr) + if isinstance(result, sp.Or): + return [orTerm.args for orTerm in result.args] + elif isinstance(result, sp.And): + return [result.args] + else: + return result + + t1 = reduce_rational_inequalities(toDnfList(sp.And(condition, preCondition)), symbol) + t2 = reduce_rational_inequalities(toDnfList(preCondition), symbol) + return t1 != t2 + + +def simplifyBooleanExpression(expr, singleVariableRanges): + """Simplification of boolean expression using known ranges of variables + The singleVariableRanges parameter is a dict mapping a variable name to a sympy logical expression that + contains only this variable and defines a range for it. For example with a being a symbol + { a: sp.And(a >=0, a < 10) } + """ + from sympy.core.relational import Relational + from sympy.logic.boolalg import to_dnf + + expr = to_dnf(expr) + + def visit(e): + if isinstance(e, Relational): + symbols = e.atoms(sp.Symbol) + if len(symbols) == 1: + symbol = symbols.pop() + if symbol in singleVariableRanges: + if not isConditionNecessary(e, singleVariableRanges[symbol], symbol): + return sp.true + return e + else: + newArgs = [visit(a) for a in e.args] + return e.func(*newArgs) if newArgs else e + + return visit(expr) + + +def simplifyConditionals(node, loopConditionals={}): + """Simplifies/Removes conditions inside loops that depend on the loop counter.""" + if isinstance(node, ast.LoopOverCoordinate): + ctrSym = node.loopCounterSymbol + loopConditionals[ctrSym] = sp.And(ctrSym >= node.start, ctrSym < node.stop) + simplifyConditionals(node.body) + del loopConditionals[ctrSym] + elif isinstance(node, ast.Conditional): + node.conditionExpr = simplifyBooleanExpression(node.conditionExpr, loopConditionals) + simplifyConditionals(node.trueBlock) + if node.falseBlock: + simplifyConditionals(node.falseBlock) + if node.conditionExpr == sp.true: + node.parent.replace(node, [node.trueBlock]) + if node.conditionExpr == sp.false: + node.parent.replace(node, [node.falseBlock] if node.falseBlock else []) + elif isinstance(node, ast.Block): + for a in list(node.args): + simplifyConditionals(a) + elif isinstance(node, ast.SympyAssignment): + return node + else: + raise ValueError("Can not handle node", type(node)) + + +def cleanupBlocks(node): + """Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """ + if isinstance(node, ast.SympyAssignment): + return + elif isinstance(node, ast.Block): + for a in list(node.args): + cleanupBlocks(a) + if len(node.args) <= 1 and isinstance(node.parent, ast.Block): + node.parent.replace(node, node.args) + return + else: + for a in node.args: + cleanupBlocks(a) + + def symbolNameToVariableName(symbolName): """Replaces characters which are allowed in sympy symbol names but not in C/C++ variable names""" return symbolName.replace("^", "_") @@ -546,17 +665,23 @@ def typeAllEquations(eqs, typeForSymbol): else: assert False, "Expected a symbol as left-hand-side" - typedEquations = [] - for eq in eqs: - if isinstance(eq, sp.Eq) or isinstance(eq, ast.SympyAssignment): - newLhs = processLhs(eq.lhs) - newRhs = processRhs(eq.rhs) - typedEquations.append(ast.SympyAssignment(newLhs, newRhs)) + def visit(object): + if isinstance(object, list) or isinstance(object, tuple): + return [visit(e) for e in object] + if isinstance(object, sp.Eq) or isinstance(object, ast.SympyAssignment): + newLhs = processLhs(object.lhs) + newRhs = processRhs(object.rhs) + return ast.SympyAssignment(newLhs, newRhs) + elif isinstance(object, ast.Conditional): + falseBlock = None if object.falseBlock is None else visit(object.falseBlock) + return ast.Conditional(processRhs(object.conditionExpr), + trueBlock=visit(object.trueBlock), falseBlock=falseBlock) + elif isinstance(object, ast.Block): + return ast.Block([visit(e) for e in object.args]) else: - assert isinstance(eq, ast.Node), "Only equations and ast nodes are allowed in input" - typedEquations.append(eq) + return object - typedEquations = typedEquations + typedEquations = visit(eqs) return fieldsRead, fieldsWritten, typedEquations