From 2e6f3efe2f32c8a662c7bcb061a8496245b06dfd Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Sun, 22 Sep 2019 21:08:27 +0200
Subject: [PATCH] llvm: Use addressspace 1 (global memory) for nvvm_target

---
 pystencils/data_types.py   | 4 ++--
 pystencils/llvm/llvm.py    | 8 ++++----
 pystencils/llvm/llvmjit.py | 2 +-
 3 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/pystencils/data_types.py b/pystencils/data_types.py
index 86ce1747..cf240e6b 100644
--- a/pystencils/data_types.py
+++ b/pystencils/data_types.py
@@ -300,7 +300,7 @@ def ctypes_from_llvm(data_type):
         raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))
 
 
-def to_llvm_type(data_type):
+def to_llvm_type(data_type, nvvm_target=False):
     """
     Transforms a given type into ctypes
     :param data_type: Subclass of Type
@@ -309,7 +309,7 @@ def to_llvm_type(data_type):
     if not ir:
         raise _ir_importerror
     if isinstance(data_type, PointerType):
-        return to_llvm_type(data_type.base_type).as_pointer()
+        return to_llvm_type(data_type.base_type).as_pointer(1 if nvvm_target else 0)
     else:
         return to_llvm_type.map[data_type.numpy_dtype]
 
diff --git a/pystencils/llvm/llvm.py b/pystencils/llvm/llvm.py
index 4e730f04..bfccbdc2 100644
--- a/pystencils/llvm/llvm.py
+++ b/pystencils/llvm/llvm.py
@@ -21,13 +21,13 @@ def _call_sreg(builder, name):
     return builder.call(fn, ())
 
 
-def generate_llvm(ast_node, module=None, builder=None):
+def generate_llvm(ast_node, module=None, builder=None, target='cpu'):
     """Prints the ast as llvm code."""
     if module is None:
         module = lc.Module()
     if builder is None:
         builder = ir.IRBuilder()
-    printer = LLVMPrinter(module, builder)
+    printer = LLVMPrinter(module, builder, target=target)
     return printer._print(ast_node)
 
 
@@ -173,7 +173,7 @@ class LLVMPrinter(Printer):
         parameter_type = []
         parameters = func.get_parameters()
         for parameter in parameters:
-            parameter_type.append(to_llvm_type(parameter.symbol.dtype))
+            parameter_type.append(to_llvm_type(parameter.symbol.dtype, nvvm_target=self.target == 'gpu'))
         func_type = ir.FunctionType(return_type, tuple(parameter_type))
         name = func.function_name
         fn = ir.Function(self.module, func_type, name)
@@ -307,7 +307,7 @@ class LLVMPrinter(Printer):
                     self.builder.branch(after_block)
                     self.builder.position_at_end(false_block)
 
-            phi = self.builder.phi(to_llvm_type(get_type_of_expression(piece)))
+            phi = self.builder.phi(to_llvm_type(get_type_of_expression(piece), nvvm_target=self.target == 'gpu'))
             for (val, block) in phi_data:
                 phi.add_incoming(val, block)
             return phi
diff --git a/pystencils/llvm/llvmjit.py b/pystencils/llvm/llvmjit.py
index 85850944..4d85e418 100644
--- a/pystencils/llvm/llvmjit.py
+++ b/pystencils/llvm/llvmjit.py
@@ -101,7 +101,7 @@ def make_python_function_incomplete_params(kernel_function_node, argument_dict,
 
 def generate_and_jit(ast):
     target = 'gpu' if ast._backend == 'llvm_gpu' else 'cpu'
-    gen = generate_llvm(ast)
+    gen = generate_llvm(ast, target=target)
     if isinstance(gen, ir.Module):
         return compile_llvm(gen, target)
     else:
-- 
GitLab