From 2e6f3efe2f32c8a662c7bcb061a8496245b06dfd Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Sun, 22 Sep 2019 21:08:27 +0200 Subject: [PATCH] llvm: Use addressspace 1 (global memory) for nvvm_target --- pystencils/data_types.py | 4 ++-- pystencils/llvm/llvm.py | 8 ++++---- pystencils/llvm/llvmjit.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 86ce1747..cf240e6b 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -300,7 +300,7 @@ def ctypes_from_llvm(data_type): raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type)) -def to_llvm_type(data_type): +def to_llvm_type(data_type, nvvm_target=False): """ Transforms a given type into ctypes :param data_type: Subclass of Type @@ -309,7 +309,7 @@ def to_llvm_type(data_type): if not ir: raise _ir_importerror if isinstance(data_type, PointerType): - return to_llvm_type(data_type.base_type).as_pointer() + return to_llvm_type(data_type.base_type).as_pointer(1 if nvvm_target else 0) else: return to_llvm_type.map[data_type.numpy_dtype] diff --git a/pystencils/llvm/llvm.py b/pystencils/llvm/llvm.py index 4e730f04..bfccbdc2 100644 --- a/pystencils/llvm/llvm.py +++ b/pystencils/llvm/llvm.py @@ -21,13 +21,13 @@ def _call_sreg(builder, name): return builder.call(fn, ()) -def generate_llvm(ast_node, module=None, builder=None): +def generate_llvm(ast_node, module=None, builder=None, target='cpu'): """Prints the ast as llvm code.""" if module is None: module = lc.Module() if builder is None: builder = ir.IRBuilder() - printer = LLVMPrinter(module, builder) + printer = LLVMPrinter(module, builder, target=target) return printer._print(ast_node) @@ -173,7 +173,7 @@ class LLVMPrinter(Printer): parameter_type = [] parameters = func.get_parameters() for parameter in parameters: - parameter_type.append(to_llvm_type(parameter.symbol.dtype)) + parameter_type.append(to_llvm_type(parameter.symbol.dtype, nvvm_target=self.target == 'gpu')) func_type = ir.FunctionType(return_type, tuple(parameter_type)) name = func.function_name fn = ir.Function(self.module, func_type, name) @@ -307,7 +307,7 @@ class LLVMPrinter(Printer): self.builder.branch(after_block) self.builder.position_at_end(false_block) - phi = self.builder.phi(to_llvm_type(get_type_of_expression(piece))) + phi = self.builder.phi(to_llvm_type(get_type_of_expression(piece), nvvm_target=self.target == 'gpu')) for (val, block) in phi_data: phi.add_incoming(val, block) return phi diff --git a/pystencils/llvm/llvmjit.py b/pystencils/llvm/llvmjit.py index 85850944..4d85e418 100644 --- a/pystencils/llvm/llvmjit.py +++ b/pystencils/llvm/llvmjit.py @@ -101,7 +101,7 @@ def make_python_function_incomplete_params(kernel_function_node, argument_dict, def generate_and_jit(ast): target = 'gpu' if ast._backend == 'llvm_gpu' else 'cpu' - gen = generate_llvm(ast) + gen = generate_llvm(ast, target=target) if isinstance(gen, ir.Module): return compile_llvm(gen, target) else: -- GitLab