Commit 44815e98 authored by Martin Bauer's avatar Martin Bauer
Browse files

Fixes for new LBStep / Datahandling

parent 1d1a192a
import sympy as sp
from functools import partial
from collections import defaultdict
from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
from pystencils.transformations import resolveBufferAccesses, resolveFieldAccesses, makeLoopOverDomain, \
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \
......
......@@ -110,7 +110,7 @@ class DataHandling(ABC):
:param name: name of the array to gather
:param sliceObj: slice expression of the rectangular sub-part that should be gathered
:param allGather: if False only the root process receives the result, if True all processes
:return: generator expression yielding the gathered field, the gathered field does not include any ghost layers
:return: gathered field that does not include any ghost layers, or None if gathered on another process
"""
@abstractmethod
......
......@@ -157,12 +157,16 @@ class ParallelDataHandling(DataHandling):
sliceObj = tuple([slice(None, None, None)] * self.dim)
if self.dim == 2:
sliceObj += (0.5,)
for array in wlb.field.gatherGenerator(self.blocks, name, sliceObj, allGather):
if self.fields[name].indexDimensions == 0:
array = array[..., 0]
if self.dim == 2:
array = array[:, :, 0]
yield array
array = wlb.field.gatherField(self.blocks, name, sliceObj, allGather)
if array is None:
return None
if self.fields[name].indexDimensions == 0:
array = array[..., 0]
if self.dim == 2:
array = array[:, :, 0]
return array
def _normalizeArrShape(self, arr, indexDimensions):
if indexDimensions == 0:
......
......@@ -98,6 +98,7 @@ class SerialDataHandling(DataHandling):
indexDimensions = 0
layoutTuple = spatialLayoutStringToTuple(layout, self.dim)
# cpuArr is always created - since there is no createPycudaArrayWithLayout()
cpuArr = createNumpyArrayWithLayout(layout=layoutTuple, **kwargs)
if cpu:
......@@ -111,7 +112,7 @@ class SerialDataHandling(DataHandling):
assert all(f.name != latexName for f in self.fields.values()), "Symbolic field with this name already exists"
self.fields[name] = Field.createFixedSize(latexName, shape=kwargs['shape'], indexDimensions=indexDimensions,
dtype=kwargs['dtype'], layout=layout)
dtype=kwargs['dtype'], layout=layoutTuple)
self._fieldLatexNameToDataName[latexName] = name
def addCustomData(self, name, cpuCreationFunction,
......@@ -171,7 +172,10 @@ class SerialDataHandling(DataHandling):
sliceObj = normalizeSlice(sliceObj, arr.shape[:-indDimensions] if indDimensions > 0 else arr.shape)
sliceObj = tuple(s if type(s) is slice else slice(s, s + 1, None) for s in sliceObj)
arr = arr[sliceObj]
yield arr
else:
arr = arr.view()
arr.flags.writeable = False
return arr
def swap(self, name1, name2, gpu=False):
if not gpu:
......@@ -216,7 +220,6 @@ class SerialDataHandling(DataHandling):
return self._synchronizationFunctor(names, stencilName, 'gpu')
def _synchronizationFunctor(self, names, stencil, target):
assert target in ('cpu', 'gpu')
if not hasattr(names, '__len__') or type(names) is str:
names = [names]
......@@ -224,9 +227,9 @@ class SerialDataHandling(DataHandling):
filteredStencil = []
neighbors = [-1, 0, 1]
if stencil.startswith('D2'):
if (stencil is None and self.dim == 2) or (stencil is not None and stencil.startswith('D2')):
directions = itertools.product(*[neighbors] * 2)
elif stencil.startswith('D3'):
elif (stencil is None and self.dim == 3) or (stencil is not None and stencil.startswith('D3')):
directions = itertools.product(*[neighbors] * 3)
else:
raise ValueError("Invalid stencil")
......
import matplotlib
import matplotlib.pyplot as plt
import pystencils.plot2d as plt
import matplotlib.animation as animation
from IPython.display import HTML
from tempfile import NamedTemporaryFile
import base64
from IPython import get_ipython
......@@ -152,7 +150,7 @@ display_animation_func = None
def disp(*args, **kwargs):
if not display_animation_func:
raise "Call set_display_mode first"
raise Exception("Call set_display_mode first")
return display_animation_func(*args, **kwargs)
......@@ -179,4 +177,3 @@ ipython = get_ipython()
if ipython:
setDisplayMode('imageupdate')
ipython.magic("matplotlib inline")
matplotlib.rcParams['figure.figsize'] = (16.0, 6.0)
\ No newline at end of file
......@@ -3,7 +3,7 @@ from pystencils.gpucuda.indexing import indexingCreatorFromParams
def createKernel(equations, target='cpu', dataType="double", iterationSlice=None, ghostLayers=None,
cpuOpenMP=True, cpuVectorizeInfo=None,
cpuOpenMP=False, cpuVectorizeInfo=None,
gpuIndexing='block', gpuIndexingParams={}):
"""
Creates abstract syntax tree (AST) of kernel, using a list of update equations.
......
......@@ -73,7 +73,7 @@ def vectorFieldAnimation(runFunction, step=2, rescale=True, plotSetupFunction=la
field = runFunction()
if rescale:
maxNorm = np.max(norm(field, axis=2, ord=2))
field /= maxNorm
field = field / maxNorm
if 'scale' not in kwargs:
kwargs['scale'] = 1.0
......@@ -85,7 +85,7 @@ def vectorFieldAnimation(runFunction, step=2, rescale=True, plotSetupFunction=la
f = np.swapaxes(f, 0, 1)
if rescale:
maxNorm = np.max(norm(f, axis=2, ord=2))
f /= maxNorm
f = f / maxNorm
u, v = f[::step, ::step, 0], f[::step, ::step, 1]
quiverPlot.set_UVC(u, v)
plotUpdateFunction()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment