Skip to content
Snippets Groups Projects
Commit afd34623 authored by Martin Bauer's avatar Martin Bauer
Browse files

Ported most of the tutorials/demos/benchmarks

parent 44815e98
Branches
Tags
No related merge requests found
......@@ -102,7 +102,7 @@ class DataHandling(ABC):
"""
@abstractmethod
def gatherArray(self, name, sliceObj=None, allGather=False):
def gatherArray(self, name, sliceObj=None, allGather=False, ghostLayers=False):
"""
Gathers part of the domain on a local process. Whenever possible use 'access' instead, since this method copies
the distributed data to a single process which is inefficient and may exhaust the available memory
......@@ -110,6 +110,7 @@ class DataHandling(ABC):
:param name: name of the array to gather
:param sliceObj: slice expression of the rectangular sub-part that should be gathered
:param allGather: if False only the root process receives the result, if True all processes
:param ghostLayers: number of outer ghost layers to include (only available for serial data handlings)
:return: gathered field that does not include any ghost layers, or None if gathered on another process
"""
......
......@@ -4,7 +4,7 @@ from pystencils.datahandling.datahandling_interface import DataHandling
from pystencils.parallel.blockiteration import slicedBlockIteration, blockIteration
from pystencils.utils import DotDict
import waLBerla as wlb
import warnings
class ParallelDataHandling(DataHandling):
GPU_DATA_PREFIX = "gpu_"
......@@ -152,7 +152,11 @@ class ParallelDataHandling(DataHandling):
else:
yield from blockIteration(self.blocks, ghostLayers, self.dim, prefix)
def gatherArray(self, name, sliceObj=None, allGather=False):
def gatherArray(self, name, sliceObj=None, allGather=False, ghostLayers=False):
if ghostLayers is not False:
warnings.warn("gatherArray with ghost layers is only supported in serial datahandling. "
"Array without ghost layers is returned")
if sliceObj is None:
sliceObj = tuple([slice(None, None, None)] * self.dim)
if self.dim == 2:
......
......@@ -16,13 +16,6 @@ except ImportError:
class SerialDataHandling(DataHandling):
class _PassThroughContextManager:
def __init__(self, arr):
self.arr = arr
def __enter__(self, *args, **kwargs):
return self.arr
def __init__(self, domainSize, defaultGhostLayers=1, defaultLayout='SoA', periodicity=False):
"""
Creates a data handling for single node simulations
......@@ -98,7 +91,6 @@ class SerialDataHandling(DataHandling):
indexDimensions = 0
layoutTuple = spatialLayoutStringToTuple(layout, self.dim)
# cpuArr is always created - since there is no createPycudaArrayWithLayout()
cpuArr = createNumpyArrayWithLayout(layout=layoutTuple, **kwargs)
if cpu:
......@@ -162,11 +154,15 @@ class SerialDataHandling(DataHandling):
offset = tuple(s.start - ghostLayers for s in sliceObj)
yield SerialBlock(iterDict, offset, sliceObj)
def gatherArray(self, name, sliceObj=None, **kwargs):
gls = self._fieldInformation[name]['ghostLayers']
def gatherArray(self, name, sliceObj=None, ghostLayers=False, **kwargs):
glToRemove = self._fieldInformation[name]['ghostLayers']
if isinstance(ghostLayers, int):
glToRemove -= ghostLayers
if ghostLayers is True:
glToRemove = 0
arr = self.cpuArrays[name]
indDimensions = self.fields[name].indexDimensions
arr = removeGhostLayers(arr, indexDimensions=indDimensions, ghostLayers=gls)
arr = removeGhostLayers(arr, indexDimensions=indDimensions, ghostLayers=glToRemove)
if sliceObj is not None:
sliceObj = normalizeSlice(sliceObj, arr.shape[:-indDimensions] if indDimensions > 0 else arr.shape)
......
......@@ -4,6 +4,7 @@ from IPython.display import HTML
from tempfile import NamedTemporaryFile
import base64
from IPython import get_ipython
import sympy as sp
def log_progress(sequence, every=None, size=None, name='Items'):
......@@ -176,4 +177,7 @@ def setDisplayMode(mode):
ipython = get_ipython()
if ipython:
setDisplayMode('imageupdate')
ipython.magic("config InlineBackend.rc = { }")
ipython.magic("matplotlib inline")
plt.rc('figure', figsize=(16, 6))
sp.init_printing()
......@@ -9,6 +9,14 @@ class SliceMaker(object):
makeSlice = SliceMaker()
class SlicedGetter(object):
def __init__(self, functionReturningArray):
self._functionReturningArray = functionReturningArray
def __getitem__(self, item):
return self._functionReturningArray(item)
def normalizeSlice(slices, sizes):
"""Converts slices with floating point and/or negative entries to integer slices"""
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment