From bec1010fa058a277f624ca2241c711921b0b1d70 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Sun, 22 Sep 2019 00:05:36 +0200 Subject: [PATCH] llvm: Implement LLVMPrinter._print_ThreadIndexingSymbol --- pystencils/gpucuda/indexing.py | 8 ++++---- pystencils/llvm/kernelcreation.py | 17 ++++++++++++++--- pystencils/llvm/llvm.py | 27 ++++++++++++++++++++++++++- 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/pystencils/gpucuda/indexing.py b/pystencils/gpucuda/indexing.py index 4c8701b2..eb212119 100644 --- a/pystencils/gpucuda/indexing.py +++ b/pystencils/gpucuda/indexing.py @@ -24,10 +24,10 @@ class ThreadIndexingSymbol(TypedSymbol): __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) -BLOCK_IDX = [ThreadIndexingSymbol("blockIdx." + coord, create_type("int")) for coord in ('x', 'y', 'z')] -THREAD_IDX = [ThreadIndexingSymbol("threadIdx." + coord, create_type("int")) for coord in ('x', 'y', 'z')] -BLOCK_DIM = [ThreadIndexingSymbol("blockDim." + coord, create_type("int")) for coord in ('x', 'y', 'z')] -GRID_DIM = [ThreadIndexingSymbol("gridDim." + coord, create_type("int")) for coord in ('x', 'y', 'z')] +BLOCK_IDX = [ThreadIndexingSymbol("blockIdx." + coord, create_type("int32")) for coord in ('x', 'y', 'z')] +THREAD_IDX = [ThreadIndexingSymbol("threadIdx." + coord, create_type("int32")) for coord in ('x', 'y', 'z')] +BLOCK_DIM = [ThreadIndexingSymbol("blockDim." + coord, create_type("int32")) for coord in ('x', 'y', 'z')] +GRID_DIM = [ThreadIndexingSymbol("gridDim." + coord, create_type("int32")) for coord in ('x', 'y', 'z')] class AbstractIndexing(abc.ABC): diff --git a/pystencils/llvm/kernelcreation.py b/pystencils/llvm/kernelcreation.py index 38ac7fe6..04e4292f 100644 --- a/pystencils/llvm/kernelcreation.py +++ b/pystencils/llvm/kernelcreation.py @@ -3,7 +3,7 @@ from pystencils.transformations import insert_casts def create_kernel(assignments, function_name="kernel", type_info=None, split_groups=(), - iteration_slice=None, ghost_layers=None): + iteration_slice=None, ghost_layers=None, target='cpu'): """ Creates an abstract syntax tree for a kernel function, by taking a list of update rules. @@ -25,9 +25,20 @@ def create_kernel(assignments, function_name="kernel", type_info=None, split_gro :return: :class:`pystencils.ast.KernelFunction` node """ - from pystencils.cpu import create_kernel - code = create_kernel(assignments, function_name, type_info, split_groups, iteration_slice, ghost_layers) + if target == 'cpu': + from pystencils.cpu import create_kernel + code = create_kernel(assignments, function_name, type_info, split_groups, iteration_slice, ghost_layers) + elif target == 'gpu': + from pystencils.gpucuda.kernelcreation import create_cuda_kernel + code = create_cuda_kernel(assignments, + function_name, + type_info, + iteration_slice=iteration_slice, + ghost_layers=ghost_layers) + else: + NotImplementedError() code.body = insert_casts(code.body) code._compile_function = make_python_function code._backend = 'llvm' + return code diff --git a/pystencils/llvm/llvm.py b/pystencils/llvm/llvm.py index edbae21c..de02fdec 100644 --- a/pystencils/llvm/llvm.py +++ b/pystencils/llvm/llvm.py @@ -1,6 +1,7 @@ import functools import llvmlite.ir as ir +import llvmlite.llvmpy.core as lc import sympy as sp from sympy import Indexed, S from sympy.printing.printer import Printer @@ -12,10 +13,18 @@ from pystencils.data_types import ( from pystencils.llvm.control_flow import Loop +# From Numba +def _call_sreg(builder, name): + module = builder.module + fnty = lc.Type.function(lc.Type.int(), ()) + fn = module.get_or_insert_function(fnty, name=name) + return builder.call(fn, ()) + + def generate_llvm(ast_node, module=None, builder=None): """Prints the ast as llvm code.""" if module is None: - module = ir.Module() + module = lc.Module() if builder is None: builder = ir.IRBuilder() printer = LLVMPrinter(module, builder) @@ -330,3 +339,19 @@ class LLVMPrinter(Printer): mro = "None" raise TypeError("Unsupported type for LLVM JIT conversion: Expression:\"%s\", Type:\"%s\", MRO:%s" % (expr, type(expr), mro)) + + # from: https://llvm.org/docs/NVPTXUsage.html#nvptx-intrinsics + INDEXING_FUNCTION_MAPPING = { + 'blockIdx': 'llvm.nvvm.read.ptx.sreg.ctaid', + 'threadIdx': 'llvm.nvvm.read.ptx.sreg.tid', + 'blockDim': 'llvm.nvvm.read.ptx.sreg.ntid', + 'gridDim': 'llvm.nvvm.read.ptx.sreg.nctaid' + } + + def _print_ThreadIndexingSymbol(self, node): + symbol_name: str = node.name + function_name, dimension = tuple(symbol_name.split(".")) + function_name = self.INDEXING_FUNCTION_MAPPING[function_name] + name = f"{function_name}.{dimension}" + + return self.builder.zext(_call_sreg(self.builder, name), self.integer) -- GitLab