Commit bec1010f authored by Stephan Seitz's avatar Stephan Seitz
Browse files

llvm: Implement LLVMPrinter._print_ThreadIndexingSymbol

parent a4b64edf
...@@ -24,10 +24,10 @@ class ThreadIndexingSymbol(TypedSymbol): ...@@ -24,10 +24,10 @@ class ThreadIndexingSymbol(TypedSymbol):
__xnew_cached_ = staticmethod(cacheit(__new_stage2__)) __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
BLOCK_IDX = [ThreadIndexingSymbol("blockIdx." + coord, create_type("int")) for coord in ('x', 'y', 'z')] BLOCK_IDX = [ThreadIndexingSymbol("blockIdx." + coord, create_type("int32")) for coord in ('x', 'y', 'z')]
THREAD_IDX = [ThreadIndexingSymbol("threadIdx." + coord, create_type("int")) for coord in ('x', 'y', 'z')] THREAD_IDX = [ThreadIndexingSymbol("threadIdx." + coord, create_type("int32")) for coord in ('x', 'y', 'z')]
BLOCK_DIM = [ThreadIndexingSymbol("blockDim." + coord, create_type("int")) for coord in ('x', 'y', 'z')] BLOCK_DIM = [ThreadIndexingSymbol("blockDim." + coord, create_type("int32")) for coord in ('x', 'y', 'z')]
GRID_DIM = [ThreadIndexingSymbol("gridDim." + coord, create_type("int")) for coord in ('x', 'y', 'z')] GRID_DIM = [ThreadIndexingSymbol("gridDim." + coord, create_type("int32")) for coord in ('x', 'y', 'z')]
class AbstractIndexing(abc.ABC): class AbstractIndexing(abc.ABC):
......
...@@ -3,7 +3,7 @@ from pystencils.transformations import insert_casts ...@@ -3,7 +3,7 @@ from pystencils.transformations import insert_casts
def create_kernel(assignments, function_name="kernel", type_info=None, split_groups=(), def create_kernel(assignments, function_name="kernel", type_info=None, split_groups=(),
iteration_slice=None, ghost_layers=None): iteration_slice=None, ghost_layers=None, target='cpu'):
""" """
Creates an abstract syntax tree for a kernel function, by taking a list of update rules. Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
...@@ -25,9 +25,20 @@ def create_kernel(assignments, function_name="kernel", type_info=None, split_gro ...@@ -25,9 +25,20 @@ def create_kernel(assignments, function_name="kernel", type_info=None, split_gro
:return: :class:`pystencils.ast.KernelFunction` node :return: :class:`pystencils.ast.KernelFunction` node
""" """
from pystencils.cpu import create_kernel if target == 'cpu':
code = create_kernel(assignments, function_name, type_info, split_groups, iteration_slice, ghost_layers) from pystencils.cpu import create_kernel
code = create_kernel(assignments, function_name, type_info, split_groups, iteration_slice, ghost_layers)
elif target == 'gpu':
from pystencils.gpucuda.kernelcreation import create_cuda_kernel
code = create_cuda_kernel(assignments,
function_name,
type_info,
iteration_slice=iteration_slice,
ghost_layers=ghost_layers)
else:
NotImplementedError()
code.body = insert_casts(code.body) code.body = insert_casts(code.body)
code._compile_function = make_python_function code._compile_function = make_python_function
code._backend = 'llvm' code._backend = 'llvm'
return code return code
import functools import functools
import llvmlite.ir as ir import llvmlite.ir as ir
import llvmlite.llvmpy.core as lc
import sympy as sp import sympy as sp
from sympy import Indexed, S from sympy import Indexed, S
from sympy.printing.printer import Printer from sympy.printing.printer import Printer
...@@ -12,10 +13,18 @@ from pystencils.data_types import ( ...@@ -12,10 +13,18 @@ from pystencils.data_types import (
from pystencils.llvm.control_flow import Loop from pystencils.llvm.control_flow import Loop
# From Numba
def _call_sreg(builder, name):
module = builder.module
fnty = lc.Type.function(lc.Type.int(), ())
fn = module.get_or_insert_function(fnty, name=name)
return builder.call(fn, ())
def generate_llvm(ast_node, module=None, builder=None): def generate_llvm(ast_node, module=None, builder=None):
"""Prints the ast as llvm code.""" """Prints the ast as llvm code."""
if module is None: if module is None:
module = ir.Module() module = lc.Module()
if builder is None: if builder is None:
builder = ir.IRBuilder() builder = ir.IRBuilder()
printer = LLVMPrinter(module, builder) printer = LLVMPrinter(module, builder)
...@@ -330,3 +339,19 @@ class LLVMPrinter(Printer): ...@@ -330,3 +339,19 @@ class LLVMPrinter(Printer):
mro = "None" mro = "None"
raise TypeError("Unsupported type for LLVM JIT conversion: Expression:\"%s\", Type:\"%s\", MRO:%s" raise TypeError("Unsupported type for LLVM JIT conversion: Expression:\"%s\", Type:\"%s\", MRO:%s"
% (expr, type(expr), mro)) % (expr, type(expr), mro))
# from: https://llvm.org/docs/NVPTXUsage.html#nvptx-intrinsics
INDEXING_FUNCTION_MAPPING = {
'blockIdx': 'llvm.nvvm.read.ptx.sreg.ctaid',
'threadIdx': 'llvm.nvvm.read.ptx.sreg.tid',
'blockDim': 'llvm.nvvm.read.ptx.sreg.ntid',
'gridDim': 'llvm.nvvm.read.ptx.sreg.nctaid'
}
def _print_ThreadIndexingSymbol(self, node):
symbol_name: str = node.name
function_name, dimension = tuple(symbol_name.split("."))
function_name = self.INDEXING_FUNCTION_MAPPING[function_name]
name = f"{function_name}.{dimension}"
return self.builder.zext(_call_sreg(self.builder, name), self.integer)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment