Commit 2e6f3efe authored by Stephan Seitz's avatar Stephan Seitz
Browse files

llvm: Use addressspace 1 (global memory) for nvvm_target

parent a526fe47
......@@ -300,7 +300,7 @@ def ctypes_from_llvm(data_type):
raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))
def to_llvm_type(data_type):
def to_llvm_type(data_type, nvvm_target=False):
"""
Transforms a given type into ctypes
:param data_type: Subclass of Type
......@@ -309,7 +309,7 @@ def to_llvm_type(data_type):
if not ir:
raise _ir_importerror
if isinstance(data_type, PointerType):
return to_llvm_type(data_type.base_type).as_pointer()
return to_llvm_type(data_type.base_type).as_pointer(1 if nvvm_target else 0)
else:
return to_llvm_type.map[data_type.numpy_dtype]
......
......@@ -21,13 +21,13 @@ def _call_sreg(builder, name):
return builder.call(fn, ())
def generate_llvm(ast_node, module=None, builder=None):
def generate_llvm(ast_node, module=None, builder=None, target='cpu'):
"""Prints the ast as llvm code."""
if module is None:
module = lc.Module()
if builder is None:
builder = ir.IRBuilder()
printer = LLVMPrinter(module, builder)
printer = LLVMPrinter(module, builder, target=target)
return printer._print(ast_node)
......@@ -173,7 +173,7 @@ class LLVMPrinter(Printer):
parameter_type = []
parameters = func.get_parameters()
for parameter in parameters:
parameter_type.append(to_llvm_type(parameter.symbol.dtype))
parameter_type.append(to_llvm_type(parameter.symbol.dtype, nvvm_target=self.target == 'gpu'))
func_type = ir.FunctionType(return_type, tuple(parameter_type))
name = func.function_name
fn = ir.Function(self.module, func_type, name)
......@@ -307,7 +307,7 @@ class LLVMPrinter(Printer):
self.builder.branch(after_block)
self.builder.position_at_end(false_block)
phi = self.builder.phi(to_llvm_type(get_type_of_expression(piece)))
phi = self.builder.phi(to_llvm_type(get_type_of_expression(piece), nvvm_target=self.target == 'gpu'))
for (val, block) in phi_data:
phi.add_incoming(val, block)
return phi
......
......@@ -101,7 +101,7 @@ def make_python_function_incomplete_params(kernel_function_node, argument_dict,
def generate_and_jit(ast):
target = 'gpu' if ast._backend == 'llvm_gpu' else 'cpu'
gen = generate_llvm(ast)
gen = generate_llvm(ast, target=target)
if isinstance(gen, ir.Module):
return compile_llvm(gen, target)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment