diff --git a/derivative.py b/derivative.py new file mode 100644 index 0000000000000000000000000000000000000000..af097ccb5285cb94099de26d167442428fcde0fd --- /dev/null +++ b/derivative.py @@ -0,0 +1,473 @@ +import sympy as sp +from collections import namedtuple, defaultdict +from pystencils.sympyextensions import normalizeProduct, prod + + +def defaultDiffSortKey(d): + return str(d.ceIdx), str(d.label) + + +class DiffOperator(sp.Expr): + """ + Un-applied differential, i.e. differential operator + Its args are: + - label: the differential is w.r.t to this label / variable. + This label is mainly for display purposes (its the subscript) and to distinguish DiffOperators + If the label is '-1' no subscript is displayed + - ceIdx: expansion order index in the Chapman Enskog expansion. It is displayed as superscript. + and not displayed if set to '-1' + The DiffOperator behaves much like a variable with special name. Its main use is to be applied later, using the + DiffOperator.apply(expr, arg) which transforms 'DiffOperator's to applied 'Diff's + """ + is_commutative = True + is_number = False + is_Rational = False + + def __new__(cls, label=-1, ceIdx=-1, **kwargs): + return sp.Expr.__new__(cls, sp.sympify(label), sp.sympify(ceIdx), **kwargs) + + @property + def label(self): + return self.args[0] + + @property + def ceIdx(self): + return self.args[1] + + def _latex(self, printer, *args): + result = "{\partial" + if self.ceIdx >= 0: + result += "^{(%s)}" % (self.ceIdx,) + if self.label != -1: + result += "_{%s}" % (self.label,) + result += "}" + return result + + @staticmethod + def apply(expr, argument): + """ + Returns a new expression where each 'DiffOperator' is replaced by a 'Diff' node. + Multiplications of 'DiffOperator's are interpreted as nested application of differentiation: + i.e. DiffOperator('x')*DiffOperator('x') is a second derivative replaced by Diff(Diff(arg, x), t) + """ + def handleMul(mul): + args = normalizeProduct(mul) + diffs = [a for a in args if isinstance(a, DiffOperator)] + if len(diffs) == 0: + return mul + rest = [a for a in args if not isinstance(a, DiffOperator)] + diffs.sort(key=defaultDiffSortKey) + result = argument + for d in reversed(diffs): + result = Diff(result, label=d.label, ceIdx=d.ceIdx) + return prod(rest) * result + + expr = expr.expand() + if expr.func == sp.Mul or expr.func == sp.Pow: + return handleMul(expr) + elif expr.func == sp.Add: + return expr.func(*[handleMul(a) for a in expr.args]) + else: + return expr + + +class Diff(sp.Expr): + """ + Sympy Node representing a derivative. The difference to sympy's built in differential is: + - shortened latex representation + - all simplifications have to be done manually + - each Diff has a Chapman Enskog expansion order index: 'ceIdx' + """ + is_number = False + is_Rational = False + + def __new__(cls, argument, label=-1, ceIdx=-1, **kwargs): + if argument == 0: + return sp.Rational(0, 1) + return sp.Expr.__new__(cls, argument.expand(), sp.sympify(label), sp.sympify(ceIdx), **kwargs) + + @property + def is_commutative(self): + anyNonCommutative = any(not s.is_commutative for s in self.atoms(sp.Symbol)) + if anyNonCommutative: + return False + else: + return True + + def getArgRecursive(self): + """Returns the argument the derivative acts on, for nested derivatives the inner argument is returned""" + if not isinstance(self.arg, Diff): + return self.arg + else: + return self.arg.getArgRecursive() + + def changeArgRecursive(self, newArg): + """Returns a Diff node with the given 'newArg' instead of the current argument. For nested derivatives + a new nested derivative is returned where the inner Diff has the 'newArg'""" + if not isinstance(self.arg, Diff): + return Diff(newArg, self.label, self.ceIdx) + else: + return Diff(self.arg.changeArgRecursive(newArg), self.label, self.ceIdx) + + def splitLinear(self, functions): + """ + Applies linearity property of Diff: i.e. 'Diff(c*a+b)' is transformed to 'c * Diff(a) + Diff(b)' + The parameter functions is a list of all symbols that are considered functions, not constants. + For the example above: functions=[a, b] + """ + constant, variable = 1, 1 + + if self.arg.func != sp.Mul: + constant, variable = 1, self.arg + else: + for factor in normalizeProduct(self.arg): + if factor in functions or isinstance(factor, Diff): + variable *= factor + else: + constant *= factor + + if isinstance(variable, sp.Symbol) and variable not in functions: + return 0 + + if isinstance(variable, int) or variable.is_number: + return 0 + else: + return constant * Diff(variable, label=self.label, ceIdx=self.ceIdx) + + @property + def arg(self): + """Expression the derivative acts on""" + return self.args[0] + + @property + def label(self): + """Subscript, usually the variable the Diff is w.r.t. """ + return self.args[1] + + @property + def ceIdx(self): + """Superscript, used as the Chapman Enskog order index""" + return self.args[2] + + def _latex(self, printer, *args): + result = "{\partial" + if self.ceIdx >= 0: + result += "^{(%s)}" % (self.ceIdx,) + if self.label != -1: + result += "_{%s}" % (printer.doprint(self.label),) + + contents = printer.doprint(self.arg) + if isinstance(self.arg, int) or isinstance(self.arg, sp.Symbol) or self.arg.is_number or self.arg.func == Diff: + result += " " + contents + else: + result += " (" + contents + ") " + + result += "}" + return result + + def __str__(self): + return "D(%s)" % self.arg + + +# ---------------------------------------------------------------------------------------------------------------------- + +def derivativeTerms(expr): + """ + Returns set of all derivatives in an expression + this is different from `expr.atoms(Diff)` when nested derivatives are in the expression, + since this function only returns the outer derivatives + """ + result = set() + + def visit(e): + if isinstance(e, Diff): + result.add(e) + else: + for a in e.args: + visit(a) + visit(expr) + return result + + +def collectDerivatives(expr): + """Rewrites expression into a sum of distinct derivatives with prefactors""" + return expr.collect(derivativeTerms(expr)) + + +def createNestedDiff(*args, arg=None): + """Shortcut to create nested derivatives""" + assert arg is not None + args = sorted(args, reverse=True) + res = arg + for i in args: + res = Diff(res, i) + return res + + +def expandUsingLinearity(expr, functions=None, constants=None): + """ + Expands all derivative nodes by applying Diff.splitLinear + :param expr: expression containing derivatives + :param functions: sequence of symbols that are considered functions and can not be pulled before the derivative. + if None, all symbols are viewed as functions + :param constants: sequence of symbols which are considered constants and can be pulled before the derivative + """ + if functions is None: + functions = expr.atoms(sp.Symbol) + if constants is not None: + functions.difference_update(constants) + + if isinstance(expr, Diff): + arg = expandUsingLinearity(expr.arg, functions) + if hasattr(arg, 'func') and arg.func == sp.Add: + result = 0 + for a in arg.args: + result += Diff(a, label=expr.label, ceIdx=expr.ceIdx).splitLinear(functions) + return result + else: + diff = Diff(arg, label=expr.label, ceIdx=expr.ceIdx) + if diff == 0: + return 0 + else: + return diff.splitLinear(functions) + else: + newArgs = [expandUsingLinearity(e, functions) for e in expr.args] + result = sp.expand(expr.func(*newArgs) if newArgs else expr) + return result + + +def fullDiffExpand(expr, functions=None, constants=None): + if functions is None: + functions = expr.atoms(sp.Symbol) + if constants is not None: + functions.difference_update(constants) + + def visit(e): + e = e.expand() + + if e.func == Diff: + result = 0 + diffArgs = {'label': e.label, 'ceIdx': e.ceIdx} + diffInner = e.args[0] + diffInner = visit(diffInner) + for term in diffInner.args if diffInner.func == sp.Add else [diffInner]: + independentTerms = 1 + dependentTerms = [] + for factor in normalizeProduct(term): + if factor in functions or isinstance(factor, Diff): + dependentTerms.append(factor) + else: + independentTerms *= factor + for i in range(len(dependentTerms)): + dependentTerm = dependentTerms[i] + otherDependentTerms = dependentTerms[:i] + dependentTerms[i+1:] + processedDiff = normalizeDiffOrder(Diff(dependentTerm, **diffArgs)) + result += independentTerms * prod(otherDependentTerms) * processedDiff + return result + else: + newArgs = [visit(arg) for arg in e.args] + return e.func(*newArgs) if newArgs else e + + if isinstance(expr, sp.Matrix): + return expr.applyfunc(visit) + else: + return visit(expr) + + +def normalizeDiffOrder(expression, functions=None, constants=None, sortKey=defaultDiffSortKey): + """Assumes order of differentiation can be exchanged. Changes the order of nested Diffs to a standard order defined + by the sorting key 'sortKey' such that the derivative terms can be further simplified """ + def visit(expr): + if isinstance(expr, Diff): + nodes = [expr] + while isinstance(nodes[-1].arg, Diff): + nodes.append(nodes[-1].arg) + + processedArg = visit(nodes[-1].arg) + nodes.sort(key=sortKey) + + result = processedArg + for d in reversed(nodes): + result = Diff(result, label=d.label, ceIdx=d.ceIdx) + return result + else: + newArgs = [visit(e) for e in expr.args] + return expr.func(*newArgs) if newArgs else expr + + expression = expandUsingLinearity(expression.expand(), functions, constants).expand() + return visit(expression) + + +def expandUsingProductRule(expr): + """Fully expands all derivatives by applying product rule""" + if isinstance(expr, Diff): + arg = expandUsingProductRule(expr.args[0]) + if arg.func == sp.Add: + newArgs = [Diff(e, label=expr.label, ceIdx=expr.ceIdx) + for e in arg.args] + return sp.Add(*newArgs) + if arg.func not in (sp.Mul, sp.Pow): + return Diff(arg, label=expr.label, ceIdx=expr.ceIdx) + else: + prodList = normalizeProduct(arg) + result = 0 + for i in range(len(prodList)): + preFactor = prod(prodList[j] for j in range(len(prodList)) if i != j) + result += preFactor * Diff(prodList[i], label=expr.label, ceIdx=expr.ceIdx) + return result + else: + newArgs = [expandUsingProductRule(e) for e in expr.args] + return expr.func(*newArgs) if newArgs else expr + + +def combineUsingProductRule(expr): + """Inverse product rule""" + + def exprToDiffDecomposition(expr): + """Decomposes a sp.Add node containing CeDiffs into: + diffDict: maps (label, ceIdx) -> [ (preFactor, argument), ... ] + i.e. a partial(b) ( a is prefactor, b is argument) + in case of partial(a) partial(b) two entries are created (0.5 partial(a), b), (0.5 partial(b), a) + """ + DiffInfo = namedtuple("DiffInfo", ["label", "ceIdx"]) + + class DiffSplit: + def __init__(self, preFactor, argument): + self.preFactor = preFactor + self.argument = argument + + def __repr__(self): + return str((self.preFactor, self.argument)) + + assert isinstance(expr, sp.Add) + diffDict = defaultdict(list) + rest = 0 + for term in expr.args: + if isinstance(term, Diff): + diffDict[DiffInfo(term.label, term.ceIdx)].append(DiffSplit(1, term.arg)) + else: + mulArgs = normalizeProduct(term) + diffs = [d for d in mulArgs if isinstance(d, Diff)] + factor = prod(d for d in mulArgs if not isinstance(d, Diff)) + if len(diffs) == 0: + rest += factor + else: + for i, diff in enumerate(diffs): + allButCurrent = [d for j, d in enumerate(diffs) if i != j] + preFactor = factor * prod(allButCurrent) * sp.Rational(1, len(diffs)) + diffDict[DiffInfo(diff.label, diff.ceIdx)].append(DiffSplit(preFactor, diff.arg)) + + return diffDict, rest + + def matchDiffSplits(own, other): + ownFac = own.preFactor / other.argument + otherFac = other.preFactor / own.argument + + if sp.count_ops(ownFac) > sp.count_ops(own.preFactor) or sp.count_ops(otherFac) > sp.count_ops(other.preFactor): + return None + + newOtherFactor = ownFac - otherFac + return newOtherFactor + + def processDiffList(diffList, label, ceIdx): + if len(diffList) == 0: + return 0 + elif len(diffList) == 1: + return diffList[0].preFactor * Diff(diffList[0].argument, label, ceIdx) + + result = 0 + matches = [] + for i in range(1, len(diffList)): + matchResult = matchDiffSplits(diffList[i], diffList[0]) + if matchResult is not None: + matches.append((i, matchResult)) + + if len(matches) == 0: + result += diffList[0].preFactor * Diff(diffList[0].argument, label, ceIdx) + else: + otherIdx, matchResult = sorted(matches, key=lambda e: sp.count_ops(e[1]))[0] + newArgument = diffList[0].argument * diffList[otherIdx].argument + result += (diffList[0].preFactor / diffList[otherIdx].argument) * Diff(newArgument, label, ceIdx) + if matchResult == 0: + del diffList[otherIdx] + else: + diffList[otherIdx].preFactor = matchResult * diffList[0].argument + result += processDiffList(diffList[1:], label, ceIdx) + return result + + expr = expr.expand() + if isinstance(expr, sp.Add): + diffDict, rest = exprToDiffDecomposition(expr) + for (label, ceIdx), diffList in diffDict.items(): + rest += processDiffList(diffList, label, ceIdx) + return rest + else: + newArgs = [combineUsingProductRule(e) for e in expr.args] + return expr.func(*newArgs) if newArgs else expr + + +def replaceDiff(expr, replacementDict): + """replacementDict: maps variable (label) to a new Differential operator""" + + def visit(e): + if isinstance(e, Diff): + if e.label in replacementDict: + return DiffOperator.apply(replacementDict[e.label], visit(e.arg)) + newArgs = [visit(arg) for arg in e.args] + return e.func(*newArgs) if newArgs else e + + return visit(expr) + + +def zeroDiffs(expr, label): + """Replaces all differentials with the given label by 0""" + def visit(e): + if isinstance(e, Diff): + if e.label == label: + return 0 + newArgs = [visit(arg) for arg in e.args] + return e.func(*newArgs) if newArgs else e + return visit(expr) + + +def evaluateDiffs(expr, var=None): + """Replaces Diff nodes by sp.diff , the free variable is either the label (if var=None) otherwise + the specified var""" + if isinstance(expr, Diff): + if var is None: + var = expr.label + return sp.diff(evaluateDiffs(expr.arg, var), var) + else: + newArgs = [evaluateDiffs(arg, var) for arg in expr.args] + return expr.func(*newArgs) if newArgs else expr + + +def functionalDerivative(functional, v, constants=None): + """ + Computes functional derivative of functional with respect to v using Euler-Lagrange equation + + .. math :: + + \frac{\delta F}{\delta v} = + \frac{\partial F}{\partial v} - \nabla \cdot \frac{\partial F}{\partial \nabla v} + + - assumes that gradients are represented by Diff() node (from Chapman Enskog module) + - Diff(Diff(r)) represents the divergence of r + - the constants parameter is a list with symbols not affected by the derivative. This is used for simplification + of the derivative terms. + """ + diffs = functional.atoms(Diff) + nonDiffPart = functional.subs({d: sp.Dummy() for d in diffs}) + + partialF_partialV = sp.diff(nonDiffPart, v) + + gradientPart = 0 + for diffObj in diffs: + if diffObj.args[0] != v: + continue + dummy = sp.Dummy() + partialF_partialGradV = functional.subs(diffObj, dummy).diff(dummy).subs(dummy, diffObj) + gradientPart += Diff(partialF_partialGradV, label=diffObj.label, ceIdx=diffObj.ceIdx) + + result = partialF_partialV - gradientPart + return expandUsingLinearity(result, constants=constants) diff --git a/finitedifferences.py b/finitedifferences.py index 9f1152d151e9806b72497ec16d81221536829cda..9dd1416661b79eb38b4e3efab9b12bf944169695 100644 --- a/finitedifferences.py +++ b/finitedifferences.py @@ -1,7 +1,10 @@ import numpy as np import sympy as sp + +from pystencils.equationcollection import EquationCollection from pystencils.field import Field from pystencils.transformations import fastSubs +from pystencils.derivative import Diff def grad(var, dim=3): @@ -185,10 +188,14 @@ class Advection(sp.Function): def advection(advectedScalar, velocityField, idx=None): """Advection term: divergence( velocityField * advectedScalar )""" + if isinstance(advectedScalar, Field): + firstArg = advectedScalar.center + elif isinstance(advectedScalar, Field.Access): + firstArg = advectedScalar + else: + raise ValueError("Advected scalar has to be a pystencils Field or Field.Access") - assert isinstance(advectedScalar, Field), "Advected scalar has to be a pystencils.Field" - - args = [advectedScalar.center, velocityField if not isinstance(velocityField, Field) else velocityField.center] + args = [firstArg, velocityField if not isinstance(velocityField, Field) else velocityField.center] if idx is not None: args.append(idx) return Advection(*args) @@ -236,8 +243,14 @@ class Diffusion(sp.Function): def diffusion(scalar, diffusionCoeff, idx=None): - assert isinstance(scalar, Field), "Advected scalar has to be a pystencils.Field" - args = [scalar.center, diffusionCoeff if not isinstance(diffusionCoeff, Field) else diffusionCoeff.center] + if isinstance(scalar, Field): + firstArg = scalar.center + elif isinstance(scalar, Field.Access): + firstArg = scalar + else: + raise ValueError("Diffused scalar has to be a pystencils Field or Field.Access") + + args = [firstArg, diffusionCoeff if not isinstance(diffusionCoeff, Field) else diffusionCoeff.center] if idx is not None: args.append(idx) return Diffusion(*args) @@ -261,7 +274,12 @@ class Transient(sp.Function): def transient(scalar, idx=None): - args = [scalar.center] + if isinstance(scalar, Field): + args = [scalar.center] + elif isinstance(scalar, Field.Access): + args = [scalar] + else: + raise ValueError("Scalar has to be a pystencils Field or Field.Access") if idx is not None: args.append(idx) return Transient(*args) @@ -272,6 +290,13 @@ class Discretization2ndOrder: self.dx = dx self.dt = dt + @staticmethod + def __diffOrder(e): + if not isinstance(e, Diff): + return 0 + else: + return 1 + Discretization2ndOrder.__diffOrder(e.args[0]) + def _discretize_diffusion(self, expr): result = 0 for c in range(expr.dim): @@ -296,11 +321,44 @@ class Discretization2ndOrder: return self._discretize_diffusion(e) elif isinstance(e, Advection): return self._discretize_advection(e) + elif isinstance(e, Diff): + return self._discretize_diff(e) else: newArgs = [self._discretizeSpatial(a) for a in e.args] return e.func(*newArgs) if newArgs else e + def _discretize_diff(self, e): + order = self.__diffOrder(e) + if order == 1: + fa = e.args[0] + index = e.label + return (fa.neighbor(index, 1) - fa.neighbor(index, -1)) / (2 * self.dx) + elif order == 2: + indices = sorted([e.label, e.args[0].label]) + fa = e.args[0].args[0] + if indices[0] == indices[1] and all(i >= 0 for i in indices): + result = (-2 * fa + fa.neighbor(indices[0], -1) + fa.neighbor(indices[0], +1)) + elif indices[0] == indices[1]: + result = 0 + for d in range(fa.field.spatialDimensions): + result += (-2 * fa + fa.neighbor(d, -1) + fa.neighbor(d, +1)) + else: + assert all(i >= 0 for i in indices) + offsets = [(1, 1), [-1, 1], [1, -1], [-1, -1]] + result = sum(o1*o2 * fa.neighbor(indices[0], o1).neighbor(indices[1], o2) for o1, o2 in offsets) / 4 + return result / (self.dx**2) + else: + raise NotImplementedError("Term contains derivatives of order > 2") + def __call__(self, expr): + if isinstance(expr, list): + return [self(e) for e in expr] + elif isinstance(expr, sp.Matrix): + return expr.applyfunc(self.__call__) + elif isinstance(expr, EquationCollection): + return expr.copy(mainEquations=[e for e in expr.mainEquations], + subexpressions=[e for e in expr.subexpressions]) + transientTerms = expr.atoms(Transient) if len(transientTerms) == 0: return self._discretizeSpatial(expr)