diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index eb5e4851176bb82db62d77ebb0100c092d2322f4..f290942ceaf6a5b4defad6c6779b7913395caf7d 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -1,8 +1,5 @@ import sympy as sp from functools import partial - -from collections import defaultdict - from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction from pystencils.transformations import resolveBufferAccesses, resolveFieldAccesses, makeLoopOverDomain, \ typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \ diff --git a/datahandling/datahandling_interface.py b/datahandling/datahandling_interface.py index af69c67c800f39e3f907254db226650086e63b7b..b2faa156c3086c85fc0de5c788c0541de919b16c 100644 --- a/datahandling/datahandling_interface.py +++ b/datahandling/datahandling_interface.py @@ -110,7 +110,7 @@ class DataHandling(ABC): :param name: name of the array to gather :param sliceObj: slice expression of the rectangular sub-part that should be gathered :param allGather: if False only the root process receives the result, if True all processes - :return: generator expression yielding the gathered field, the gathered field does not include any ghost layers + :return: gathered field that does not include any ghost layers, or None if gathered on another process """ @abstractmethod diff --git a/datahandling/parallel_datahandling.py b/datahandling/parallel_datahandling.py index 20fd0a99fc35a9c0e7f6cfa5f575bb219e181e93..d8d558cdcabfcc996bf7631a5d50aa76ec3aa121 100644 --- a/datahandling/parallel_datahandling.py +++ b/datahandling/parallel_datahandling.py @@ -157,12 +157,16 @@ class ParallelDataHandling(DataHandling): sliceObj = tuple([slice(None, None, None)] * self.dim) if self.dim == 2: sliceObj += (0.5,) - for array in wlb.field.gatherGenerator(self.blocks, name, sliceObj, allGather): - if self.fields[name].indexDimensions == 0: - array = array[..., 0] - if self.dim == 2: - array = array[:, :, 0] - yield array + + array = wlb.field.gatherField(self.blocks, name, sliceObj, allGather) + if array is None: + return None + + if self.fields[name].indexDimensions == 0: + array = array[..., 0] + if self.dim == 2: + array = array[:, :, 0] + return array def _normalizeArrShape(self, arr, indexDimensions): if indexDimensions == 0: diff --git a/datahandling/serial_datahandling.py b/datahandling/serial_datahandling.py index ed149b9b3498ad9ca9ac30c1cb05fe8ba86ea096..1662462a865db607b9c61306be7d8e69cbd8fed0 100644 --- a/datahandling/serial_datahandling.py +++ b/datahandling/serial_datahandling.py @@ -98,6 +98,7 @@ class SerialDataHandling(DataHandling): indexDimensions = 0 layoutTuple = spatialLayoutStringToTuple(layout, self.dim) + # cpuArr is always created - since there is no createPycudaArrayWithLayout() cpuArr = createNumpyArrayWithLayout(layout=layoutTuple, **kwargs) if cpu: @@ -111,7 +112,7 @@ class SerialDataHandling(DataHandling): assert all(f.name != latexName for f in self.fields.values()), "Symbolic field with this name already exists" self.fields[name] = Field.createFixedSize(latexName, shape=kwargs['shape'], indexDimensions=indexDimensions, - dtype=kwargs['dtype'], layout=layout) + dtype=kwargs['dtype'], layout=layoutTuple) self._fieldLatexNameToDataName[latexName] = name def addCustomData(self, name, cpuCreationFunction, @@ -171,7 +172,10 @@ class SerialDataHandling(DataHandling): sliceObj = normalizeSlice(sliceObj, arr.shape[:-indDimensions] if indDimensions > 0 else arr.shape) sliceObj = tuple(s if type(s) is slice else slice(s, s + 1, None) for s in sliceObj) arr = arr[sliceObj] - yield arr + else: + arr = arr.view() + arr.flags.writeable = False + return arr def swap(self, name1, name2, gpu=False): if not gpu: @@ -216,7 +220,6 @@ class SerialDataHandling(DataHandling): return self._synchronizationFunctor(names, stencilName, 'gpu') def _synchronizationFunctor(self, names, stencil, target): - assert target in ('cpu', 'gpu') if not hasattr(names, '__len__') or type(names) is str: names = [names] @@ -224,9 +227,9 @@ class SerialDataHandling(DataHandling): filteredStencil = [] neighbors = [-1, 0, 1] - if stencil.startswith('D2'): + if (stencil is None and self.dim == 2) or (stencil is not None and stencil.startswith('D2')): directions = itertools.product(*[neighbors] * 2) - elif stencil.startswith('D3'): + elif (stencil is None and self.dim == 3) or (stencil is not None and stencil.startswith('D3')): directions = itertools.product(*[neighbors] * 3) else: raise ValueError("Invalid stencil") diff --git a/jupytersetup.py b/jupytersetup.py index 82805254a53aab261b687f28b7095d8b82ba9600..f390c9c4b2a75f1ba36707877beb53a49e68236f 100644 --- a/jupytersetup.py +++ b/jupytersetup.py @@ -1,10 +1,8 @@ -import matplotlib -import matplotlib.pyplot as plt +import pystencils.plot2d as plt import matplotlib.animation as animation from IPython.display import HTML from tempfile import NamedTemporaryFile import base64 - from IPython import get_ipython @@ -152,7 +150,7 @@ display_animation_func = None def disp(*args, **kwargs): if not display_animation_func: - raise "Call set_display_mode first" + raise Exception("Call set_display_mode first") return display_animation_func(*args, **kwargs) @@ -179,4 +177,3 @@ ipython = get_ipython() if ipython: setDisplayMode('imageupdate') ipython.magic("matplotlib inline") - matplotlib.rcParams['figure.figsize'] = (16.0, 6.0) \ No newline at end of file diff --git a/kernelcreation.py b/kernelcreation.py index 0823c017cf7e6a881d8618f2f53ab9ae849d6e42..65760903e0935912085cef52ee62f12e8579d916 100644 --- a/kernelcreation.py +++ b/kernelcreation.py @@ -3,7 +3,7 @@ from pystencils.gpucuda.indexing import indexingCreatorFromParams def createKernel(equations, target='cpu', dataType="double", iterationSlice=None, ghostLayers=None, - cpuOpenMP=True, cpuVectorizeInfo=None, + cpuOpenMP=False, cpuVectorizeInfo=None, gpuIndexing='block', gpuIndexingParams={}): """ Creates abstract syntax tree (AST) of kernel, using a list of update equations. diff --git a/plot2d.py b/plot2d.py index f0e409be47ffc646064aa26730ed773adbe7bdfd..58a37d68a14770fe67a4a6a72af86789476642fd 100644 --- a/plot2d.py +++ b/plot2d.py @@ -73,7 +73,7 @@ def vectorFieldAnimation(runFunction, step=2, rescale=True, plotSetupFunction=la field = runFunction() if rescale: maxNorm = np.max(norm(field, axis=2, ord=2)) - field /= maxNorm + field = field / maxNorm if 'scale' not in kwargs: kwargs['scale'] = 1.0 @@ -85,7 +85,7 @@ def vectorFieldAnimation(runFunction, step=2, rescale=True, plotSetupFunction=la f = np.swapaxes(f, 0, 1) if rescale: maxNorm = np.max(norm(f, axis=2, ord=2)) - f /= maxNorm + f = f / maxNorm u, v = f[::step, ::step, 0], f[::step, ::step, 1] quiverPlot.set_UVC(u, v) plotUpdateFunction()