diff --git a/derivative.py b/derivative.py index af097ccb5285cb94099de26d167442428fcde0fd..e3c8722e543d56139e0fffb9ab6ac6a2f7b2f495 100644 --- a/derivative.py +++ b/derivative.py @@ -4,71 +4,7 @@ from pystencils.sympyextensions import normalizeProduct, prod def defaultDiffSortKey(d): - return str(d.ceIdx), str(d.label) - - -class DiffOperator(sp.Expr): - """ - Un-applied differential, i.e. differential operator - Its args are: - - label: the differential is w.r.t to this label / variable. - This label is mainly for display purposes (its the subscript) and to distinguish DiffOperators - If the label is '-1' no subscript is displayed - - ceIdx: expansion order index in the Chapman Enskog expansion. It is displayed as superscript. - and not displayed if set to '-1' - The DiffOperator behaves much like a variable with special name. Its main use is to be applied later, using the - DiffOperator.apply(expr, arg) which transforms 'DiffOperator's to applied 'Diff's - """ - is_commutative = True - is_number = False - is_Rational = False - - def __new__(cls, label=-1, ceIdx=-1, **kwargs): - return sp.Expr.__new__(cls, sp.sympify(label), sp.sympify(ceIdx), **kwargs) - - @property - def label(self): - return self.args[0] - - @property - def ceIdx(self): - return self.args[1] - - def _latex(self, printer, *args): - result = "{\partial" - if self.ceIdx >= 0: - result += "^{(%s)}" % (self.ceIdx,) - if self.label != -1: - result += "_{%s}" % (self.label,) - result += "}" - return result - - @staticmethod - def apply(expr, argument): - """ - Returns a new expression where each 'DiffOperator' is replaced by a 'Diff' node. - Multiplications of 'DiffOperator's are interpreted as nested application of differentiation: - i.e. DiffOperator('x')*DiffOperator('x') is a second derivative replaced by Diff(Diff(arg, x), t) - """ - def handleMul(mul): - args = normalizeProduct(mul) - diffs = [a for a in args if isinstance(a, DiffOperator)] - if len(diffs) == 0: - return mul - rest = [a for a in args if not isinstance(a, DiffOperator)] - diffs.sort(key=defaultDiffSortKey) - result = argument - for d in reversed(diffs): - result = Diff(result, label=d.label, ceIdx=d.ceIdx) - return prod(rest) * result - - expr = expr.expand() - if expr.func == sp.Mul or expr.func == sp.Pow: - return handleMul(expr) - elif expr.func == sp.Add: - return expr.func(*[handleMul(a) for a in expr.args]) - else: - return expr + return str(d.superscript), str(d.target) class Diff(sp.Expr): @@ -76,15 +12,15 @@ class Diff(sp.Expr): Sympy Node representing a derivative. The difference to sympy's built in differential is: - shortened latex representation - all simplifications have to be done manually - - each Diff has a Chapman Enskog expansion order index: 'ceIdx' + - optional marker displayed as superscript """ is_number = False is_Rational = False - def __new__(cls, argument, label=-1, ceIdx=-1, **kwargs): + def __new__(cls, argument, target=-1, superscript=-1, **kwargs): if argument == 0: return sp.Rational(0, 1) - return sp.Expr.__new__(cls, argument.expand(), sp.sympify(label), sp.sympify(ceIdx), **kwargs) + return sp.Expr.__new__(cls, argument.expand(), sp.sympify(target), sp.sympify(superscript), **kwargs) @property def is_commutative(self): @@ -105,9 +41,9 @@ class Diff(sp.Expr): """Returns a Diff node with the given 'newArg' instead of the current argument. For nested derivatives a new nested derivative is returned where the inner Diff has the 'newArg'""" if not isinstance(self.arg, Diff): - return Diff(newArg, self.label, self.ceIdx) + return Diff(newArg, self.target, self.superscript) else: - return Diff(self.arg.changeArgRecursive(newArg), self.label, self.ceIdx) + return Diff(self.arg.changeArgRecursive(newArg), self.target, self.superscript) def splitLinear(self, functions): """ @@ -132,7 +68,7 @@ class Diff(sp.Expr): if isinstance(variable, int) or variable.is_number: return 0 else: - return constant * Diff(variable, label=self.label, ceIdx=self.ceIdx) + return constant * Diff(variable, target=self.target, superscript=self.superscript) @property def arg(self): @@ -140,21 +76,21 @@ class Diff(sp.Expr): return self.args[0] @property - def label(self): + def target(self): """Subscript, usually the variable the Diff is w.r.t. """ return self.args[1] @property - def ceIdx(self): + def superscript(self): """Superscript, used as the Chapman Enskog order index""" return self.args[2] def _latex(self, printer, *args): result = "{\partial" - if self.ceIdx >= 0: - result += "^{(%s)}" % (self.ceIdx,) - if self.label != -1: - result += "_{%s}" % (printer.doprint(self.label),) + if self.superscript >= 0: + result += "^{(%s)}" % (self.superscript,) + if self.target != -1: + result += "_{%s}" % (printer.doprint(self.target),) contents = printer.doprint(self.arg) if isinstance(self.arg, int) or isinstance(self.arg, sp.Symbol) or self.arg.is_number or self.arg.func == Diff: @@ -169,8 +105,72 @@ class Diff(sp.Expr): return "D(%s)" % self.arg +class DiffOperator(sp.Expr): + """ + Un-applied differential, i.e. differential operator + Its args are: + - target: the differential is w.r.t to this variable. + This target is mainly for display purposes (its the subscript) and to distinguish DiffOperators + If the target is '-1' no subscript is displayed + - superscript: optional marker displayed as superscript + is not displayed if set to '-1' + The DiffOperator behaves much like a variable with special name. Its main use is to be applied later, using the + DiffOperator.apply(expr, arg) which transforms 'DiffOperator's to applied 'Diff's + """ + is_commutative = True + is_number = False + is_Rational = False + + def __new__(cls, target=-1, superscript=-1, **kwargs): + return sp.Expr.__new__(cls, sp.sympify(target), sp.sympify(superscript), **kwargs) + + @property + def target(self): + return self.args[0] + + @property + def superscript(self): + return self.args[1] + + def _latex(self, printer, *args): + result = "{\partial" + if self.superscript >= 0: + result += "^{(%s)}" % (self.superscript,) + if self.target != -1: + result += "_{%s}" % (self.target,) + result += "}" + return result + + @staticmethod + def apply(expr, argument): + """ + Returns a new expression where each 'DiffOperator' is replaced by a 'Diff' node. + Multiplications of 'DiffOperator's are interpreted as nested application of differentiation: + i.e. DiffOperator('x')*DiffOperator('x') is a second derivative replaced by Diff(Diff(arg, x), t) + """ + def handleMul(mul): + args = normalizeProduct(mul) + diffs = [a for a in args if isinstance(a, DiffOperator)] + if len(diffs) == 0: + return mul + rest = [a for a in args if not isinstance(a, DiffOperator)] + diffs.sort(key=defaultDiffSortKey) + result = argument + for d in reversed(diffs): + result = Diff(result, target=d.target, superscript=d.superscript) + return prod(rest) * result + + expr = expr.expand() + if expr.func == sp.Mul or expr.func == sp.Pow: + return handleMul(expr) + elif expr.func == sp.Add: + return expr.func(*[handleMul(a) for a in expr.args]) + else: + return expr + # ---------------------------------------------------------------------------------------------------------------------- + def derivativeTerms(expr): """ Returns set of all derivatives in an expression @@ -222,10 +222,10 @@ def expandUsingLinearity(expr, functions=None, constants=None): if hasattr(arg, 'func') and arg.func == sp.Add: result = 0 for a in arg.args: - result += Diff(a, label=expr.label, ceIdx=expr.ceIdx).splitLinear(functions) + result += Diff(a, target=expr.target, superscript=expr.superscript).splitLinear(functions) return result else: - diff = Diff(arg, label=expr.label, ceIdx=expr.ceIdx) + diff = Diff(arg, target=expr.target, superscript=expr.superscript) if diff == 0: return 0 else: @@ -247,7 +247,7 @@ def fullDiffExpand(expr, functions=None, constants=None): if e.func == Diff: result = 0 - diffArgs = {'label': e.label, 'ceIdx': e.ceIdx} + diffArgs = {'target': e.target, 'superscript': e.superscript} diffInner = e.args[0] diffInner = visit(diffInner) for term in diffInner.args if diffInner.func == sp.Add else [diffInner]: @@ -288,7 +288,7 @@ def normalizeDiffOrder(expression, functions=None, constants=None, sortKey=defau result = processedArg for d in reversed(nodes): - result = Diff(result, label=d.label, ceIdx=d.ceIdx) + result = Diff(result, target=d.target, superscript=d.superscript) return result else: newArgs = [visit(e) for e in expr.args] @@ -303,17 +303,17 @@ def expandUsingProductRule(expr): if isinstance(expr, Diff): arg = expandUsingProductRule(expr.args[0]) if arg.func == sp.Add: - newArgs = [Diff(e, label=expr.label, ceIdx=expr.ceIdx) + newArgs = [Diff(e, target=expr.target, superscript=expr.superscript) for e in arg.args] return sp.Add(*newArgs) if arg.func not in (sp.Mul, sp.Pow): - return Diff(arg, label=expr.label, ceIdx=expr.ceIdx) + return Diff(arg, target=expr.target, superscript=expr.superscript) else: prodList = normalizeProduct(arg) result = 0 for i in range(len(prodList)): preFactor = prod(prodList[j] for j in range(len(prodList)) if i != j) - result += preFactor * Diff(prodList[i], label=expr.label, ceIdx=expr.ceIdx) + result += preFactor * Diff(prodList[i], target=expr.target, superscript=expr.superscript) return result else: newArgs = [expandUsingProductRule(e) for e in expr.args] @@ -325,11 +325,11 @@ def combineUsingProductRule(expr): def exprToDiffDecomposition(expr): """Decomposes a sp.Add node containing CeDiffs into: - diffDict: maps (label, ceIdx) -> [ (preFactor, argument), ... ] + diffDict: maps (target, superscript) -> [ (preFactor, argument), ... ] i.e. a partial(b) ( a is prefactor, b is argument) in case of partial(a) partial(b) two entries are created (0.5 partial(a), b), (0.5 partial(b), a) """ - DiffInfo = namedtuple("DiffInfo", ["label", "ceIdx"]) + DiffInfo = namedtuple("DiffInfo", ["target", "superscript"]) class DiffSplit: def __init__(self, preFactor, argument): @@ -344,7 +344,7 @@ def combineUsingProductRule(expr): rest = 0 for term in expr.args: if isinstance(term, Diff): - diffDict[DiffInfo(term.label, term.ceIdx)].append(DiffSplit(1, term.arg)) + diffDict[DiffInfo(term.target, term.superscript)].append(DiffSplit(1, term.arg)) else: mulArgs = normalizeProduct(term) diffs = [d for d in mulArgs if isinstance(d, Diff)] @@ -355,7 +355,7 @@ def combineUsingProductRule(expr): for i, diff in enumerate(diffs): allButCurrent = [d for j, d in enumerate(diffs) if i != j] preFactor = factor * prod(allButCurrent) * sp.Rational(1, len(diffs)) - diffDict[DiffInfo(diff.label, diff.ceIdx)].append(DiffSplit(preFactor, diff.arg)) + diffDict[DiffInfo(diff.target, diff.superscript)].append(DiffSplit(preFactor, diff.arg)) return diffDict, rest @@ -369,11 +369,11 @@ def combineUsingProductRule(expr): newOtherFactor = ownFac - otherFac return newOtherFactor - def processDiffList(diffList, label, ceIdx): + def processDiffList(diffList, label, superscript): if len(diffList) == 0: return 0 elif len(diffList) == 1: - return diffList[0].preFactor * Diff(diffList[0].argument, label, ceIdx) + return diffList[0].preFactor * Diff(diffList[0].argument, label, superscript) result = 0 matches = [] @@ -383,23 +383,23 @@ def combineUsingProductRule(expr): matches.append((i, matchResult)) if len(matches) == 0: - result += diffList[0].preFactor * Diff(diffList[0].argument, label, ceIdx) + result += diffList[0].preFactor * Diff(diffList[0].argument, label, superscript) else: otherIdx, matchResult = sorted(matches, key=lambda e: sp.count_ops(e[1]))[0] newArgument = diffList[0].argument * diffList[otherIdx].argument - result += (diffList[0].preFactor / diffList[otherIdx].argument) * Diff(newArgument, label, ceIdx) + result += (diffList[0].preFactor / diffList[otherIdx].argument) * Diff(newArgument, label, superscript) if matchResult == 0: del diffList[otherIdx] else: diffList[otherIdx].preFactor = matchResult * diffList[0].argument - result += processDiffList(diffList[1:], label, ceIdx) + result += processDiffList(diffList[1:], label, superscript) return result expr = expr.expand() if isinstance(expr, sp.Add): diffDict, rest = exprToDiffDecomposition(expr) - for (label, ceIdx), diffList in diffDict.items(): - rest += processDiffList(diffList, label, ceIdx) + for (label, superscript), diffList in diffDict.items(): + rest += processDiffList(diffList, label, superscript) return rest else: newArgs = [combineUsingProductRule(e) for e in expr.args] @@ -407,12 +407,12 @@ def combineUsingProductRule(expr): def replaceDiff(expr, replacementDict): - """replacementDict: maps variable (label) to a new Differential operator""" + """replacementDict: maps variable (target) to a new Differential operator""" def visit(e): if isinstance(e, Diff): - if e.label in replacementDict: - return DiffOperator.apply(replacementDict[e.label], visit(e.arg)) + if e.target in replacementDict: + return DiffOperator.apply(replacementDict[e.target], visit(e.arg)) newArgs = [visit(arg) for arg in e.args] return e.func(*newArgs) if newArgs else e @@ -420,10 +420,10 @@ def replaceDiff(expr, replacementDict): def zeroDiffs(expr, label): - """Replaces all differentials with the given label by 0""" + """Replaces all differentials with the given target by 0""" def visit(e): if isinstance(e, Diff): - if e.label == label: + if e.target == label: return 0 newArgs = [visit(arg) for arg in e.args] return e.func(*newArgs) if newArgs else e @@ -431,11 +431,11 @@ def zeroDiffs(expr, label): def evaluateDiffs(expr, var=None): - """Replaces Diff nodes by sp.diff , the free variable is either the label (if var=None) otherwise + """Replaces Diff nodes by sp.diff , the free variable is either the target (if var=None) otherwise the specified var""" if isinstance(expr, Diff): if var is None: - var = expr.label + var = expr.target return sp.diff(evaluateDiffs(expr.arg, var), var) else: newArgs = [evaluateDiffs(arg, var) for arg in expr.args] @@ -467,7 +467,7 @@ def functionalDerivative(functional, v, constants=None): continue dummy = sp.Dummy() partialF_partialGradV = functional.subs(diffObj, dummy).diff(dummy).subs(dummy, diffObj) - gradientPart += Diff(partialF_partialGradV, label=diffObj.label, ceIdx=diffObj.ceIdx) + gradientPart += Diff(partialF_partialGradV, target=diffObj.target, superscript=diffObj.superscript) result = partialF_partialV - gradientPart return expandUsingLinearity(result, constants=constants) diff --git a/finitedifferences.py b/finitedifferences.py index 9dd1416661b79eb38b4e3efab9b12bf944169695..039a603e14c04e92bf06066ee3bd4e1032256816 100644 --- a/finitedifferences.py +++ b/finitedifferences.py @@ -331,10 +331,10 @@ class Discretization2ndOrder: order = self.__diffOrder(e) if order == 1: fa = e.args[0] - index = e.label + index = e.target return (fa.neighbor(index, 1) - fa.neighbor(index, -1)) / (2 * self.dx) elif order == 2: - indices = sorted([e.label, e.args[0].label]) + indices = sorted([e.target, e.args[0].target]) fa = e.args[0].args[0] if indices[0] == indices[1] and all(i >= 0 for i in indices): result = (-2 * fa + fa.neighbor(indices[0], -1) + fa.neighbor(indices[0], +1)) diff --git a/transformations/stage2.py b/transformations/stage2.py index ebfc5309d9fdfe590b72e17c4b111eb9fc84e52f..f46f09b0b547873a7a9b28b8c2bc10034e53d70d 100644 --- a/transformations/stage2.py +++ b/transformations/stage2.py @@ -155,7 +155,7 @@ def insert_casts(node): # elif isinstance(arg, sp.tensor.Indexed) or isinstance(arg, sp.tensor.indexed.Indexed): # node.replace(arg, ast.Indexed(arg.args, arg.base, node)) # elif isinstance(arg, sp.tensor.IndexedBase): -# node.replace(arg, arg.label) +# node.replace(arg, arg.target) # elif isinstance(arg, sp.Function): # node.replace(arg, ast.Function(arg.func, arg.args, node)) # #elif isinstance(arg, sp.containers.Tuple):