diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py
index ef69c06bd9921717fa745f990261fc3fca46dd7a..222b313424c67f031b580d6061375022dceda603 100644
--- a/src/pystencilssfg/context.py
+++ b/src/pystencilssfg/context.py
@@ -1,7 +1,6 @@
 from typing import Generator, Sequence
 
 from .configuration import SfgCodeStyle
-from .visitors import CollectIncludes
 from .source_components import (
     SfgHeaderInclude,
     SfgKernelNamespace,
@@ -156,9 +155,6 @@ class SfgContext:
         self._functions[func.name] = func
         self._declarations_ordered.append(func)
 
-        for incl in CollectIncludes().visit(func):
-            self.add_include(incl)
-
     # ----------------------------------------------------------------------------------------------
     #   Classes
     # ----------------------------------------------------------------------------------------------
@@ -176,9 +172,6 @@ class SfgContext:
         self._classes[cls.class_name] = cls
         self._declarations_ordered.append(cls)
 
-        for incl in CollectIncludes().visit(cls):
-            self.add_include(incl)
-
     # ----------------------------------------------------------------------------------------------
     #   Declarations in order of addition
     # ----------------------------------------------------------------------------------------------
diff --git a/src/pystencilssfg/emission/header_source_pair.py b/src/pystencilssfg/emission/header_source_pair.py
index b21b9952a7c77c82187cdbdb8fb9e2ee2ce1c8dd..20e590393a97eb567d025a14cfa884d1b83a7d9a 100644
--- a/src/pystencilssfg/emission/header_source_pair.py
+++ b/src/pystencilssfg/emission/header_source_pair.py
@@ -2,6 +2,7 @@ from os import path, makedirs
 
 from ..configuration import SfgOutputSpec
 from ..context import SfgContext
+from .prepare import prepare_context
 from .printers import SfgHeaderPrinter, SfgImplPrinter
 
 from .clang_format import invoke_clang_format
@@ -20,10 +21,12 @@ class HeaderSourcePairEmitter:
     def output_files(self) -> tuple[str, str]:
         return (
             path.join(self._output_directory, self._header_filename),
-            path.join(self._output_directory, self._impl_filename)
+            path.join(self._output_directory, self._impl_filename),
         )
 
     def write_files(self, ctx: SfgContext):
+        ctx = prepare_context(ctx)
+
         header_printer = SfgHeaderPrinter(ctx, self._ospec)
         impl_printer = SfgImplPrinter(ctx, self._ospec)
 
@@ -35,8 +38,8 @@ class HeaderSourcePairEmitter:
 
         makedirs(self._output_directory, exist_ok=True)
 
-        with open(self._ospec.get_header_filepath(), 'w') as headerfile:
+        with open(self._ospec.get_header_filepath(), "w") as headerfile:
             headerfile.write(header)
 
-        with open(self._ospec.get_impl_filepath(), 'w') as cppfile:
+        with open(self._ospec.get_impl_filepath(), "w") as cppfile:
             cppfile.write(impl)
diff --git a/src/pystencilssfg/emission/prepare.py b/src/pystencilssfg/emission/prepare.py
new file mode 100644
index 0000000000000000000000000000000000000000..28493bbb99abecad9dc0fbaf62d0cb1582373fac
--- /dev/null
+++ b/src/pystencilssfg/emission/prepare.py
@@ -0,0 +1,20 @@
+from ..context import SfgContext
+
+from ..visitors import CollectIncludes
+
+
+def prepare_context(ctx: SfgContext):
+    """Prepares a populated context for printing. Make sure to run this function on the
+    [SfgContext][pystencilssfg.SfgContext] before passing it to a printer.
+
+    Steps:
+     - Collection of includes: All defined functions and classes are traversed to collect all required
+       header includes
+    """
+
+    #   Collect all includes
+    required_includes = CollectIncludes().visit(ctx)
+    for incl in required_includes:
+        ctx.add_include(incl)
+
+    return ctx
diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py
index e753b3717f49778af8fb65a423af3d4b6b53ec72..bf3a5147b47b3dd62bad99e450c1c89335a9cc49 100644
--- a/src/pystencilssfg/source_components.py
+++ b/src/pystencilssfg/source_components.py
@@ -440,9 +440,14 @@ class SfgClass:
     def members(
         self, visibility: SfgVisibility | None = None
     ) -> Generator[SfgClassMember, None, None]:
-        yield from chain.from_iterable(
-            b.members() for b in filter(lambda b: b.visibility == visibility, self._blocks)
-        )
+        if visibility is None:
+            yield from chain.from_iterable(
+                b.members() for b in self._blocks
+            )
+        else:
+            yield from chain.from_iterable(
+                b.members() for b in filter(lambda b: b.visibility == visibility, self._blocks)
+            )
 
     def definitions(
         self, visibility: SfgVisibility | None = None
diff --git a/src/pystencilssfg/visitors/collectors.py b/src/pystencilssfg/visitors/collectors.py
index 9bd1774ff393ae09b4883914f9ca745776dbd76a..fef8955bb8a5eab7c80d2c812b41e64ce7f938ba 100644
--- a/src/pystencilssfg/visitors/collectors.py
+++ b/src/pystencilssfg/visitors/collectors.py
@@ -7,7 +7,14 @@ from functools import reduce
 from .dispatcher import visitor
 from ..exceptions import SfgException
 from ..tree import SfgCallTreeNode
-from ..source_components import SfgFunction, SfgClass, SfgConstructor, SfgMemberVariable
+from ..source_components import (
+    SfgFunction,
+    SfgClass,
+    SfgConstructor,
+    SfgMemberVariable,
+    SfgInClassDefinition,
+)
+from ..context import SfgContext
 
 if TYPE_CHECKING:
     from ..source_components import SfgHeaderInclude
@@ -18,10 +25,23 @@ class CollectIncludes:
     def visit(self, obj: object) -> set[SfgHeaderInclude]:
         raise SfgException(f"Can't collect includes from object of type {type(obj)}")
 
+    @visit.case(SfgContext)
+    def context(self, ctx: SfgContext) -> set[SfgHeaderInclude]:
+        includes = set()
+        for func in ctx.functions():
+            includes |= self.visit(func)
+
+        for cls in ctx.classes():
+            includes |= self.visit(cls)
+
+        return includes
+
     @visit.case(SfgCallTreeNode)
     def tree_node(self, node: SfgCallTreeNode) -> set[SfgHeaderInclude]:
         return reduce(
-            lambda accu, child: accu | self.visit(child), node.children, node.required_includes
+            lambda accu, child: accu | self.visit(child),
+            node.children,
+            node.required_includes,
         )
 
     @visit.case(SfgFunction)
@@ -43,3 +63,7 @@ class CollectIncludes:
     @visit.case(SfgMemberVariable)
     def sfg_member_var(self, var: SfgMemberVariable) -> set[SfgHeaderInclude]:
         return var.required_includes
+
+    @visit.case(SfgInClassDefinition)
+    def sfg_cls_def(self, _: SfgInClassDefinition) -> set[SfgHeaderInclude]:
+        return set()