From 330ffcf3a8dc07a948f44fd6b007091f0b25eed1 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Fri, 21 Apr 2017 13:34:52 +0200 Subject: [PATCH] Simpler OpenMP handling & OpenMP parallel boundaries - periodic kernels not yet OpenMP parallel --- cpu/kernelcreation.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index 8ae94cbbf..dab631a97 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -126,7 +126,7 @@ def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typ return ast -def addOpenMP(astNode, schedule="static", numThreads=None): +def addOpenMP(astNode, schedule="static", numThreads=True): """ Parallelizes the outer loop with OpenMP @@ -134,13 +134,37 @@ def addOpenMP(astNode, schedule="static", numThreads=None): :param schedule: OpenMP scheduling policy e.g. 'static' or 'dynamic' :param numThreads: explicitly specify number of threads """ + if not numThreads: + return + assert type(astNode) is ast.KernelFunction body = astNode.body - threadsClause = "" if numThreads is None else " num_threads(%s)" % (numThreads,) + threadsClause = "" if numThreads else " num_threads(%s)" % (numThreads,) wrapperBlock = ast.PragmaBlock('#pragma omp parallel' + threadsClause, body.takeChildNodes()) body.append(wrapperBlock) outerLoops = [l for l in body.atoms(ast.LoopOverCoordinate) if l.isOutermostLoop] assert outerLoops, "No outer loop found" assert len(outerLoops) <= 1, "More than one outer loop found. Which one should be parallelized?" - outerLoops[0].prefixLines.append("#pragma omp for schedule(%s)" % (schedule,)) + loopToParallelize = outerLoops[0] + try: + loopRange = int(loopToParallelize.stop - loopToParallelize.start) + except TypeError: + loopRange = None + + if numThreads is None: + import multiprocessing + numThreads = multiprocessing.cpu_count() + + if loopRange is not None and loopRange < numThreads: + containedLoops = [l for l in loopToParallelize.body.args if isinstance(l, LoopOverCoordinate)] + if len(containedLoops) == 1: + containedLoop = containedLoops[0] + try: + containedLoopRange = int(containedLoop.stop - containedLoop.start) + if containedLoopRange > loopRange: + loopToParallelize = containedLoop + except TypeError: + pass + + loopToParallelize.prefixLines.append("#pragma omp for schedule(%s)" % (schedule,)) -- GitLab