From bec1010fa058a277f624ca2241c711921b0b1d70 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Sun, 22 Sep 2019 00:05:36 +0200
Subject: [PATCH] llvm: Implement LLVMPrinter._print_ThreadIndexingSymbol

---
 pystencils/gpucuda/indexing.py    |  8 ++++----
 pystencils/llvm/kernelcreation.py | 17 ++++++++++++++---
 pystencils/llvm/llvm.py           | 27 ++++++++++++++++++++++++++-
 3 files changed, 44 insertions(+), 8 deletions(-)

diff --git a/pystencils/gpucuda/indexing.py b/pystencils/gpucuda/indexing.py
index 4c8701b2..eb212119 100644
--- a/pystencils/gpucuda/indexing.py
+++ b/pystencils/gpucuda/indexing.py
@@ -24,10 +24,10 @@ class ThreadIndexingSymbol(TypedSymbol):
     __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
 
 
-BLOCK_IDX = [ThreadIndexingSymbol("blockIdx." + coord, create_type("int")) for coord in ('x', 'y', 'z')]
-THREAD_IDX = [ThreadIndexingSymbol("threadIdx." + coord, create_type("int")) for coord in ('x', 'y', 'z')]
-BLOCK_DIM = [ThreadIndexingSymbol("blockDim." + coord, create_type("int")) for coord in ('x', 'y', 'z')]
-GRID_DIM = [ThreadIndexingSymbol("gridDim." + coord, create_type("int")) for coord in ('x', 'y', 'z')]
+BLOCK_IDX = [ThreadIndexingSymbol("blockIdx." + coord, create_type("int32")) for coord in ('x', 'y', 'z')]
+THREAD_IDX = [ThreadIndexingSymbol("threadIdx." + coord, create_type("int32")) for coord in ('x', 'y', 'z')]
+BLOCK_DIM = [ThreadIndexingSymbol("blockDim." + coord, create_type("int32")) for coord in ('x', 'y', 'z')]
+GRID_DIM = [ThreadIndexingSymbol("gridDim." + coord, create_type("int32")) for coord in ('x', 'y', 'z')]
 
 
 class AbstractIndexing(abc.ABC):
diff --git a/pystencils/llvm/kernelcreation.py b/pystencils/llvm/kernelcreation.py
index 38ac7fe6..04e4292f 100644
--- a/pystencils/llvm/kernelcreation.py
+++ b/pystencils/llvm/kernelcreation.py
@@ -3,7 +3,7 @@ from pystencils.transformations import insert_casts
 
 
 def create_kernel(assignments, function_name="kernel", type_info=None, split_groups=(),
-                  iteration_slice=None, ghost_layers=None):
+                  iteration_slice=None, ghost_layers=None, target='cpu'):
     """
     Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
 
@@ -25,9 +25,20 @@ def create_kernel(assignments, function_name="kernel", type_info=None, split_gro
 
     :return: :class:`pystencils.ast.KernelFunction` node
     """
-    from pystencils.cpu import create_kernel
-    code = create_kernel(assignments, function_name, type_info, split_groups, iteration_slice, ghost_layers)
+    if target == 'cpu':
+        from pystencils.cpu import create_kernel
+        code = create_kernel(assignments, function_name, type_info, split_groups, iteration_slice, ghost_layers)
+    elif target == 'gpu':
+        from pystencils.gpucuda.kernelcreation import create_cuda_kernel
+        code = create_cuda_kernel(assignments,
+                                  function_name,
+                                  type_info,
+                                  iteration_slice=iteration_slice,
+                                  ghost_layers=ghost_layers)
+    else:
+        NotImplementedError()
     code.body = insert_casts(code.body)
     code._compile_function = make_python_function
     code._backend = 'llvm'
+
     return code
diff --git a/pystencils/llvm/llvm.py b/pystencils/llvm/llvm.py
index edbae21c..de02fdec 100644
--- a/pystencils/llvm/llvm.py
+++ b/pystencils/llvm/llvm.py
@@ -1,6 +1,7 @@
 import functools
 
 import llvmlite.ir as ir
+import llvmlite.llvmpy.core as lc
 import sympy as sp
 from sympy import Indexed, S
 from sympy.printing.printer import Printer
@@ -12,10 +13,18 @@ from pystencils.data_types import (
 from pystencils.llvm.control_flow import Loop
 
 
+# From Numba
+def _call_sreg(builder, name):
+    module = builder.module
+    fnty = lc.Type.function(lc.Type.int(), ())
+    fn = module.get_or_insert_function(fnty, name=name)
+    return builder.call(fn, ())
+
+
 def generate_llvm(ast_node, module=None, builder=None):
     """Prints the ast as llvm code."""
     if module is None:
-        module = ir.Module()
+        module = lc.Module()
     if builder is None:
         builder = ir.IRBuilder()
     printer = LLVMPrinter(module, builder)
@@ -330,3 +339,19 @@ class LLVMPrinter(Printer):
             mro = "None"
         raise TypeError("Unsupported type for LLVM JIT conversion: Expression:\"%s\", Type:\"%s\", MRO:%s"
                         % (expr, type(expr), mro))
+
+    # from: https://llvm.org/docs/NVPTXUsage.html#nvptx-intrinsics
+    INDEXING_FUNCTION_MAPPING = {
+        'blockIdx': 'llvm.nvvm.read.ptx.sreg.ctaid',
+        'threadIdx': 'llvm.nvvm.read.ptx.sreg.tid',
+        'blockDim': 'llvm.nvvm.read.ptx.sreg.ntid',
+        'gridDim': 'llvm.nvvm.read.ptx.sreg.nctaid'
+    }
+
+    def _print_ThreadIndexingSymbol(self, node):
+        symbol_name: str = node.name
+        function_name, dimension = tuple(symbol_name.split("."))
+        function_name = self.INDEXING_FUNCTION_MAPPING[function_name]
+        name = f"{function_name}.{dimension}"
+
+        return self.builder.zext(_call_sreg(self.builder, name), self.integer)
-- 
GitLab