Commit 2ca7c47a authored by Martin Bauer's avatar Martin Bauer
Browse files

lbmpy: worked on moment-based methods

parent da5b4d93
from lbmpy.stencils import inverseDirection
from pystencils import Field
import abc
# ------------------------------------------------ Interface -----------------------------------------------------------
class PdfFieldAccessor:
"""
Defines how data is read and written in an LBM time step.
Examples for PdfFieldAccessors are
- stream pull using two fields (source/destination)
- inplace collision access, without streaming
- esoteric twist single field update
-
"""
@abc.abstractmethod
def read(self, field, stencil):
"""Returns sequence of field accesses for all stencil values where pdfs are read from"""
pass
@abc.abstractmethod
def write(self, field, stencil):
"""Returns sequence of field accesses for all stencil values where pdfs are written to"""
pass
# ----------------------------------------------- Implementation -------------------------------------------------------
class CollideOnlyInplaceAccessor(PdfFieldAccessor):
@staticmethod
def read(field, stencil):
return [field(i) for i in range(len(stencil))]
@staticmethod
def write(field, stencil):
return [field(i) for i in range(len(stencil))]
class StreamPullTwoFieldsAccessor(PdfFieldAccessor):
@staticmethod
def read(field, stencil):
return [field[inverseDirection(d)](i) for i, d in enumerate(stencil)]
@staticmethod
def write(field, stencil):
return [field(i) for i in range(len(stencil))]
class AABBEvenTimeStepAccessor(PdfFieldAccessor):
@staticmethod
def read(field, stencil):
return [field(i) for i in range(len(stencil))]
@staticmethod
def write(field, stencil):
return [field(stencil.index(inverseDirection(d))) for d in stencil]
class AABBOddTimeStepAccessor(PdfFieldAccessor):
@staticmethod
def read(field, stencil):
res = []
for i, d in enumerate(stencil):
invDir = inverseDirection(d)
fieldAccess = field[invDir](stencil.index(invDir))
res.append(fieldAccess)
return
@staticmethod
def write(field, stencil):
return [field[d](i) for i, d in enumerate(stencil)]
class EsotericTwistAccessor(PdfFieldAccessor):
@staticmethod
def read(field, stencil):
result = []
for i, direction in enumerate(stencil):
direction = inverseDirection(direction)
neighborOffset = tuple([-e if e <= 0 else 0 for e in direction])
result.append(field[neighborOffset](i))
return result
@staticmethod
def write(field, stencil):
result = []
for i, direction in enumerate(stencil):
neighborOffset = tuple([e if e >= 0 else 0 for e in direction])
inverseIndex = stencil.index(inverseDirection(direction))
result.append(field[neighborOffset](inverseIndex))
return result
# -------------------------------------------- Visualization -----------------------------------------------------------
def visualizeFieldMapping(axes, stencil, fieldMapping, color='b'):
from lbmpy.gridvisualization import Grid
grid = Grid(3, 3)
grid.fillWithDefaultArrows()
for fieldAccess, direction in zip(fieldMapping, stencil):
fieldPosition = stencil[fieldAccess.index[0]]
neighbor = fieldAccess.offsets
grid.addArrow((1 + neighbor[0], 1 + neighbor[1]),
arrowPosition=fieldPosition, arrowDirection=direction, color=color)
grid.draw(axes)
def visualizePdfFieldAccessor(pdfFieldAccessor, figure=None):
from lbmpy.stencils import getStencil
if figure is None:
import matplotlib.pyplot as plt
figure = plt.gcf()
stencil = getStencil('D2Q9')
figure.patch.set_facecolor('white')
field = Field.createGeneric('f', spatialDimensions=2, indexDimensions=1)
preCollisionAccesses = pdfFieldAccessor.read(field, stencil)
postCollisionAccesses = pdfFieldAccessor.write(field, stencil)
axLeft = figure.add_subplot(1, 2, 1)
axRight = figure.add_subplot(1, 2, 2)
visualizeFieldMapping(axLeft, stencil, preCollisionAccesses, color='k')
visualizeFieldMapping(axRight, stencil, postCollisionAccesses, color='r')
axLeft.set_title("Read")
axRight.set_title("Write")
import matplotlib.patches as patches
class Grid:
"""Visualizes a 2D LBM grid with matplotlib by drawing cells and pdf arrows"""
def __init__(self, xCells, yCells):
"""Create a new grid with the given number of cells in x (horizontal) and y (vertical) direction"""
self._xCells = xCells
self._yCells = yCells
self._patches = []
for x in range(xCells):
for y in range(yCells):
self._patches.append(patches.Rectangle((x, y), 1.0, 1.0, fill=False, linewidth=3, color='#bbbbbb'))
self._cellBoundaries = dict() # mapping cell to rectangle patch
self._arrows = dict() # mapping (cell, direction) tuples to arrow patches
def addCellBoundary(self, cell, **kwargs):
"""Draws a rectangle around a single cell. Keyword arguments are passed to the matplotlib Rectangle patch"""
if 'fill' not in kwargs: kwargs['fill'] = False
if 'linewidth' not in kwargs: kwargs['linewidth'] = 3
if 'color' not in kwargs: kwargs['#bbbbbb']
self._cellBoundaries[cell] = patches.Rectangle(cell, 1.0, 1.0, **kwargs)
def addCellBoundaries(self, **kwargs):
"""Draws a rectangle around all cells. Keyword arguments are passed to the matplotlib Rectangle patch"""
for x in range(self._xCells):
for y in range(self._yCells):
self.addCellBoundary((x, y), **kwargs)
def addArrow(self, cell, arrowPosition, arrowDirection, **kwargs):
"""
Draws an arrow in a cell. If an arrow exists already at this position, it is replaced.
:param cell: cell coordinate as tuple (x,y)
:param arrowPosition: each cell has 9 possible positions specified as tuple e.g. upper left (-1, 1)
:param arrowDirection: direction of the arrow as (x,y) tuple
:param kwargs: arguments passed directly to the FancyArrow patch of matplotlib
"""
cellMidpoint = (0.5 + cell[0], 0.5 + cell[1])
if 'width' not in kwargs: kwargs['width'] = 0.005
if 'color' not in kwargs: kwargs['color'] = 'k'
if arrowPosition == (0, 0):
del kwargs['width']
self._arrows[(cell, arrowPosition)] = patches.Circle(cellMidpoint, radius=0.03, **kwargs)
else:
arrowMidpoint = (cellMidpoint[0] + arrowPosition[0] * 0.25,
cellMidpoint[1] + arrowPosition[1] * 0.25)
length = 0.75
arrowStart = (arrowMidpoint[0] - arrowDirection[0] * 0.25 * length,
arrowMidpoint[1] - arrowDirection[1] * 0.25 * length)
patch = patches.FancyArrow(arrowStart[0], arrowStart[1],
0.25 * length * arrowDirection[0],
0.25 * length * arrowDirection[1],
**kwargs)
self._arrows[(cell, arrowPosition)] = patch
def fillWithDefaultArrows(self, **kwargs):
"""Fills the complete grid with the default pdf arrows"""
for x in range(self._xCells):
for y in range(self._yCells):
for dx in [-1, 0, 1]:
for dy in [-1, 0, 1]:
if 'color' not in kwargs: kwargs['color'] = '#bbbbbb'
if 'width' not in kwargs: kwargs['width'] = 0.006
self.addArrow((x, y), (dx, dy), (dx, dy), **kwargs)
def draw(self, ax):
"""Draw the grid into a given matplotlib axes object"""
for p in self._patches:
ax.add_patch(p)
for arrowPatch in self._arrows.values():
ax.add_patch(arrowPatch)
offset = 0.1
ax.set_xlim(-offset, self._xCells+offset)
ax.set_xlim(-offset, self._xCells + offset)
ax.set_ylim(-offset, self._yCells + offset)
ax.set_aspect('equal')
ax.set_axis_off()
import sympy as sp
from collections import defaultdict
def createLbmSplitGroups(lbmCollisionEqs):
"""
Creates split groups for LBM collision equations. For details about split groups see
:func:`pystencils.transformation.splitInnerLoop` .
The split groups are added as simplification hint 'splitGroups'
Split groups are created in the following way: Opposing directions are put into a single group.
The velocity subexpressions are pre-computed as well as all subexpressions which are used in all
non-center collision equations, and depend on at least one pdf.
Required simplification hints:
- velocity: sequence of velocity symbols
"""
sh = lbmCollisionEqs.simplificationHints
assert 'velocity' in sh, "Needs simplification hint 'velocity': Sequence of velocity symbols"
pdfSymbols = lbmCollisionEqs.method.preCollisionPdfSymbols
stencil = lbmCollisionEqs.method.stencil
importantSubExpressions = {e.lhs for e in lbmCollisionEqs.subexpressions
if pdfSymbols.intersection(lbmCollisionEqs.getDependentSymbols([e.lhs]))}
for eq in lbmCollisionEqs.mainEquations[1:]:
importantSubExpressions.intersection_update(eq.rhs.atoms(sp.Symbol))
subexpressionsToPreCompute = list(sh['velocity']) + list(importantSubExpressions)
splitGroups = [subexpressionsToPreCompute, ]
directionGroups = defaultdict(list)
dim = len(stencil[0])
for direction, eq in zip(stencil, lbmCollisionEqs.mainEquations):
if direction == tuple([0]*dim):
splitGroups[0].append(eq.lhs)
continue
inverseDir = tuple([-i for i in direction])
if inverseDir in directionGroups:
directionGroups[inverseDir].append(eq.lhs)
else:
directionGroups[direction].append(eq.lhs)
splitGroups += directionGroups.values()
return splitGroups
import abc
import sympy as sp
from pystencils.equationcollection import EquationCollection
class LbmCollisionRule(EquationCollection):
def __init__(self, lbmMethod, *args, **kwargs):
super(LbmCollisionRule, self).__init__(*args, **kwargs)
self.method = lbmMethod
class AbstractLbmMethod(metaclass=abc.ABCMeta):
......@@ -47,5 +54,6 @@ class AbstractLbmMethod(metaclass=abc.ABCMeta):
@abc.abstractmethod
def getCollisionRule(self):
"""Returns an equation collection defining the collision operator."""
"""Returns an LbmCollisionRule i.e. an equation collection with a reference to the method.
This collision rule defines the collision operator."""
......@@ -151,7 +151,16 @@ class DensityVelocityComputation(AbstractConservedQuantityComputation):
nameToSymbol = {'density': self._symbolOrder0,
'velocity': self._symbolsOrder1}
return eqColl.extract({nameToSymbol[e] for e in outputQuantityNames})
symbolsToExtract = set()
for e in outputQuantityNames:
symbol = nameToSymbol[e]
if hasattr(symbol, "__len__"):
symbolsToExtract.update(symbol)
else:
symbolsToExtract.add(symbol)
return eqColl.extract(symbolsToExtract)
def __repr__(self):
return "ConservedValueComputation for %s" % (", " .join(self.conservedQuantities.keys()),)
......@@ -226,7 +235,7 @@ def divideFirstOrderMomentsByRho(equationCollection, dim):
rho = oldEqs[0].lhs
newFirstOrderMomentEq = [sp.Eq(eq.lhs, eq.rhs / rho) for eq in oldEqs[1:dim+1]]
newEqs = [oldEqs[0]] + newFirstOrderMomentEq + oldEqs[dim+1:]
return equationCollection.newWithAdditionalSubexpressions(newEqs, [])
return equationCollection.copy(newEqs)
def addDensityOffset(equationCollection, offset=sp.Rational(1, 1)):
......@@ -235,7 +244,7 @@ def addDensityOffset(equationCollection, offset=sp.Rational(1, 1)):
"""
oldEqs = equationCollection.mainEquations
newDensity = sp.Eq(oldEqs[0].lhs, oldEqs[0].rhs + offset)
return equationCollection.newWithAdditionalSubexpressions([newDensity] + oldEqs[1:], [])
return equationCollection.copy([newDensity] + oldEqs[1:])
def applyForceModelShift(shiftMemberName, dim, equationCollection, forceModel, compressible, reverse=False):
......@@ -254,7 +263,7 @@ def applyForceModelShift(shiftMemberName, dim, equationCollection, forceModel, c
velOffsets = [-v for v in velOffsets]
shiftedVelocityEqs = [sp.Eq(oldEq.lhs, oldEq.rhs + offset) for oldEq, offset in zip(oldVelEqs, velOffsets)]
newEqs = [oldEqs[0]] + shiftedVelocityEqs + oldEqs[dim + 1:]
return equationCollection.newWithAdditionalSubexpressions(newEqs, [])
return equationCollection.copy(newEqs)
else:
return equationCollection
......
......@@ -2,14 +2,14 @@ import sympy as sp
import collections
from collections import namedtuple, OrderedDict, defaultdict
from lbmpy.stencils import stencilsHaveSameEntries, getStencil
from lbmpy.maxwellian_equilibrium import getMomentsOfDiscreteMaxwellianEquilibrium, \
getMomentsOfContinuousMaxwellianEquilibrium
from lbmpy.methods.abstractlbmmethod import AbstractLbmMethod
from lbmpy.methods.abstractlbmmethod import AbstractLbmMethod, LbmCollisionRule
from lbmpy.methods.conservedquantitycomputation import AbstractConservedQuantityComputation, DensityVelocityComputation
from lbmpy.moments import MOMENT_SYMBOLS, momentMatrix, exponentsToPolynomialRepresentations, isShearMoment, \
momentsUpToComponentOrder, isEven, gramSchmidt, getOrder
from pystencils.equationcollection import EquationCollection
from pystencils.sympyextensions import commonDenominator
from lbmpy.moments import MOMENT_SYMBOLS, momentMatrix, isShearMoment, \
isEven, gramSchmidt, getOrder, getDefaultMomentSetForStencil
from pystencils.sympyextensions import commonDenominator, replaceAdditive
RelaxationInfo = namedtuple('Relaxationinfo', ['equilibriumValue', 'relaxationRate'])
......@@ -91,6 +91,10 @@ class MomentBasedLbmMethod(AbstractLbmMethod):
</tr>\n""".format(**vals)
return table.format(content=content, nb='style="border:none"')
@property
def moments(self):
return self._moments
@property
def zerothOrderEquilibriumMomentSymbol(self, ):
return self._conservedQuantityComputation.definedSymbols(order=0)[1]
......@@ -107,7 +111,12 @@ class MomentBasedLbmMethod(AbstractLbmMethod):
def _computeWeights(self):
replacements = self._conservedQuantityComputation.defaultValues
eqColl = self.getEquilibrium().newWithSubstitutionsApplied(replacements).insertSubexpressions()
eqColl = self.getEquilibrium().copyWithSubstitutionsApplied(replacements).insertSubexpressions()
newMainEqs = [sp.Eq(e.lhs,
replaceAdditive(e.rhs, 1, sum(self.preCollisionPdfSymbols), requiredMatchReplacement=1.0))
for e in eqColl.mainEquations]
eqColl = eqColl.copy(newMainEqs)
weights = []
for eq in eqColl.mainEquations:
value = eq.rhs.expand()
......@@ -144,7 +153,7 @@ class MomentBasedLbmMethod(AbstractLbmMethod):
if self._forceModel is not None:
forceModelTerms = self._forceModel(self)
newEqs = [sp.Eq(eq.lhs, eq.rhs + fmt) for eq, fmt in zip(eqColl.mainEquations, forceModelTerms)]
eqColl = eqColl.newWithAdditionalSubexpressions(newEqs, [])
eqColl = eqColl.copy(newEqs)
return eqColl
@property
......@@ -165,11 +174,10 @@ class MomentBasedLbmMethod(AbstractLbmMethod):
simplificationHints = eqValueEqs.simplificationHints
simplificationHints.update(self._conservedQuantityComputation.definedSymbols())
simplificationHints['relaxationRates'] = D.atoms(sp.Symbol)
simplificationHints['stencil'] = self.stencil
allSubexpressions = relaxationRateSubExpressions + eqValueEqs.subexpressions + eqValueEqs.mainEquations
return EquationCollection(collisionEqs, allSubexpressions,
simplificationHints)
return LbmCollisionRule(self, collisionEqs, allSubexpressions,
simplificationHints)
@staticmethod
def _generateRelaxationMatrix(relaxationMatrix):
......@@ -257,8 +265,8 @@ def createWithDiscreteMaxwellianEqMoments(stencil, momentToRelaxationRateDict, c
:return: :class:`lbmpy.methods.MomentBasedLbmMethod` instance
"""
momToRrDict = OrderedDict(momentToRelaxationRateDict)
assert len(momToRrDict) == len(
stencil), "The number of moments has to be the same as the number of stencil entries"
assert len(momToRrDict) == len(stencil), \
"The number of moments has to be the same as the number of stencil entries"
densityVelocityComputation = DensityVelocityComputation(stencil, compressible, forceModel)
eqMoments = getMomentsOfDiscreteMaxwellianEquilibrium(stencil, list(momToRrDict.keys()), c_s_sq=sp.Rational(1, 3),
......@@ -281,8 +289,7 @@ def createWithContinuousMaxwellianEqMoments(stencil, momentToRelaxationRateDict,
stencil), "The number of moments has to be the same as the number of stencil entries"
dim = len(stencil[0])
densityVelocityComputation = DensityVelocityComputation(stencil, True, forceModel)
eqMoments = getMomentsOfContinuousMaxwellianEquilibrium(list(momToRrDict.keys()), stencil, dim,
c_s_sq=sp.Rational(1, 3),
eqMoments = getMomentsOfContinuousMaxwellianEquilibrium(list(momToRrDict.keys()), dim, c_s_sq=sp.Rational(1, 3),
order=equilibriumAccuracyOrder)
rrDict = OrderedDict([(mom, RelaxationInfo(eqMom, rr))
for mom, rr, eqMom in zip(momToRrDict.keys(), momToRrDict.values(), eqMoments)])
......@@ -307,8 +314,7 @@ def createSRT(stencil, relaxationRate, compressible=False, forceModel=None, equi
:param equilibriumAccuracyOrder: approximation order of macroscopic velocity :math:`\mathbf{u}` in the equilibrium
:return: :class:`lbmpy.methods.MomentBasedLbmMethod` instance
"""
dim = len(stencil[0])
moments = exponentsToPolynomialRepresentations(momentsUpToComponentOrder(2, dim=dim))
moments = getDefaultMomentSetForStencil(stencil)
rrDict = {m: relaxationRate for m in moments}
return createWithDiscreteMaxwellianEqMoments(stencil, rrDict, compressible, forceModel, equilibriumAccuracyOrder)
......@@ -324,8 +330,7 @@ def createTRT(stencil, relaxationRateEvenMoments, relaxationRateOddMoments, comp
two relaxation rates: one for even moments (determines viscosity) and one for odd moments.
If unsure how to choose the odd relaxation rate, use the function :func:`lbmpy.methods.createTRTWithMagicNumber`.
"""
dim = len(stencil[0])
moments = exponentsToPolynomialRepresentations(momentsUpToComponentOrder(2, dim=dim))
moments = getDefaultMomentSetForStencil(stencil)
rrDict = {m: relaxationRateEvenMoments if isEven(m) else relaxationRateOddMoments for m in moments}
return createWithDiscreteMaxwellianEqMoments(stencil, rrDict, compressible, forceModel, equilibriumAccuracyOrder)
......@@ -363,18 +368,16 @@ def createOrthogonalMRT(stencil, relaxationRateGetter=None, compressible=False,
if relaxationRateGetter is None:
relaxationRateGetter = defaultRelaxationRateNames()
Q = len(stencil)
D = len(stencil[0])
x, y, z = MOMENT_SYMBOLS
one = sp.Rational(1, 1)
momentToRelaxationRateDict = OrderedDict()
if (D, Q) == (2, 9):
moments = exponentsToPolynomialRepresentations(momentsUpToComponentOrder(2, dim=D))
if stencilsHaveSameEntries(stencil, getStencil("D2Q9")):
moments = getDefaultMomentSetForStencil(stencil)
orthogonalMoments = gramSchmidt(moments, stencil)
orthogonalMomentsScaled = [e * commonDenominator(e) for e in orthogonalMoments]
nestedMoments = list(sortMomentsIntoGroupsOfSameOrder(orthogonalMomentsScaled).values())
elif (D, Q) == (3, 15):
elif stencilsHaveSameEntries(stencil, getStencil("D3Q15")):
sq = x ** 2 + y ** 2 + z ** 2
nestedMoments = [
[one, x, y, z], # [0, 3, 5, 7]
......@@ -384,7 +387,7 @@ def createOrthogonalMRT(stencil, relaxationRateGetter=None, compressible=False,
[(3 * sq - 5) * x, (3 * sq - 5) * y, (3 * sq - 5) * z], # [4, 6, 8]
[x * y * z]
]
elif (D, Q) == (3, 19):
elif stencilsHaveSameEntries(stencil, getStencil("D3Q19")):
sq = x ** 2 + y ** 2 + z ** 2
nestedMoments = [
[one, x, y, z], # [0, 3, 5, 7]
......@@ -395,7 +398,7 @@ def createOrthogonalMRT(stencil, relaxationRateGetter=None, compressible=False,
[(2 * sq - 3) * (3 * x ** 2 - sq), (2 * sq - 3) * (y ** 2 - z ** 2)], # [10, 12]
[(y ** 2 - z ** 2) * x, (z ** 2 - x ** 2) * y, (x ** 2 - y ** 2) * z] # [16, 17, 18]
]
elif (D, Q) == (3, 27):
elif stencilsHaveSameEntries(stencil, getStencil("D3Q27")):
xsq, ysq, zsq = x ** 2, y ** 2, z ** 2
allMoments = [
sp.Rational(1, 1), # 0
......
......@@ -22,7 +22,9 @@ def replaceSecondOrderVelocityProducts(lbmCollisionEqs):
for i, s in enumerate(lbmCollisionEqs.mainEquations):
newRhs = replaceSecondOrderProducts(s.rhs, u, positive=None, replaceMixed=substitutions)
result.append(sp.Eq(s.lhs, newRhs))
return lbmCollisionEqs.newWithAdditionalSubexpressions(result, substitutions)
res = lbmCollisionEqs.copy(result)
res.subexpressions += substitutions
return res
def factorRelaxationRates(lbmCollisionEqs):
......@@ -40,7 +42,7 @@ def factorRelaxationRates(lbmCollisionEqs):
for rp in sh['relaxationRates']:
newRhs = newRhs.collect(rp)
result.append(sp.Eq(s.lhs, newRhs))
return lbmCollisionEqs.newWithAdditionalSubexpressions(result, [])
return lbmCollisionEqs.copy(result)
def factorDensityAfterFactoringRelaxationTimes(lbmCollisionEqs):
......@@ -65,7 +67,7 @@ def factorDensityAfterFactoringRelaxationTimes(lbmCollisionEqs):
coeff = newRhs.coeff(rp)
newRhs = newRhs.subs(coeff, coeff.collect(rho))
result.append(sp.Eq(s.lhs, newRhs))
return lbmCollisionEqs.newWithAdditionalSubexpressions(result, [])
return lbmCollisionEqs.copy(result)
def replaceDensityAndVelocity(lbmCollisionEqs):
......@@ -88,7 +90,7 @@ def replaceDensityAndVelocity(lbmCollisionEqs):
for replacement in substitutions:
newRhs = replaceAdditive(newRhs, replacement.lhs, replacement.rhs, requiredMatchReplacement=0.5)
result.append(sp.Eq(s.lhs, newRhs))
return lbmCollisionEqs.newWithAdditionalSubexpressions(result, [])
return lbmCollisionEqs.copy(result)
def replaceCommonQuadraticAndConstantTerm(lbmCollisionEqs):
......@@ -106,9 +108,8 @@ def replaceCommonQuadraticAndConstantTerm(lbmCollisionEqs):
assert 'density' in sh, "Needs simplification hint 'density': Symbol for density"
assert 'velocity' in sh, "Needs simplification hint 'velocity': Sequence of velocity symbols"
assert 'relaxationRates' in sh, "Needs simplification hint 'relaxationRates': Set of symbolic relaxation rates"
assert 'stencil' in sh, "Needs simplification hint 'stencil': Sequence of discrete velocities"
stencil = sh['stencil']
stencil = lbmCollisionEqs.method.stencil
assert sum([abs(e) for e in stencil[0]]) == 0, "Works only if first stencil entry is the center direction"
f_eq_common = __getCommonQuadraticAndConstantTerms(lbmCollisionEqs)
......@@ -118,7 +119,9 @@ def replaceCommonQuadraticAndConstantTerm(lbmCollisionEqs):
for s in lbmCollisionEqs.mainEquations:
newRhs = replaceAdditive(s.rhs, f_eq_common.lhs, f_eq_common.rhs, requiredMatchReplacement=0.5)
result.append(sp.Eq(s.lhs, newRhs))
return lbmCollisionEqs.newWithAdditionalSubexpressions(result, [f_eq_common])
res = lbmCollisionEqs.copy(result)
res.subexpressions.append(f_eq_common)
return res
else:
return lbmCollisionEqs
......@@ -129,14 +132,13 @@ def cseInOpposingDirections(lbmCollisionEqs):
Required simplification hints:
- relaxationRates: set of symbolic relaxation rates
- stencil:
- postCollisionPdfSymbols: sequence of symbols
"""
sh = lbmCollisionEqs.simplificationHints
assert 'stencil' in sh, "Needs simplification hint 'stencil': Sequence of discrete velocities"
assert 'relaxationRates' in sh, "Needs simplification hint 'relaxationRates': Set of symbolic relaxation rates"
updateRules = lbmCollisionEqs.mainEquations
stencil = sh['stencil']
stencil = lbmCollisionEqs.method.stencil
relaxationRates = sh['relaxationRates']
replacementSymbolGenerator = lbmCollisionEqs.subexpressionSymbolNameGenerator
......@@ -193,7 +195,10 @@ def cseInOpposingDirections(lbmCollisionEqs):
for term, substitutedVar in newCoefficientSubstitutions.items():
substitutions.append(sp.Eq(substitutedVar, term))
return lbmCollisionEqs.newWithAdditionalSubexpressions(result, substitutions)
result.sort(key=lambda e: lbmCollisionEqs.method.postCollisionPdfSymbols.index(e.lhs))
res = lbmCollisionEqs.copy(result)
res.subexpressions += substitutions
return res
# -------------------------------------- Helper Functions --------------------------------------------------------------
......@@ -202,7 +207,7 @@ def __getCommonQuadraticAndConstantTerms(lbmCollisionEqs):
"""Determines a common subexpression useful for most LBM model often called f_eq_common.
It contains the quadratic and constant terms of the center update rule."""
sh = lbmCollisionEqs.simplificationHints
stencil = sh['stencil']
stencil = lbmCollisionEqs.method.stencil
relaxationRates = sh['relaxationRates']
dim = len(stencil[0])
......
......@@ -116,6 +116,15 @@ def momentMultiplicity(exponentTuple):
return result
def pickRepresentativeMoments(moments):
"""Picks the representative i.e. of each permutation group only one is kept"""
toRemove = []
for m in moments:
permutations = list(momentPermutations(m))
toRemove += permutations[1:]
return set(moments) - set(toRemove)
def momentPermutations(exponentTuple):
"""Returns all (unique) permutations of the given tuple"""
return __uniquePermutations(exponentTuple)
......@@ -351,6 +360,40 @@ def gramSchmidt(moments, stencil, weights=None):
return moments