diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index 8ae94cbbfbd52984f1bcafc824b7007e8960ce71..dab631a97f19572e6f9275d0f4ffdba0ac68a665 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -126,7 +126,7 @@ def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typ return ast -def addOpenMP(astNode, schedule="static", numThreads=None): +def addOpenMP(astNode, schedule="static", numThreads=True): """ Parallelizes the outer loop with OpenMP @@ -134,13 +134,37 @@ def addOpenMP(astNode, schedule="static", numThreads=None): :param schedule: OpenMP scheduling policy e.g. 'static' or 'dynamic' :param numThreads: explicitly specify number of threads """ + if not numThreads: + return + assert type(astNode) is ast.KernelFunction body = astNode.body - threadsClause = "" if numThreads is None else " num_threads(%s)" % (numThreads,) + threadsClause = "" if numThreads else " num_threads(%s)" % (numThreads,) wrapperBlock = ast.PragmaBlock('#pragma omp parallel' + threadsClause, body.takeChildNodes()) body.append(wrapperBlock) outerLoops = [l for l in body.atoms(ast.LoopOverCoordinate) if l.isOutermostLoop] assert outerLoops, "No outer loop found" assert len(outerLoops) <= 1, "More than one outer loop found. Which one should be parallelized?" - outerLoops[0].prefixLines.append("#pragma omp for schedule(%s)" % (schedule,)) + loopToParallelize = outerLoops[0] + try: + loopRange = int(loopToParallelize.stop - loopToParallelize.start) + except TypeError: + loopRange = None + + if numThreads is None: + import multiprocessing + numThreads = multiprocessing.cpu_count() + + if loopRange is not None and loopRange < numThreads: + containedLoops = [l for l in loopToParallelize.body.args if isinstance(l, LoopOverCoordinate)] + if len(containedLoops) == 1: + containedLoop = containedLoops[0] + try: + containedLoopRange = int(containedLoop.stop - containedLoop.start) + if containedLoopRange > loopRange: + loopToParallelize = containedLoop + except TypeError: + pass + + loopToParallelize.prefixLines.append("#pragma omp for schedule(%s)" % (schedule,))