From 8f789bf6ce0033c57b2224532d9039b0636f1f68 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 1 Dec 2023 12:07:28 +0100
Subject: [PATCH] integrated classes

---
 integration/.gitignore                        |   1 +
 integration/test_classes.py                   |  45 ++++
 src/pystencilssfg/composer.py                 |   4 +-
 src/pystencilssfg/context.py                  |  64 +++--
 .../emitters/classes_printing.py              |  67 +++++
 .../emitters/header_source_pair.py            |  15 +-
 src/pystencilssfg/emitters/jinja_filters.py   |  25 +-
 .../templates/HeaderSourcePair.tmpl.cpp       |  18 +-
 .../templates/HeaderSourcePair.tmpl.h         |   4 +
 src/pystencilssfg/source_components.py        | 229 ++++++++++++++++--
 src/pystencilssfg/tree/deferred_nodes.py      |   6 +-
 src/pystencilssfg/tree/dispatcher.py          |  44 ----
 src/pystencilssfg/visitors/__init__.py        |  10 +
 src/pystencilssfg/visitors/collectors.py      |  45 ++++
 src/pystencilssfg/visitors/dispatcher.py      |  73 ++++++
 .../visitors.py => visitors/tree_visitors.py} |  51 ++--
 test_classses_out/test_classes.cpp            |  32 +++
 test_classses_out/test_classes.h              |  24 ++
 18 files changed, 631 insertions(+), 126 deletions(-)
 create mode 100644 integration/.gitignore
 create mode 100644 integration/test_classes.py
 create mode 100644 src/pystencilssfg/emitters/classes_printing.py
 delete mode 100644 src/pystencilssfg/tree/dispatcher.py
 create mode 100644 src/pystencilssfg/visitors/__init__.py
 create mode 100644 src/pystencilssfg/visitors/collectors.py
 create mode 100644 src/pystencilssfg/visitors/dispatcher.py
 rename src/pystencilssfg/{tree/visitors.py => visitors/tree_visitors.py} (73%)
 create mode 100644 test_classses_out/test_classes.cpp
 create mode 100644 test_classses_out/test_classes.h

diff --git a/integration/.gitignore b/integration/.gitignore
new file mode 100644
index 0000000..c585e19
--- /dev/null
+++ b/integration/.gitignore
@@ -0,0 +1 @@
+out
\ No newline at end of file
diff --git a/integration/test_classes.py b/integration/test_classes.py
new file mode 100644
index 0000000..446921c
--- /dev/null
+++ b/integration/test_classes.py
@@ -0,0 +1,45 @@
+# type: ignore
+from pystencilssfg import SourceFileGenerator, SfgConfiguration
+from pystencilssfg.source_concepts import SrcObject
+from pystencilssfg.source_components import SfgClass, SfgMemberVariable, SfgConstructor, SfgMethod, SfgVisibility
+
+from pystencils import fields, kernel
+
+sfg_config = SfgConfiguration(
+    output_directory="out/test_classes",
+    outer_namespace="gen_code"
+)
+
+f, g = fields("f, g(1): double[2D]")
+
+with SourceFileGenerator(sfg_config) as sfg:
+
+    @kernel
+    def assignments():
+        f[0,0] @= 3 * g[0,0]
+
+    khandle = sfg.kernels.create(assignments)
+
+    cls = SfgClass("MyClass")
+    cls.add_method(SfgMethod(
+        "callKernel",
+        sfg.call(khandle),
+        visibility=SfgVisibility.PUBLIC
+    ))
+
+    cls.add_member_variable(
+        SfgMemberVariable(
+            "stuff", "std::vector< int >",
+            SfgVisibility.PRIVATE
+        )
+    )
+
+    cls.add_constructor(
+        SfgConstructor(
+            [SrcObject("std::vector< int > &", "stuff")],
+            ["stuff_(stuff)"],
+            visibility=SfgVisibility.PUBLIC
+        )
+    )
+
+    sfg.context.add_class(cls)
diff --git a/src/pystencilssfg/composer.py b/src/pystencilssfg/composer.py
index d405b66..96f59ef 100644
--- a/src/pystencilssfg/composer.py
+++ b/src/pystencilssfg/composer.py
@@ -77,7 +77,7 @@ class SfgComposer:
         else:
             raise TypeError("Invalid type of argument `ast_or_kernel_handle`!")
 
-        func = SfgFunction(self._ctx, name, tree)
+        func = SfgFunction(name, tree)
         self._ctx.add_function(func)
 
     def function(self, name: str):
@@ -96,7 +96,7 @@ class SfgComposer:
 
         def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder):
             tree = make_sequence(*args)
-            func = SfgFunction(self._ctx, name, tree)
+            func = SfgFunction(name, tree)
             self._ctx.add_function(func)
 
         return sequencer
diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py
index d1add3c..48ecfd2 100644
--- a/src/pystencilssfg/context.py
+++ b/src/pystencilssfg/context.py
@@ -1,16 +1,23 @@
 from typing import Generator, Sequence
 
 from .configuration import SfgCodeStyle
-from .tree.visitors import CollectIncludes
-from .source_components import SfgHeaderInclude, SfgKernelNamespace, SfgFunction
+from .visitors import CollectIncludes
+from .source_components import (
+    SfgHeaderInclude,
+    SfgKernelNamespace,
+    SfgFunction,
+    SfgClass,
+)
 from .exceptions import SfgException
 
 
 class SfgContext:
-    def __init__(self,
-                 outer_namespace: str | None = None,
-                 codestyle: SfgCodeStyle = SfgCodeStyle(),
-                 argv: Sequence[str] | None = None):
+    def __init__(
+        self,
+        outer_namespace: str | None = None,
+        codestyle: SfgCodeStyle = SfgCodeStyle(),
+        argv: Sequence[str] | None = None,
+    ):
         self._argv = argv
         self._default_kernel_namespace = SfgKernelNamespace(self, "kernels")
 
@@ -23,8 +30,11 @@ class SfgContext:
         self._prelude: str = ""
         self._includes: set[SfgHeaderInclude] = set()
         self._definitions: list[str] = []
-        self._kernel_namespaces = {self._default_kernel_namespace.name: self._default_kernel_namespace}
+        self._kernel_namespaces = {
+            self._default_kernel_namespace.name: self._default_kernel_namespace
+        }
         self._functions: dict[str, SfgFunction] = dict()
+        self._classes: dict[str, SfgClass] = dict()
 
     @property
     def argv(self) -> Sequence[str]:
@@ -48,11 +58,16 @@ class SfgContext:
     @property
     def fully_qualified_namespace(self) -> str | None:
         match (self.outer_namespace, self.inner_namespace):
-            case None, None: return None
-            case outer, None: return outer
-            case None, inner: return inner
-            case outer, inner: return f"{outer}::{inner}"
-            case _: assert False
+            case None, None:
+                return None
+            case outer, None:
+                return outer
+            case None, inner:
+                return inner
+            case outer, inner:
+                return f"{outer}::{inner}"
+            case _:
+                assert False
 
     @property
     def codestyle(self) -> SfgCodeStyle:
@@ -127,10 +142,29 @@ class SfgContext:
     def get_function(self, name: str) -> SfgFunction | None:
         return self._functions.get(name, None)
 
-    def add_function(self, func: SfgFunction) -> None:
+    def add_function(self, func: SfgFunction):
         if func.name in self._functions:
-            raise ValueError(f"Duplicate function: {func.name}")
+            raise SfgException(f"Duplicate function: {func.name}")
 
         self._functions[func.name] = func
-        for incl in CollectIncludes().visit(func._tree):
+        for incl in CollectIncludes().visit(func):
+            self.add_include(incl)
+
+    # ----------------------------------------------------------------------------------------------
+    #   Classes
+    # ----------------------------------------------------------------------------------------------
+
+    def classes(self) -> Generator[SfgClass, None, None]:
+        yield from self._classes.values()
+
+    def get_class(self, name: str) -> SfgClass | None:
+        return self._classes.get(name, None)
+
+    def add_class(self, cls: SfgClass):
+        if cls.class_name in self._classes:
+            raise SfgException(f"Duplicate class: {cls.class_name}")
+
+        self._classes[cls.class_name] = cls
+
+        for incl in CollectIncludes().visit(cls):
             self.add_include(incl)
diff --git a/src/pystencilssfg/emitters/classes_printing.py b/src/pystencilssfg/emitters/classes_printing.py
new file mode 100644
index 0000000..f80c7ca
--- /dev/null
+++ b/src/pystencilssfg/emitters/classes_printing.py
@@ -0,0 +1,67 @@
+from ..context import SfgContext
+from ..visitors import visitor
+from ..source_components import (
+    SfgClass,
+    SfgConstructor,
+    SfgMemberVariable,
+    SfgMethod,
+    SfgVisibility,
+)
+from ..exceptions import SfgException
+
+
+class ClassDeclarationPrinter:
+    def __init__(self, ctx: SfgContext):
+        self._codestyle = ctx.codestyle
+
+    def print(self, cls: SfgClass):
+        return self.visit(cls, cls)
+
+    @visitor
+    def visit(self, obj: object, cls: SfgClass) -> str:
+        raise SfgException("Can't print this.")
+
+    @visit.case(SfgClass)
+    def sfg_class(self, cls: SfgClass, _: SfgClass):
+        code = f"{cls.class_keyword} {cls.class_name} \n"
+
+        if cls.base_classes:
+            code += f" : {','.join(cls.base_classes)}\n"
+
+        code += "{\n"
+        for visibility in (
+            SfgVisibility.DEFAULT,
+            SfgVisibility.PUBLIC,
+            SfgVisibility.PRIVATE,
+        ):
+            if visibility != SfgVisibility.DEFAULT:
+                code += f"\n{visibility}:\n"
+            for member in cls.members(visibility):
+                code += self._codestyle.indent(self.visit(member, cls)) + "\n"
+        code += "};\n"
+
+        return code
+
+    @visit.case(SfgConstructor)
+    def sfg_constructor(self, constr: SfgConstructor, cls: SfgClass):
+        code = f"{cls.class_name} ("
+        code += ", ".join(f"{param.dtype} {param.name}" for param in constr.parameters)
+        code += ")\n"
+        if constr.initializers:
+            code += "  : " + ", ".join(constr.initializers) + "\n"
+        if constr.body:
+            code += "{\n" + self._codestyle.indent(constr.body) + "\n}\n"
+        else:
+            code += "{ }\n"
+        return code
+
+    @visit.case(SfgMemberVariable)
+    def sfg_member_var(self, var: SfgMemberVariable, _: SfgClass):
+        return f"{var.dtype} {var.name};"
+
+    @visit.case(SfgMethod)
+    def sfg_method(self, method: SfgMethod, _: SfgClass):
+        code = f"void {method.name} ("
+        code += ", ".join(f"{param.dtype} {param.name}" for param in method.parameters)
+        code += ");"
+        return code
diff --git a/src/pystencilssfg/emitters/header_source_pair.py b/src/pystencilssfg/emitters/header_source_pair.py
index 33653ab..553f84b 100644
--- a/src/pystencilssfg/emitters/header_source_pair.py
+++ b/src/pystencilssfg/emitters/header_source_pair.py
@@ -1,6 +1,6 @@
 from jinja2 import Environment, PackageLoader, StrictUndefined
 
-from os import path
+from os import path, makedirs
 
 from ..configuration import SfgOutputSpec
 from ..context import SfgContext
@@ -31,12 +31,13 @@ class HeaderSourcePairEmitter:
             'source_filename': self._impl_filename,
             'basename': self._basename,
             'prelude_comment': ctx.prelude_comment,
-            'definitions': list(ctx.definitions()),
+            'definitions': tuple(ctx.definitions()),
             'fq_namespace': fq_namespace,
-            'public_includes': list(incl.get_code() for incl in ctx.includes() if not incl.private),
-            'private_includes': list(incl.get_code() for incl in ctx.includes() if incl.private),
-            'kernel_namespaces': list(ctx.kernel_namespaces()),
-            'functions': list(ctx.functions())
+            'public_includes': tuple(incl.get_code() for incl in ctx.includes() if not incl.private),
+            'private_includes': tuple(incl.get_code() for incl in ctx.includes() if incl.private),
+            'kernel_namespaces': tuple(ctx.kernel_namespaces()),
+            'functions': tuple(ctx.functions()),
+            'classes': tuple(ctx.classes())
         }
 
         template_name = "HeaderSourcePair"
@@ -52,6 +53,8 @@ class HeaderSourcePairEmitter:
         header = env.get_template(f"{template_name}.tmpl.h").render(**jinja_context)
         source = env.get_template(f"{template_name}.tmpl.cpp").render(**jinja_context)
 
+        makedirs(self._output_directory, exist_ok=True)
+
         with open(self._ospec.get_header_filepath(), 'w') as headerfile:
             headerfile.write(header)
 
diff --git a/src/pystencilssfg/emitters/jinja_filters.py b/src/pystencilssfg/emitters/jinja_filters.py
index f6ee5e1..31169f4 100644
--- a/src/pystencilssfg/emitters/jinja_filters.py
+++ b/src/pystencilssfg/emitters/jinja_filters.py
@@ -5,7 +5,8 @@ from pystencils.astnodes import KernelFunction
 from pystencils import Backend
 from pystencils.backends import generate_c
 
-from pystencilssfg.source_components import SfgFunction
+from pystencilssfg.source_components import SfgFunction, SfgClass
+from .classes_printing import ClassDeclarationPrinter
 
 
 def format_prelude_comment(prelude_comment: str):
@@ -26,12 +27,22 @@ def generate_function_parameter_list(ctx, func: SfgFunction):
     return ", ".join(f"{param.dtype} {param.name}" for param in params)
 
 
-def generate_function_body(func: SfgFunction):
-    return func.get_code()
+@pass_context
+def generate_function_body(ctx, func: SfgFunction):
+    return func.get_code(ctx["ctx"])
+
+
+@pass_context
+def print_class_declaration(ctx, cls: SfgClass):
+    return ClassDeclarationPrinter(ctx["ctx"]).print(cls)
 
 
 def add_filters_to_jinja(jinja_env):
-    jinja_env.filters['format_prelude_comment'] = format_prelude_comment
-    jinja_env.filters['generate_kernel_definition'] = generate_kernel_definition
-    jinja_env.filters['generate_function_parameter_list'] = generate_function_parameter_list
-    jinja_env.filters['generate_function_body'] = generate_function_body
+    jinja_env.filters["format_prelude_comment"] = format_prelude_comment
+    jinja_env.filters["generate_kernel_definition"] = generate_kernel_definition
+    jinja_env.filters[
+        "generate_function_parameter_list"
+    ] = generate_function_parameter_list
+    jinja_env.filters["generate_function_body"] = generate_function_body
+
+    jinja_env.filters["print_class_declaration"] = print_class_declaration
diff --git a/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.cpp b/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.cpp
index b1d37cc..48c9653 100644
--- a/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.cpp
+++ b/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.cpp
@@ -31,12 +31,26 @@ namespace {{ kns.name }} {
 *************************************************************************************/
 
 {% for function in functions %}
-
 void {{ function.name }} ( {{ function | generate_function_parameter_list }} ) { 
-  {{ function | generate_function_body | indent(2) }}
+  {{ function | generate_function_body | indent(ctx.codestyle.indent_width) }}
 }
 
+
+{% endfor -%}
+
+/*************************************************************************************
+ *                                Class Methods
+*************************************************************************************/
+
+{% for cls in classes %}
+{% for method in cls.methods() %}
+void {{ cls.class_name }}::{{ method.name }} ( {{ method | generate_function_parameter_list }} ) { 
+  {{ method | generate_function_body | indent(ctx.codestyle.indent_width) }}
+}
+
+
 {% endfor %}
+{% endfor -%}
 
 {% if fq_namespace is not none %}
 } // namespace {{fq_namespace}}
diff --git a/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.h b/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.h
index 40ce1ed..95b3ed8 100644
--- a/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.h
+++ b/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.h
@@ -18,6 +18,10 @@
 namespace {{fq_namespace}} {
 {% endif %}
 
+{% for cls in classes %}
+{{ cls | print_class_declaration }}
+{% endfor %}
+
 {% for function in functions %}
 void {{ function.name }} ( {{ function | generate_function_parameter_list }} );
 {% endfor %}
diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py
index 17cd3c9..af89aff 100644
--- a/src/pystencilssfg/source_components.py
+++ b/src/pystencilssfg/source_components.py
@@ -1,18 +1,26 @@
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Sequence
+from abc import ABC
+from enum import Enum, auto
+from typing import TYPE_CHECKING, Sequence, Generator
 from dataclasses import replace
 
 from pystencils import CreateKernelConfig, create_kernel
 from pystencils.astnodes import KernelFunction
 
+from .types import SrcType
+from .source_concepts import SrcObject
+from .exceptions import SfgException
+
 if TYPE_CHECKING:
     from .context import SfgContext
     from .tree import SfgCallTreeNode
 
 
 class SfgHeaderInclude:
-    def __init__(self, header_file: str, system_header: bool = False, private: bool = False):
+    def __init__(
+        self, header_file: str, system_header: bool = False, private: bool = False
+    ):
         self._header_file = header_file
         self._system_header = system_header
         self._private = private
@@ -35,10 +43,12 @@ class SfgHeaderInclude:
         return hash((self._header_file, self._system_header, self._private))
 
     def __eq__(self, other: object) -> bool:
-        return (isinstance(other, SfgHeaderInclude)
-                and self._header_file == other._header_file
-                and self._system_header == other._system_header
-                and self._private == other._private)
+        return (
+            isinstance(other, SfgHeaderInclude)
+            and self._header_file == other._header_file
+            and self._system_header == other._system_header
+            and self._private == other._private
+        )
 
 
 class SfgKernelNamespace:
@@ -64,7 +74,9 @@ class SfgKernelNamespace:
             astname = ast.function_name
 
         if astname in self._asts:
-            raise ValueError(f"Duplicate ASTs: An AST with name {astname} already exists in namespace {self._name}")
+            raise ValueError(
+                f"Duplicate ASTs: An AST with name {astname} already exists in namespace {self._name}"
+            )
 
         if name is not None:
             ast.function_name = name
@@ -73,7 +85,12 @@ class SfgKernelNamespace:
 
         return SfgKernelHandle(self._ctx, astname, self, ast.get_parameters())
 
-    def create(self, assignments, name: str | None = None, config: CreateKernelConfig | None = None):
+    def create(
+        self,
+        assignments,
+        name: str | None = None,
+        config: CreateKernelConfig | None = None,
+    ):
         """Creates a new pystencils kernel from a list of assignments and a configuration.
         This is a wrapper around
         [`pystencils.create_kernel`](
@@ -87,7 +104,9 @@ class SfgKernelNamespace:
 
         if name is not None:
             if name in self._asts:
-                raise ValueError(f"Duplicate ASTs: An AST with name {name} already exists in namespace {self._name}")
+                raise ValueError(
+                    f"Duplicate ASTs: An AST with name {name} already exists in namespace {self._name}"
+                )
             config = replace(config, function_name=name)
 
         # type: ignore
@@ -96,7 +115,13 @@ class SfgKernelNamespace:
 
 
 class SfgKernelHandle:
-    def __init__(self, ctx, name: str, namespace: SfgKernelNamespace, parameters: Sequence[KernelFunction.Parameter]):
+    def __init__(
+        self,
+        ctx,
+        name: str,
+        namespace: SfgKernelNamespace,
+        parameters: Sequence[KernelFunction.Parameter],
+    ):
         self._ctx = ctx
         self._name = name
         self._namespace = namespace
@@ -122,8 +147,10 @@ class SfgKernelHandle:
     @property
     def fully_qualified_name(self):
         match self._ctx.fully_qualified_namespace:
-            case None: return f"{self.kernel_namespace.name}::{self.kernel_name}"
-            case fqn: return f"{fqn}::{self.kernel_namespace.name}::{self.kernel_name}"
+            case None:
+                return f"{self.kernel_namespace.name}::{self.kernel_name}"
+            case fqn:
+                return f"{fqn}::{self.kernel_namespace.name}::{self.kernel_name}"
 
     @property
     def parameters(self):
@@ -139,14 +166,13 @@ class SfgKernelHandle:
 
 
 class SfgFunction:
-    def __init__(self, ctx: SfgContext, name: str, tree: SfgCallTreeNode):
-        self._ctx = ctx
+    def __init__(self, name: str, tree: SfgCallTreeNode):
         self._name = name
         self._tree = tree
 
-        from .tree.visitors import ExpandingParameterCollector
+        from .visitors.tree_visitors import ExpandingParameterCollector
 
-        param_collector = ExpandingParameterCollector(self._ctx)
+        param_collector = ExpandingParameterCollector()
         self._parameters = param_collector.visit(self._tree)
 
     @property
@@ -161,5 +187,172 @@ class SfgFunction:
     def tree(self):
         return self._tree
 
-    def get_code(self):
-        return self._tree.get_code(self._ctx)
+    def get_code(self, ctx: SfgContext):
+        return self._tree.get_code(ctx)
+
+
+class SfgVisibility(Enum):
+    DEFAULT = auto()
+    PRIVATE = auto()
+    PUBLIC = auto()
+
+    def __str__(self) -> str:
+        match self:
+            case SfgVisibility.DEFAULT:
+                return ""
+            case SfgVisibility.PRIVATE:
+                return "private"
+            case SfgVisibility.PUBLIC:
+                return "public"
+
+
+class SfgClassKeyword(Enum):
+    STRUCT = auto()
+    CLASS = auto()
+
+    def __str__(self) -> str:
+        match self:
+            case SfgClassKeyword.STRUCT:
+                return "struct"
+            case SfgClassKeyword.CLASS:
+                return "class"
+
+
+class SfgClassMember(ABC):
+    def __init__(self, visibility: SfgVisibility):
+        self._visibility = visibility
+
+    @property
+    def visibility(self) -> SfgVisibility:
+        return self._visibility
+
+
+class SfgMemberVariable(SrcObject, SfgClassMember):
+    def __init__(
+        self,
+        name: str,
+        type: SrcType,
+        visibility: SfgVisibility = SfgVisibility.DEFAULT,
+    ):
+        SrcObject.__init__(self, type, name)
+        SfgClassMember.__init__(self, visibility)
+
+
+class SfgMethod(SfgFunction, SfgClassMember):
+    def __init__(
+        self,
+        name: str,
+        tree: SfgCallTreeNode,
+        visibility: SfgVisibility = SfgVisibility.DEFAULT,
+    ):
+        SfgFunction.__init__(self, name, tree)
+        SfgClassMember.__init__(self, visibility)
+
+
+class SfgConstructor(SfgClassMember):
+    def __init__(
+        self,
+        parameters: Sequence[SrcObject] = (),
+        initializers: Sequence[str] = (),
+        body: str = "",
+        visibility: SfgVisibility = SfgVisibility.DEFAULT,
+    ):
+        SfgClassMember.__init__(self, visibility)
+        self._parameters = tuple(parameters)
+        self._initializers = tuple(initializers)
+        self._body = body
+
+    @property
+    def parameters(self) -> tuple[SrcObject, ...]:
+        return self._parameters
+
+    @property
+    def initializers(self) -> tuple[str, ...]:
+        return self._initializers
+
+    @property
+    def body(self) -> str:
+        return self._body
+
+
+class SfgClass:
+    def __init__(
+        self,
+        class_name: str,
+        class_keyword: SfgClassKeyword = SfgClassKeyword.CLASS,
+        bases: Sequence[str] = (),
+    ):
+        self._class_name = class_name
+        self._class_keyword = class_keyword
+        self._bases_classes = tuple(bases)
+
+        self._constructors: list[SfgConstructor] = []
+        self._methods: dict[str, SfgMethod] = dict()
+        self._member_vars: dict[str, SfgMemberVariable] = dict()
+
+    @property
+    def class_name(self) -> str:
+        return self._class_name
+
+    @property
+    def base_classes(self) -> tuple[str, ...]:
+        return self._bases_classes
+
+    @property
+    def class_keyword(self) -> SfgClassKeyword:
+        return self._class_keyword
+
+    def members(
+        self, visibility: SfgVisibility | None = None
+    ) -> Generator[SfgClassMember, None, None]:
+        yield from self.member_variables(visibility)
+        yield from self.constructors(visibility)
+        yield from self.methods(visibility)
+
+    def constructors(
+        self, visibility: SfgVisibility | None = None
+    ) -> Generator[SfgConstructor, None, None]:
+        if visibility is not None:
+            yield from filter(lambda m: m.visibility == visibility, self._constructors)
+        else:
+            yield from self._constructors
+
+    def add_constructor(self, constr: SfgConstructor):
+        #   TODO: Check for signature conflicts?
+        self._constructors.append(constr)
+
+    def methods(
+        self, visibility: SfgVisibility | None = None
+    ) -> Generator[SfgMethod, None, None]:
+        if visibility is not None:
+            yield from filter(
+                lambda m: m.visibility == visibility, self._methods.values()
+            )
+        else:
+            yield from self._methods.values()
+
+    def add_method(self, method: SfgMethod):
+        if method.name in self._methods:
+            raise SfgException(
+                f"Duplicate method name {method.name} in class {self._class_name}"
+            )
+
+        self._methods[method.name] = method
+
+    def member_variables(
+        self, visibility: SfgVisibility | None = None
+    ) -> Generator[SfgMemberVariable, None, None]:
+        if visibility is not None:
+            yield from filter(
+                lambda m: m.visibility == visibility, self._member_vars.values()
+            )
+        else:
+            yield from self._member_vars.values()
+
+    def add_member_variable(self, variable: SfgMemberVariable):
+        if variable.name in self._member_vars:
+            raise SfgException(
+                f"Duplicate field name {variable.name} in class {self._class_name}"
+            )
+
+        self._member_vars[variable.name] = variable
diff --git a/src/pystencilssfg/tree/deferred_nodes.py b/src/pystencilssfg/tree/deferred_nodes.py
index bf2cd32..bb3cab4 100644
--- a/src/pystencilssfg/tree/deferred_nodes.py
+++ b/src/pystencilssfg/tree/deferred_nodes.py
@@ -38,8 +38,8 @@ class SfgDeferredNode(SfgCallTreeNode, ABC):
 
 class SfgParamCollectionDeferredNode(SfgDeferredNode, ABC):
     @abstractmethod
-    def expand(self, ctx: SfgContext, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode:
-        pass
+    def expand(self, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode:
+        ...
 
 
 class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode):
@@ -47,7 +47,7 @@ class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode):
         self._field = field
         self._src_field = src_field
 
-    def expand(self, ctx: SfgContext, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode:
+    def expand(self, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode:
         #    Find field pointer
         ptr = None
         for param in visible_params:
diff --git a/src/pystencilssfg/tree/dispatcher.py b/src/pystencilssfg/tree/dispatcher.py
deleted file mode 100644
index 04108ed..0000000
--- a/src/pystencilssfg/tree/dispatcher.py
+++ /dev/null
@@ -1,44 +0,0 @@
-from __future__ import annotations
-from typing import Callable, TypeVar, Generic, ParamSpec
-from types import MethodType
-
-from functools import wraps
-
-from .basic_nodes import SfgCallTreeNode
-
-V = TypeVar("V")
-R = TypeVar("R")
-P = ParamSpec("P")
-
-
-class VisitorDispatcher(Generic[V, R]):
-    def __init__(self, wrapped_method: Callable[..., R]):
-        self._dispatch_dict: dict[type, Callable[..., R]] = {}
-        self._wrapped_method: Callable[..., R] = wrapped_method
-
-    def case(self, node_type: type):
-        """Decorator for visitor's methods"""
-
-        def decorate(handler: Callable[..., R]):
-            if node_type in self._dispatch_dict:
-                raise ValueError(f"Duplicate visitor case {node_type}")
-            self._dispatch_dict[node_type] = handler
-            return handler
-
-        return decorate
-
-    def __call__(self, instance: V, node: SfgCallTreeNode, *args, **kwargs) -> R:
-        for cls in node.__class__.mro():
-            if cls in self._dispatch_dict:
-                return self._dispatch_dict[cls](instance, node, *args, **kwargs)
-
-        return self._wrapped_method(instance, node, *args, **kwargs)
-
-    def __get__(self, obj: V, objtype=None) -> Callable[..., R]:
-        if obj is None:
-            return self
-        return MethodType(self, obj)
-
-
-def visitor(method):
-    return wraps(method)(VisitorDispatcher(method))
diff --git a/src/pystencilssfg/visitors/__init__.py b/src/pystencilssfg/visitors/__init__.py
new file mode 100644
index 0000000..48c673a
--- /dev/null
+++ b/src/pystencilssfg/visitors/__init__.py
@@ -0,0 +1,10 @@
+from .dispatcher import visitor
+from .collectors import CollectIncludes
+from .tree_visitors import FlattenSequences, ExpandingParameterCollector
+
+__all__ = [
+    "visitor",
+    "CollectIncludes",
+    "FlattenSequences",
+    "ExpandingParameterCollector",
+]
diff --git a/src/pystencilssfg/visitors/collectors.py b/src/pystencilssfg/visitors/collectors.py
new file mode 100644
index 0000000..9bd1774
--- /dev/null
+++ b/src/pystencilssfg/visitors/collectors.py
@@ -0,0 +1,45 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from functools import reduce
+
+from .dispatcher import visitor
+from ..exceptions import SfgException
+from ..tree import SfgCallTreeNode
+from ..source_components import SfgFunction, SfgClass, SfgConstructor, SfgMemberVariable
+
+if TYPE_CHECKING:
+    from ..source_components import SfgHeaderInclude
+
+
+class CollectIncludes:
+    @visitor
+    def visit(self, obj: object) -> set[SfgHeaderInclude]:
+        raise SfgException(f"Can't collect includes from object of type {type(obj)}")
+
+    @visit.case(SfgCallTreeNode)
+    def tree_node(self, node: SfgCallTreeNode) -> set[SfgHeaderInclude]:
+        return reduce(
+            lambda accu, child: accu | self.visit(child), node.children, node.required_includes
+        )
+
+    @visit.case(SfgFunction)
+    def sfg_function(self, func: SfgFunction) -> set[SfgHeaderInclude]:
+        return self.visit(func.tree)
+
+    @visit.case(SfgClass)
+    def sfg_class(self, cls: SfgClass) -> set[SfgHeaderInclude]:
+        return reduce(
+            lambda accu, member: accu | (self.visit(member)), cls.members(), set()
+        )
+
+    @visit.case(SfgConstructor)
+    def sfg_constructor(self, constr: SfgConstructor) -> set[SfgHeaderInclude]:
+        return reduce(
+            lambda accu, obj: accu | obj.required_includes, constr.parameters, set()
+        )
+
+    @visit.case(SfgMemberVariable)
+    def sfg_member_var(self, var: SfgMemberVariable) -> set[SfgHeaderInclude]:
+        return var.required_includes
diff --git a/src/pystencilssfg/visitors/dispatcher.py b/src/pystencilssfg/visitors/dispatcher.py
new file mode 100644
index 0000000..48bfda9
--- /dev/null
+++ b/src/pystencilssfg/visitors/dispatcher.py
@@ -0,0 +1,73 @@
+from __future__ import annotations
+from typing import Callable, TypeVar, Generic, ParamSpec
+from types import MethodType
+
+from functools import wraps
+
+from ..tree.basic_nodes import SfgCallTreeNode
+
+V = TypeVar("V")
+R = TypeVar("R")
+P = ParamSpec("P")
+
+
+class VisitorDispatcher(Generic[V, R]):
+    def __init__(self, wrapped_method: Callable[..., R]):
+        self._dispatch_dict: dict[type, Callable[..., R]] = {}
+        self._wrapped_method: Callable[..., R] = wrapped_method
+
+    def case(self, node_type: type):
+        """Decorator for visitor's methods"""
+
+        def decorate(handler: Callable[..., R]):
+            if node_type in self._dispatch_dict:
+                raise ValueError(f"Duplicate visitor case {node_type}")
+            self._dispatch_dict[node_type] = handler
+            return handler
+
+        return decorate
+
+    def __call__(self, instance: V, node: SfgCallTreeNode, *args, **kwargs) -> R:
+        for cls in node.__class__.mro():
+            if cls in self._dispatch_dict:
+                return self._dispatch_dict[cls](instance, node, *args, **kwargs)
+
+        return self._wrapped_method(instance, node, *args, **kwargs)
+
+    def __get__(self, obj: V, objtype=None) -> Callable[..., R]:
+        if obj is None:
+            return self
+        return MethodType(self, obj)
+
+
+def visitor(method):
+    """Decorator to create a visitor using type-based dispatch.
+
+    Use this decorator to convert a method into a visitor, like shown below.
+    After declaring a method `<method-name>` a visitor,
+    its dispatch variants can be declared using the `<method-name>,case` decarator, like this:
+
+    ```Python
+    class DemoVisitor:
+        @visitor
+        def visit(self, obj: object):
+            # fallback case
+            ...
+
+        @visit.case(str)
+        def visit_str(self, obj: str):
+            # code for handling a str
+    ```
+
+    Now, if `visit` is called with an object of type `str`, the call is dispatched to `visit_str`.
+    Dispatch follows the Python method resolution order; if cases are declared for both a type
+    and some of its parent types, the most specific case is executed.
+    If no case matches, the fallback code in the original `visit` method is executed.
+
+    This visitor dispatch method is primarily designed for traversing abstract syntax tree structures.
+    The primary visitor method (`visit` in above example) defines the common parent type of all object
+    types the visitor can handle - every case's type must be a subtype of this.
+    Of course, like in the example, this visitor dispatcher may be used with arbitrary types if the base
+    type is `object`.
+    """
+    return wraps(method)(VisitorDispatcher(method))
diff --git a/src/pystencilssfg/tree/visitors.py b/src/pystencilssfg/visitors/tree_visitors.py
similarity index 73%
rename from src/pystencilssfg/tree/visitors.py
rename to src/pystencilssfg/visitors/tree_visitors.py
index 877ed5a..76cb9c5 100644
--- a/src/pystencilssfg/tree/visitors.py
+++ b/src/pystencilssfg/visitors/tree_visitors.py
@@ -1,19 +1,21 @@
 from __future__ import annotations
 
-from typing import TYPE_CHECKING
+# from typing import TYPE_CHECKING
 
 from functools import reduce
 
-from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements
-from .deferred_nodes import SfgParamCollectionDeferredNode
+from ..tree.basic_nodes import (
+    SfgCallTreeNode,
+    SfgCallTreeLeaf,
+    SfgSequence,
+    SfgStatements,
+)
+from ..tree.deferred_nodes import SfgParamCollectionDeferredNode
 from .dispatcher import visitor
 from ..source_concepts.source_objects import TypedSymbolOrObject
 
-if TYPE_CHECKING:
-    from ..context import SfgContext
 
-
-class FlattenSequences():
+class FlattenSequences:
     """Flattens any nested sequences occuring in a kernel call tree."""
 
     @visitor
@@ -40,23 +42,12 @@ class FlattenSequences():
         sequence._children = children_flattened
 
 
-class CollectIncludes:
-    def visit(self, node: SfgCallTreeNode):
-        includes = node.required_includes
-        for c in node.children:
-            includes |= self.visit(c)
-
-        return includes
-
-
-class ExpandingParameterCollector():
+class ExpandingParameterCollector:
     """Collects all parameters required but not defined in a kernel call tree.
     Expands any deferred nodes of type `SfgParamCollectionDeferredNode` found within sequences on the way.
     """
 
-    def __init__(self, ctx: SfgContext) -> None:
-        self._ctx = ctx
-
+    def __init__(self) -> None:
         self._flattener = FlattenSequences()
 
     @visitor
@@ -70,17 +61,19 @@ class ExpandingParameterCollector():
     @visit.case(SfgSequence)
     def sequence(self, sequence: SfgSequence) -> set[TypedSymbolOrObject]:
         """
-            Only in a sequence may parameters be defined and visible to subsequent nodes.
+        Only in a sequence may parameters be defined and visible to subsequent nodes.
         """
 
         params: set[TypedSymbolOrObject] = set()
 
-        def iter_nested_sequences(seq: SfgSequence, visible_params: set[TypedSymbolOrObject]):
+        def iter_nested_sequences(
+            seq: SfgSequence, visible_params: set[TypedSymbolOrObject]
+        ):
             for i in range(len(seq.children) - 1, -1, -1):
                 c = seq.children[i]
 
                 if isinstance(c, SfgParamCollectionDeferredNode):
-                    c = c.expand(self._ctx, visible_params=visible_params)
+                    c = c.expand(visible_params=visible_params)
                     seq[i] = c
 
                 if isinstance(c, SfgSequence):
@@ -97,13 +90,13 @@ class ExpandingParameterCollector():
 
     def branching_node(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]:
         """
-            Each interior node that is not a sequence simply requires the union of all parameters
-            required by its children.
+        Each interior node that is not a sequence simply requires the union of all parameters
+        required by its children.
         """
         return reduce(lambda x, y: x | y, (self.visit(c) for c in node.children), set())
 
 
-class ParameterCollector():
+class ParameterCollector:
     """Collects all parameters required but not defined in a kernel call tree.
 
     Requires that all sequences in the tree are flattened.
@@ -120,7 +113,7 @@ class ParameterCollector():
     @visit.case(SfgSequence)
     def sequence(self, sequence: SfgSequence) -> set[TypedSymbolOrObject]:
         """
-            Only in a sequence may parameters be defined and visible to subsequent nodes.
+        Only in a sequence may parameters be defined and visible to subsequent nodes.
         """
 
         params: set[TypedSymbolOrObject] = set()
@@ -134,7 +127,7 @@ class ParameterCollector():
 
     def branching_node(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]:
         """
-            Each interior node that is not a sequence simply requires the union of all parameters
-            required by its children.
+        Each interior node that is not a sequence simply requires the union of all parameters
+        required by its children.
         """
         return reduce(lambda x, y: x | y, (self.visit(c) for c in node.children), set())
diff --git a/test_classses_out/test_classes.cpp b/test_classses_out/test_classes.cpp
new file mode 100644
index 0000000..93b26bc
--- /dev/null
+++ b/test_classses_out/test_classes.cpp
@@ -0,0 +1,32 @@
+
+
+#include "test_classes.h"
+
+
+#define FUNC_PREFIX inline
+
+
+/*************************************************************************************
+ *                                Kernels
+*************************************************************************************/
+
+namespace kernels {
+
+FUNC_PREFIX void kernel(double * RESTRICT  _data_f, double * RESTRICT const _data_g, int64_t const _size_f_0, int64_t const _size_f_1, int64_t const _stride_f_0, int64_t const _stride_f_1, int64_t const _stride_g_0, int64_t const _stride_g_1)
+{
+   for (int64_t ctr_0 = 0; ctr_0 < _size_f_0; ctr_0 += 1)
+   {
+      for (int64_t ctr_1 = 0; ctr_1 < _size_f_1; ctr_1 += 1)
+      {
+         _data_f[_stride_f_0*ctr_0 + _stride_f_1*ctr_1] = 3.0*_data_g[_stride_g_0*ctr_0 + _stride_g_1*ctr_1];
+      }
+   }
+}
+
+} // namespace kernels
+
+/*************************************************************************************
+ *                                Functions
+*************************************************************************************/
+
+
diff --git a/test_classses_out/test_classes.h b/test_classses_out/test_classes.h
new file mode 100644
index 0000000..3202665
--- /dev/null
+++ b/test_classses_out/test_classes.h
@@ -0,0 +1,24 @@
+
+
+#pragma once
+
+#include <cstdint>
+
+
+
+#define RESTRICT __restrict__
+
+
+
+class MyClass
+    : 
+{
+
+// default:
+
+
+};
+
+
+
+
-- 
GitLab