diff --git a/pystencils/kernel_wrapper.py b/pystencils/kernel_wrapper.py
index 0e327711e5a355219cc2664ac9a6c8a02d88bc09..3494b52a9fd060bbccbfe493f165dcc7d63c8c06 100644
--- a/pystencils/kernel_wrapper.py
+++ b/pystencils/kernel_wrapper.py
@@ -1,11 +1,14 @@
-"""
-Light-weight wrapper around a compiled kernel
-"""
 import pystencils
 
 
 class KernelWrapper:
-    def __init__(self, kernel, parameters, ast_node):
+    """
+    Light-weight wrapper around a compiled kernel.
+
+    Can be called while still providing access to underlying AST.
+    """
+
+    def __init__(self, kernel, parameters, ast_node: pystencils.astnodes.KernelFunction):
         self.kernel = kernel
         self.parameters = parameters
         self.ast = ast_node