Commit bec1010f authored by Stephan Seitz's avatar Stephan Seitz
Browse files

llvm: Implement LLVMPrinter._print_ThreadIndexingSymbol

parent a4b64edf
......@@ -24,10 +24,10 @@ class ThreadIndexingSymbol(TypedSymbol):
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
BLOCK_IDX = [ThreadIndexingSymbol("blockIdx." + coord, create_type("int")) for coord in ('x', 'y', 'z')]
THREAD_IDX = [ThreadIndexingSymbol("threadIdx." + coord, create_type("int")) for coord in ('x', 'y', 'z')]
BLOCK_DIM = [ThreadIndexingSymbol("blockDim." + coord, create_type("int")) for coord in ('x', 'y', 'z')]
GRID_DIM = [ThreadIndexingSymbol("gridDim." + coord, create_type("int")) for coord in ('x', 'y', 'z')]
BLOCK_IDX = [ThreadIndexingSymbol("blockIdx." + coord, create_type("int32")) for coord in ('x', 'y', 'z')]
THREAD_IDX = [ThreadIndexingSymbol("threadIdx." + coord, create_type("int32")) for coord in ('x', 'y', 'z')]
BLOCK_DIM = [ThreadIndexingSymbol("blockDim." + coord, create_type("int32")) for coord in ('x', 'y', 'z')]
GRID_DIM = [ThreadIndexingSymbol("gridDim." + coord, create_type("int32")) for coord in ('x', 'y', 'z')]
class AbstractIndexing(abc.ABC):
......@@ -3,7 +3,7 @@ from pystencils.transformations import insert_casts
def create_kernel(assignments, function_name="kernel", type_info=None, split_groups=(),
iteration_slice=None, ghost_layers=None):
iteration_slice=None, ghost_layers=None, target='cpu'):
Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
......@@ -25,9 +25,20 @@ def create_kernel(assignments, function_name="kernel", type_info=None, split_gro
:return: :class:`pystencils.ast.KernelFunction` node
if target == 'cpu':
from pystencils.cpu import create_kernel
code = create_kernel(assignments, function_name, type_info, split_groups, iteration_slice, ghost_layers)
elif target == 'gpu':
from pystencils.gpucuda.kernelcreation import create_cuda_kernel
code = create_cuda_kernel(assignments,
code.body = insert_casts(code.body)
code._compile_function = make_python_function
code._backend = 'llvm'
return code
import functools
import as ir
import llvmlite.llvmpy.core as lc
import sympy as sp
from sympy import Indexed, S
from sympy.printing.printer import Printer
......@@ -12,10 +13,18 @@ from pystencils.data_types import (
from pystencils.llvm.control_flow import Loop
# From Numba
def _call_sreg(builder, name):
module = builder.module
fnty = lc.Type.function(, ())
fn = module.get_or_insert_function(fnty, name=name)
return, ())
def generate_llvm(ast_node, module=None, builder=None):
"""Prints the ast as llvm code."""
if module is None:
module = ir.Module()
module = lc.Module()
if builder is None:
builder = ir.IRBuilder()
printer = LLVMPrinter(module, builder)
......@@ -330,3 +339,19 @@ class LLVMPrinter(Printer):
mro = "None"
raise TypeError("Unsupported type for LLVM JIT conversion: Expression:\"%s\", Type:\"%s\", MRO:%s"
% (expr, type(expr), mro))
# from:
'blockIdx': '',
'threadIdx': '',
'blockDim': '',
'gridDim': ''
def _print_ThreadIndexingSymbol(self, node):
symbol_name: str =
function_name, dimension = tuple(symbol_name.split("."))
function_name = self.INDEXING_FUNCTION_MAPPING[function_name]
name = f"{function_name}.{dimension}"
return self.builder.zext(_call_sreg(self.builder, name), self.integer)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment