diff --git a/pystencils/opencl/opencljit.py b/pystencils/opencl/opencljit.py
index 051dc1ec5e0d6baea989d0a42954ef154a6bcffb..9393873381c4e2d6d95e2d4470f774dca64cd94b 100644
--- a/pystencils/opencl/opencljit.py
+++ b/pystencils/opencl/opencljit.py
@@ -28,6 +28,20 @@ def init_globally(device_index=0):
     _global_cl_queue = cl.CommandQueue(_global_cl_ctx)
 
 
+def init_globally_with_context(opencl_ctx, opencl_queue):
+    global _global_cl_ctx
+    global _global_cl_queue
+    _global_cl_ctx = opencl_ctx
+    _global_cl_queue = opencl_queue
+
+
+def clear_global_ctx():
+    global _global_cl_ctx
+    global _global_cl_queue
+    _global_cl_ctx = None
+    _global_cl_queue = None
+
+
 def make_python_function(kernel_function_node, opencl_queue, opencl_ctx, argument_dict=None, custom_backend=None):
     """
     Creates a **OpenCL** kernel function from an abstract syntax tree which