From 8f789bf6ce0033c57b2224532d9039b0636f1f68 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 1 Dec 2023 12:07:28 +0100 Subject: [PATCH] integrated classes --- integration/.gitignore | 1 + integration/test_classes.py | 45 ++++ src/pystencilssfg/composer.py | 4 +- src/pystencilssfg/context.py | 64 +++-- .../emitters/classes_printing.py | 67 +++++ .../emitters/header_source_pair.py | 15 +- src/pystencilssfg/emitters/jinja_filters.py | 25 +- .../templates/HeaderSourcePair.tmpl.cpp | 18 +- .../templates/HeaderSourcePair.tmpl.h | 4 + src/pystencilssfg/source_components.py | 229 ++++++++++++++++-- src/pystencilssfg/tree/deferred_nodes.py | 6 +- src/pystencilssfg/tree/dispatcher.py | 44 ---- src/pystencilssfg/visitors/__init__.py | 10 + src/pystencilssfg/visitors/collectors.py | 45 ++++ src/pystencilssfg/visitors/dispatcher.py | 73 ++++++ .../visitors.py => visitors/tree_visitors.py} | 51 ++-- test_classses_out/test_classes.cpp | 32 +++ test_classses_out/test_classes.h | 24 ++ 18 files changed, 631 insertions(+), 126 deletions(-) create mode 100644 integration/.gitignore create mode 100644 integration/test_classes.py create mode 100644 src/pystencilssfg/emitters/classes_printing.py delete mode 100644 src/pystencilssfg/tree/dispatcher.py create mode 100644 src/pystencilssfg/visitors/__init__.py create mode 100644 src/pystencilssfg/visitors/collectors.py create mode 100644 src/pystencilssfg/visitors/dispatcher.py rename src/pystencilssfg/{tree/visitors.py => visitors/tree_visitors.py} (73%) create mode 100644 test_classses_out/test_classes.cpp create mode 100644 test_classses_out/test_classes.h diff --git a/integration/.gitignore b/integration/.gitignore new file mode 100644 index 0000000..c585e19 --- /dev/null +++ b/integration/.gitignore @@ -0,0 +1 @@ +out \ No newline at end of file diff --git a/integration/test_classes.py b/integration/test_classes.py new file mode 100644 index 0000000..446921c --- /dev/null +++ b/integration/test_classes.py @@ -0,0 +1,45 @@ +# type: ignore +from pystencilssfg import SourceFileGenerator, SfgConfiguration +from pystencilssfg.source_concepts import SrcObject +from pystencilssfg.source_components import SfgClass, SfgMemberVariable, SfgConstructor, SfgMethod, SfgVisibility + +from pystencils import fields, kernel + +sfg_config = SfgConfiguration( + output_directory="out/test_classes", + outer_namespace="gen_code" +) + +f, g = fields("f, g(1): double[2D]") + +with SourceFileGenerator(sfg_config) as sfg: + + @kernel + def assignments(): + f[0,0] @= 3 * g[0,0] + + khandle = sfg.kernels.create(assignments) + + cls = SfgClass("MyClass") + cls.add_method(SfgMethod( + "callKernel", + sfg.call(khandle), + visibility=SfgVisibility.PUBLIC + )) + + cls.add_member_variable( + SfgMemberVariable( + "stuff", "std::vector< int >", + SfgVisibility.PRIVATE + ) + ) + + cls.add_constructor( + SfgConstructor( + [SrcObject("std::vector< int > &", "stuff")], + ["stuff_(stuff)"], + visibility=SfgVisibility.PUBLIC + ) + ) + + sfg.context.add_class(cls) diff --git a/src/pystencilssfg/composer.py b/src/pystencilssfg/composer.py index d405b66..96f59ef 100644 --- a/src/pystencilssfg/composer.py +++ b/src/pystencilssfg/composer.py @@ -77,7 +77,7 @@ class SfgComposer: else: raise TypeError("Invalid type of argument `ast_or_kernel_handle`!") - func = SfgFunction(self._ctx, name, tree) + func = SfgFunction(name, tree) self._ctx.add_function(func) def function(self, name: str): @@ -96,7 +96,7 @@ class SfgComposer: def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder): tree = make_sequence(*args) - func = SfgFunction(self._ctx, name, tree) + func = SfgFunction(name, tree) self._ctx.add_function(func) return sequencer diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index d1add3c..48ecfd2 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -1,16 +1,23 @@ from typing import Generator, Sequence from .configuration import SfgCodeStyle -from .tree.visitors import CollectIncludes -from .source_components import SfgHeaderInclude, SfgKernelNamespace, SfgFunction +from .visitors import CollectIncludes +from .source_components import ( + SfgHeaderInclude, + SfgKernelNamespace, + SfgFunction, + SfgClass, +) from .exceptions import SfgException class SfgContext: - def __init__(self, - outer_namespace: str | None = None, - codestyle: SfgCodeStyle = SfgCodeStyle(), - argv: Sequence[str] | None = None): + def __init__( + self, + outer_namespace: str | None = None, + codestyle: SfgCodeStyle = SfgCodeStyle(), + argv: Sequence[str] | None = None, + ): self._argv = argv self._default_kernel_namespace = SfgKernelNamespace(self, "kernels") @@ -23,8 +30,11 @@ class SfgContext: self._prelude: str = "" self._includes: set[SfgHeaderInclude] = set() self._definitions: list[str] = [] - self._kernel_namespaces = {self._default_kernel_namespace.name: self._default_kernel_namespace} + self._kernel_namespaces = { + self._default_kernel_namespace.name: self._default_kernel_namespace + } self._functions: dict[str, SfgFunction] = dict() + self._classes: dict[str, SfgClass] = dict() @property def argv(self) -> Sequence[str]: @@ -48,11 +58,16 @@ class SfgContext: @property def fully_qualified_namespace(self) -> str | None: match (self.outer_namespace, self.inner_namespace): - case None, None: return None - case outer, None: return outer - case None, inner: return inner - case outer, inner: return f"{outer}::{inner}" - case _: assert False + case None, None: + return None + case outer, None: + return outer + case None, inner: + return inner + case outer, inner: + return f"{outer}::{inner}" + case _: + assert False @property def codestyle(self) -> SfgCodeStyle: @@ -127,10 +142,29 @@ class SfgContext: def get_function(self, name: str) -> SfgFunction | None: return self._functions.get(name, None) - def add_function(self, func: SfgFunction) -> None: + def add_function(self, func: SfgFunction): if func.name in self._functions: - raise ValueError(f"Duplicate function: {func.name}") + raise SfgException(f"Duplicate function: {func.name}") self._functions[func.name] = func - for incl in CollectIncludes().visit(func._tree): + for incl in CollectIncludes().visit(func): + self.add_include(incl) + + # ---------------------------------------------------------------------------------------------- + # Classes + # ---------------------------------------------------------------------------------------------- + + def classes(self) -> Generator[SfgClass, None, None]: + yield from self._classes.values() + + def get_class(self, name: str) -> SfgClass | None: + return self._classes.get(name, None) + + def add_class(self, cls: SfgClass): + if cls.class_name in self._classes: + raise SfgException(f"Duplicate class: {cls.class_name}") + + self._classes[cls.class_name] = cls + + for incl in CollectIncludes().visit(cls): self.add_include(incl) diff --git a/src/pystencilssfg/emitters/classes_printing.py b/src/pystencilssfg/emitters/classes_printing.py new file mode 100644 index 0000000..f80c7ca --- /dev/null +++ b/src/pystencilssfg/emitters/classes_printing.py @@ -0,0 +1,67 @@ +from ..context import SfgContext +from ..visitors import visitor +from ..source_components import ( + SfgClass, + SfgConstructor, + SfgMemberVariable, + SfgMethod, + SfgVisibility, +) +from ..exceptions import SfgException + + +class ClassDeclarationPrinter: + def __init__(self, ctx: SfgContext): + self._codestyle = ctx.codestyle + + def print(self, cls: SfgClass): + return self.visit(cls, cls) + + @visitor + def visit(self, obj: object, cls: SfgClass) -> str: + raise SfgException("Can't print this.") + + @visit.case(SfgClass) + def sfg_class(self, cls: SfgClass, _: SfgClass): + code = f"{cls.class_keyword} {cls.class_name} \n" + + if cls.base_classes: + code += f" : {','.join(cls.base_classes)}\n" + + code += "{\n" + for visibility in ( + SfgVisibility.DEFAULT, + SfgVisibility.PUBLIC, + SfgVisibility.PRIVATE, + ): + if visibility != SfgVisibility.DEFAULT: + code += f"\n{visibility}:\n" + for member in cls.members(visibility): + code += self._codestyle.indent(self.visit(member, cls)) + "\n" + code += "};\n" + + return code + + @visit.case(SfgConstructor) + def sfg_constructor(self, constr: SfgConstructor, cls: SfgClass): + code = f"{cls.class_name} (" + code += ", ".join(f"{param.dtype} {param.name}" for param in constr.parameters) + code += ")\n" + if constr.initializers: + code += " : " + ", ".join(constr.initializers) + "\n" + if constr.body: + code += "{\n" + self._codestyle.indent(constr.body) + "\n}\n" + else: + code += "{ }\n" + return code + + @visit.case(SfgMemberVariable) + def sfg_member_var(self, var: SfgMemberVariable, _: SfgClass): + return f"{var.dtype} {var.name};" + + @visit.case(SfgMethod) + def sfg_method(self, method: SfgMethod, _: SfgClass): + code = f"void {method.name} (" + code += ", ".join(f"{param.dtype} {param.name}" for param in method.parameters) + code += ");" + return code diff --git a/src/pystencilssfg/emitters/header_source_pair.py b/src/pystencilssfg/emitters/header_source_pair.py index 33653ab..553f84b 100644 --- a/src/pystencilssfg/emitters/header_source_pair.py +++ b/src/pystencilssfg/emitters/header_source_pair.py @@ -1,6 +1,6 @@ from jinja2 import Environment, PackageLoader, StrictUndefined -from os import path +from os import path, makedirs from ..configuration import SfgOutputSpec from ..context import SfgContext @@ -31,12 +31,13 @@ class HeaderSourcePairEmitter: 'source_filename': self._impl_filename, 'basename': self._basename, 'prelude_comment': ctx.prelude_comment, - 'definitions': list(ctx.definitions()), + 'definitions': tuple(ctx.definitions()), 'fq_namespace': fq_namespace, - 'public_includes': list(incl.get_code() for incl in ctx.includes() if not incl.private), - 'private_includes': list(incl.get_code() for incl in ctx.includes() if incl.private), - 'kernel_namespaces': list(ctx.kernel_namespaces()), - 'functions': list(ctx.functions()) + 'public_includes': tuple(incl.get_code() for incl in ctx.includes() if not incl.private), + 'private_includes': tuple(incl.get_code() for incl in ctx.includes() if incl.private), + 'kernel_namespaces': tuple(ctx.kernel_namespaces()), + 'functions': tuple(ctx.functions()), + 'classes': tuple(ctx.classes()) } template_name = "HeaderSourcePair" @@ -52,6 +53,8 @@ class HeaderSourcePairEmitter: header = env.get_template(f"{template_name}.tmpl.h").render(**jinja_context) source = env.get_template(f"{template_name}.tmpl.cpp").render(**jinja_context) + makedirs(self._output_directory, exist_ok=True) + with open(self._ospec.get_header_filepath(), 'w') as headerfile: headerfile.write(header) diff --git a/src/pystencilssfg/emitters/jinja_filters.py b/src/pystencilssfg/emitters/jinja_filters.py index f6ee5e1..31169f4 100644 --- a/src/pystencilssfg/emitters/jinja_filters.py +++ b/src/pystencilssfg/emitters/jinja_filters.py @@ -5,7 +5,8 @@ from pystencils.astnodes import KernelFunction from pystencils import Backend from pystencils.backends import generate_c -from pystencilssfg.source_components import SfgFunction +from pystencilssfg.source_components import SfgFunction, SfgClass +from .classes_printing import ClassDeclarationPrinter def format_prelude_comment(prelude_comment: str): @@ -26,12 +27,22 @@ def generate_function_parameter_list(ctx, func: SfgFunction): return ", ".join(f"{param.dtype} {param.name}" for param in params) -def generate_function_body(func: SfgFunction): - return func.get_code() +@pass_context +def generate_function_body(ctx, func: SfgFunction): + return func.get_code(ctx["ctx"]) + + +@pass_context +def print_class_declaration(ctx, cls: SfgClass): + return ClassDeclarationPrinter(ctx["ctx"]).print(cls) def add_filters_to_jinja(jinja_env): - jinja_env.filters['format_prelude_comment'] = format_prelude_comment - jinja_env.filters['generate_kernel_definition'] = generate_kernel_definition - jinja_env.filters['generate_function_parameter_list'] = generate_function_parameter_list - jinja_env.filters['generate_function_body'] = generate_function_body + jinja_env.filters["format_prelude_comment"] = format_prelude_comment + jinja_env.filters["generate_kernel_definition"] = generate_kernel_definition + jinja_env.filters[ + "generate_function_parameter_list" + ] = generate_function_parameter_list + jinja_env.filters["generate_function_body"] = generate_function_body + + jinja_env.filters["print_class_declaration"] = print_class_declaration diff --git a/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.cpp b/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.cpp index b1d37cc..48c9653 100644 --- a/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.cpp +++ b/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.cpp @@ -31,12 +31,26 @@ namespace {{ kns.name }} { *************************************************************************************/ {% for function in functions %} - void {{ function.name }} ( {{ function | generate_function_parameter_list }} ) { - {{ function | generate_function_body | indent(2) }} + {{ function | generate_function_body | indent(ctx.codestyle.indent_width) }} } + +{% endfor -%} + +/************************************************************************************* + * Class Methods +*************************************************************************************/ + +{% for cls in classes %} +{% for method in cls.methods() %} +void {{ cls.class_name }}::{{ method.name }} ( {{ method | generate_function_parameter_list }} ) { + {{ method | generate_function_body | indent(ctx.codestyle.indent_width) }} +} + + {% endfor %} +{% endfor -%} {% if fq_namespace is not none %} } // namespace {{fq_namespace}} diff --git a/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.h b/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.h index 40ce1ed..95b3ed8 100644 --- a/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.h +++ b/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.h @@ -18,6 +18,10 @@ namespace {{fq_namespace}} { {% endif %} +{% for cls in classes %} +{{ cls | print_class_declaration }} +{% endfor %} + {% for function in functions %} void {{ function.name }} ( {{ function | generate_function_parameter_list }} ); {% endfor %} diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py index 17cd3c9..af89aff 100644 --- a/src/pystencilssfg/source_components.py +++ b/src/pystencilssfg/source_components.py @@ -1,18 +1,26 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +from abc import ABC +from enum import Enum, auto +from typing import TYPE_CHECKING, Sequence, Generator from dataclasses import replace from pystencils import CreateKernelConfig, create_kernel from pystencils.astnodes import KernelFunction +from .types import SrcType +from .source_concepts import SrcObject +from .exceptions import SfgException + if TYPE_CHECKING: from .context import SfgContext from .tree import SfgCallTreeNode class SfgHeaderInclude: - def __init__(self, header_file: str, system_header: bool = False, private: bool = False): + def __init__( + self, header_file: str, system_header: bool = False, private: bool = False + ): self._header_file = header_file self._system_header = system_header self._private = private @@ -35,10 +43,12 @@ class SfgHeaderInclude: return hash((self._header_file, self._system_header, self._private)) def __eq__(self, other: object) -> bool: - return (isinstance(other, SfgHeaderInclude) - and self._header_file == other._header_file - and self._system_header == other._system_header - and self._private == other._private) + return ( + isinstance(other, SfgHeaderInclude) + and self._header_file == other._header_file + and self._system_header == other._system_header + and self._private == other._private + ) class SfgKernelNamespace: @@ -64,7 +74,9 @@ class SfgKernelNamespace: astname = ast.function_name if astname in self._asts: - raise ValueError(f"Duplicate ASTs: An AST with name {astname} already exists in namespace {self._name}") + raise ValueError( + f"Duplicate ASTs: An AST with name {astname} already exists in namespace {self._name}" + ) if name is not None: ast.function_name = name @@ -73,7 +85,12 @@ class SfgKernelNamespace: return SfgKernelHandle(self._ctx, astname, self, ast.get_parameters()) - def create(self, assignments, name: str | None = None, config: CreateKernelConfig | None = None): + def create( + self, + assignments, + name: str | None = None, + config: CreateKernelConfig | None = None, + ): """Creates a new pystencils kernel from a list of assignments and a configuration. This is a wrapper around [`pystencils.create_kernel`]( @@ -87,7 +104,9 @@ class SfgKernelNamespace: if name is not None: if name in self._asts: - raise ValueError(f"Duplicate ASTs: An AST with name {name} already exists in namespace {self._name}") + raise ValueError( + f"Duplicate ASTs: An AST with name {name} already exists in namespace {self._name}" + ) config = replace(config, function_name=name) # type: ignore @@ -96,7 +115,13 @@ class SfgKernelNamespace: class SfgKernelHandle: - def __init__(self, ctx, name: str, namespace: SfgKernelNamespace, parameters: Sequence[KernelFunction.Parameter]): + def __init__( + self, + ctx, + name: str, + namespace: SfgKernelNamespace, + parameters: Sequence[KernelFunction.Parameter], + ): self._ctx = ctx self._name = name self._namespace = namespace @@ -122,8 +147,10 @@ class SfgKernelHandle: @property def fully_qualified_name(self): match self._ctx.fully_qualified_namespace: - case None: return f"{self.kernel_namespace.name}::{self.kernel_name}" - case fqn: return f"{fqn}::{self.kernel_namespace.name}::{self.kernel_name}" + case None: + return f"{self.kernel_namespace.name}::{self.kernel_name}" + case fqn: + return f"{fqn}::{self.kernel_namespace.name}::{self.kernel_name}" @property def parameters(self): @@ -139,14 +166,13 @@ class SfgKernelHandle: class SfgFunction: - def __init__(self, ctx: SfgContext, name: str, tree: SfgCallTreeNode): - self._ctx = ctx + def __init__(self, name: str, tree: SfgCallTreeNode): self._name = name self._tree = tree - from .tree.visitors import ExpandingParameterCollector + from .visitors.tree_visitors import ExpandingParameterCollector - param_collector = ExpandingParameterCollector(self._ctx) + param_collector = ExpandingParameterCollector() self._parameters = param_collector.visit(self._tree) @property @@ -161,5 +187,172 @@ class SfgFunction: def tree(self): return self._tree - def get_code(self): - return self._tree.get_code(self._ctx) + def get_code(self, ctx: SfgContext): + return self._tree.get_code(ctx) + + +class SfgVisibility(Enum): + DEFAULT = auto() + PRIVATE = auto() + PUBLIC = auto() + + def __str__(self) -> str: + match self: + case SfgVisibility.DEFAULT: + return "" + case SfgVisibility.PRIVATE: + return "private" + case SfgVisibility.PUBLIC: + return "public" + + +class SfgClassKeyword(Enum): + STRUCT = auto() + CLASS = auto() + + def __str__(self) -> str: + match self: + case SfgClassKeyword.STRUCT: + return "struct" + case SfgClassKeyword.CLASS: + return "class" + + +class SfgClassMember(ABC): + def __init__(self, visibility: SfgVisibility): + self._visibility = visibility + + @property + def visibility(self) -> SfgVisibility: + return self._visibility + + +class SfgMemberVariable(SrcObject, SfgClassMember): + def __init__( + self, + name: str, + type: SrcType, + visibility: SfgVisibility = SfgVisibility.DEFAULT, + ): + SrcObject.__init__(self, type, name) + SfgClassMember.__init__(self, visibility) + + +class SfgMethod(SfgFunction, SfgClassMember): + def __init__( + self, + name: str, + tree: SfgCallTreeNode, + visibility: SfgVisibility = SfgVisibility.DEFAULT, + ): + SfgFunction.__init__(self, name, tree) + SfgClassMember.__init__(self, visibility) + + +class SfgConstructor(SfgClassMember): + def __init__( + self, + parameters: Sequence[SrcObject] = (), + initializers: Sequence[str] = (), + body: str = "", + visibility: SfgVisibility = SfgVisibility.DEFAULT, + ): + SfgClassMember.__init__(self, visibility) + self._parameters = tuple(parameters) + self._initializers = tuple(initializers) + self._body = body + + @property + def parameters(self) -> tuple[SrcObject, ...]: + return self._parameters + + @property + def initializers(self) -> tuple[str, ...]: + return self._initializers + + @property + def body(self) -> str: + return self._body + + +class SfgClass: + def __init__( + self, + class_name: str, + class_keyword: SfgClassKeyword = SfgClassKeyword.CLASS, + bases: Sequence[str] = (), + ): + self._class_name = class_name + self._class_keyword = class_keyword + self._bases_classes = tuple(bases) + + self._constructors: list[SfgConstructor] = [] + self._methods: dict[str, SfgMethod] = dict() + self._member_vars: dict[str, SfgMemberVariable] = dict() + + @property + def class_name(self) -> str: + return self._class_name + + @property + def base_classes(self) -> tuple[str, ...]: + return self._bases_classes + + @property + def class_keyword(self) -> SfgClassKeyword: + return self._class_keyword + + def members( + self, visibility: SfgVisibility | None = None + ) -> Generator[SfgClassMember, None, None]: + yield from self.member_variables(visibility) + yield from self.constructors(visibility) + yield from self.methods(visibility) + + def constructors( + self, visibility: SfgVisibility | None = None + ) -> Generator[SfgConstructor, None, None]: + if visibility is not None: + yield from filter(lambda m: m.visibility == visibility, self._constructors) + else: + yield from self._constructors + + def add_constructor(self, constr: SfgConstructor): + # TODO: Check for signature conflicts? + self._constructors.append(constr) + + def methods( + self, visibility: SfgVisibility | None = None + ) -> Generator[SfgMethod, None, None]: + if visibility is not None: + yield from filter( + lambda m: m.visibility == visibility, self._methods.values() + ) + else: + yield from self._methods.values() + + def add_method(self, method: SfgMethod): + if method.name in self._methods: + raise SfgException( + f"Duplicate method name {method.name} in class {self._class_name}" + ) + + self._methods[method.name] = method + + def member_variables( + self, visibility: SfgVisibility | None = None + ) -> Generator[SfgMemberVariable, None, None]: + if visibility is not None: + yield from filter( + lambda m: m.visibility == visibility, self._member_vars.values() + ) + else: + yield from self._member_vars.values() + + def add_member_variable(self, variable: SfgMemberVariable): + if variable.name in self._member_vars: + raise SfgException( + f"Duplicate field name {variable.name} in class {self._class_name}" + ) + + self._member_vars[variable.name] = variable diff --git a/src/pystencilssfg/tree/deferred_nodes.py b/src/pystencilssfg/tree/deferred_nodes.py index bf2cd32..bb3cab4 100644 --- a/src/pystencilssfg/tree/deferred_nodes.py +++ b/src/pystencilssfg/tree/deferred_nodes.py @@ -38,8 +38,8 @@ class SfgDeferredNode(SfgCallTreeNode, ABC): class SfgParamCollectionDeferredNode(SfgDeferredNode, ABC): @abstractmethod - def expand(self, ctx: SfgContext, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode: - pass + def expand(self, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode: + ... class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode): @@ -47,7 +47,7 @@ class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode): self._field = field self._src_field = src_field - def expand(self, ctx: SfgContext, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode: + def expand(self, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode: # Find field pointer ptr = None for param in visible_params: diff --git a/src/pystencilssfg/tree/dispatcher.py b/src/pystencilssfg/tree/dispatcher.py deleted file mode 100644 index 04108ed..0000000 --- a/src/pystencilssfg/tree/dispatcher.py +++ /dev/null @@ -1,44 +0,0 @@ -from __future__ import annotations -from typing import Callable, TypeVar, Generic, ParamSpec -from types import MethodType - -from functools import wraps - -from .basic_nodes import SfgCallTreeNode - -V = TypeVar("V") -R = TypeVar("R") -P = ParamSpec("P") - - -class VisitorDispatcher(Generic[V, R]): - def __init__(self, wrapped_method: Callable[..., R]): - self._dispatch_dict: dict[type, Callable[..., R]] = {} - self._wrapped_method: Callable[..., R] = wrapped_method - - def case(self, node_type: type): - """Decorator for visitor's methods""" - - def decorate(handler: Callable[..., R]): - if node_type in self._dispatch_dict: - raise ValueError(f"Duplicate visitor case {node_type}") - self._dispatch_dict[node_type] = handler - return handler - - return decorate - - def __call__(self, instance: V, node: SfgCallTreeNode, *args, **kwargs) -> R: - for cls in node.__class__.mro(): - if cls in self._dispatch_dict: - return self._dispatch_dict[cls](instance, node, *args, **kwargs) - - return self._wrapped_method(instance, node, *args, **kwargs) - - def __get__(self, obj: V, objtype=None) -> Callable[..., R]: - if obj is None: - return self - return MethodType(self, obj) - - -def visitor(method): - return wraps(method)(VisitorDispatcher(method)) diff --git a/src/pystencilssfg/visitors/__init__.py b/src/pystencilssfg/visitors/__init__.py new file mode 100644 index 0000000..48c673a --- /dev/null +++ b/src/pystencilssfg/visitors/__init__.py @@ -0,0 +1,10 @@ +from .dispatcher import visitor +from .collectors import CollectIncludes +from .tree_visitors import FlattenSequences, ExpandingParameterCollector + +__all__ = [ + "visitor", + "CollectIncludes", + "FlattenSequences", + "ExpandingParameterCollector", +] diff --git a/src/pystencilssfg/visitors/collectors.py b/src/pystencilssfg/visitors/collectors.py new file mode 100644 index 0000000..9bd1774 --- /dev/null +++ b/src/pystencilssfg/visitors/collectors.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from functools import reduce + +from .dispatcher import visitor +from ..exceptions import SfgException +from ..tree import SfgCallTreeNode +from ..source_components import SfgFunction, SfgClass, SfgConstructor, SfgMemberVariable + +if TYPE_CHECKING: + from ..source_components import SfgHeaderInclude + + +class CollectIncludes: + @visitor + def visit(self, obj: object) -> set[SfgHeaderInclude]: + raise SfgException(f"Can't collect includes from object of type {type(obj)}") + + @visit.case(SfgCallTreeNode) + def tree_node(self, node: SfgCallTreeNode) -> set[SfgHeaderInclude]: + return reduce( + lambda accu, child: accu | self.visit(child), node.children, node.required_includes + ) + + @visit.case(SfgFunction) + def sfg_function(self, func: SfgFunction) -> set[SfgHeaderInclude]: + return self.visit(func.tree) + + @visit.case(SfgClass) + def sfg_class(self, cls: SfgClass) -> set[SfgHeaderInclude]: + return reduce( + lambda accu, member: accu | (self.visit(member)), cls.members(), set() + ) + + @visit.case(SfgConstructor) + def sfg_constructor(self, constr: SfgConstructor) -> set[SfgHeaderInclude]: + return reduce( + lambda accu, obj: accu | obj.required_includes, constr.parameters, set() + ) + + @visit.case(SfgMemberVariable) + def sfg_member_var(self, var: SfgMemberVariable) -> set[SfgHeaderInclude]: + return var.required_includes diff --git a/src/pystencilssfg/visitors/dispatcher.py b/src/pystencilssfg/visitors/dispatcher.py new file mode 100644 index 0000000..48bfda9 --- /dev/null +++ b/src/pystencilssfg/visitors/dispatcher.py @@ -0,0 +1,73 @@ +from __future__ import annotations +from typing import Callable, TypeVar, Generic, ParamSpec +from types import MethodType + +from functools import wraps + +from ..tree.basic_nodes import SfgCallTreeNode + +V = TypeVar("V") +R = TypeVar("R") +P = ParamSpec("P") + + +class VisitorDispatcher(Generic[V, R]): + def __init__(self, wrapped_method: Callable[..., R]): + self._dispatch_dict: dict[type, Callable[..., R]] = {} + self._wrapped_method: Callable[..., R] = wrapped_method + + def case(self, node_type: type): + """Decorator for visitor's methods""" + + def decorate(handler: Callable[..., R]): + if node_type in self._dispatch_dict: + raise ValueError(f"Duplicate visitor case {node_type}") + self._dispatch_dict[node_type] = handler + return handler + + return decorate + + def __call__(self, instance: V, node: SfgCallTreeNode, *args, **kwargs) -> R: + for cls in node.__class__.mro(): + if cls in self._dispatch_dict: + return self._dispatch_dict[cls](instance, node, *args, **kwargs) + + return self._wrapped_method(instance, node, *args, **kwargs) + + def __get__(self, obj: V, objtype=None) -> Callable[..., R]: + if obj is None: + return self + return MethodType(self, obj) + + +def visitor(method): + """Decorator to create a visitor using type-based dispatch. + + Use this decorator to convert a method into a visitor, like shown below. + After declaring a method `<method-name>` a visitor, + its dispatch variants can be declared using the `<method-name>,case` decarator, like this: + + ```Python + class DemoVisitor: + @visitor + def visit(self, obj: object): + # fallback case + ... + + @visit.case(str) + def visit_str(self, obj: str): + # code for handling a str + ``` + + Now, if `visit` is called with an object of type `str`, the call is dispatched to `visit_str`. + Dispatch follows the Python method resolution order; if cases are declared for both a type + and some of its parent types, the most specific case is executed. + If no case matches, the fallback code in the original `visit` method is executed. + + This visitor dispatch method is primarily designed for traversing abstract syntax tree structures. + The primary visitor method (`visit` in above example) defines the common parent type of all object + types the visitor can handle - every case's type must be a subtype of this. + Of course, like in the example, this visitor dispatcher may be used with arbitrary types if the base + type is `object`. + """ + return wraps(method)(VisitorDispatcher(method)) diff --git a/src/pystencilssfg/tree/visitors.py b/src/pystencilssfg/visitors/tree_visitors.py similarity index 73% rename from src/pystencilssfg/tree/visitors.py rename to src/pystencilssfg/visitors/tree_visitors.py index 877ed5a..76cb9c5 100644 --- a/src/pystencilssfg/tree/visitors.py +++ b/src/pystencilssfg/visitors/tree_visitors.py @@ -1,19 +1,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING +# from typing import TYPE_CHECKING from functools import reduce -from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements -from .deferred_nodes import SfgParamCollectionDeferredNode +from ..tree.basic_nodes import ( + SfgCallTreeNode, + SfgCallTreeLeaf, + SfgSequence, + SfgStatements, +) +from ..tree.deferred_nodes import SfgParamCollectionDeferredNode from .dispatcher import visitor from ..source_concepts.source_objects import TypedSymbolOrObject -if TYPE_CHECKING: - from ..context import SfgContext - -class FlattenSequences(): +class FlattenSequences: """Flattens any nested sequences occuring in a kernel call tree.""" @visitor @@ -40,23 +42,12 @@ class FlattenSequences(): sequence._children = children_flattened -class CollectIncludes: - def visit(self, node: SfgCallTreeNode): - includes = node.required_includes - for c in node.children: - includes |= self.visit(c) - - return includes - - -class ExpandingParameterCollector(): +class ExpandingParameterCollector: """Collects all parameters required but not defined in a kernel call tree. Expands any deferred nodes of type `SfgParamCollectionDeferredNode` found within sequences on the way. """ - def __init__(self, ctx: SfgContext) -> None: - self._ctx = ctx - + def __init__(self) -> None: self._flattener = FlattenSequences() @visitor @@ -70,17 +61,19 @@ class ExpandingParameterCollector(): @visit.case(SfgSequence) def sequence(self, sequence: SfgSequence) -> set[TypedSymbolOrObject]: """ - Only in a sequence may parameters be defined and visible to subsequent nodes. + Only in a sequence may parameters be defined and visible to subsequent nodes. """ params: set[TypedSymbolOrObject] = set() - def iter_nested_sequences(seq: SfgSequence, visible_params: set[TypedSymbolOrObject]): + def iter_nested_sequences( + seq: SfgSequence, visible_params: set[TypedSymbolOrObject] + ): for i in range(len(seq.children) - 1, -1, -1): c = seq.children[i] if isinstance(c, SfgParamCollectionDeferredNode): - c = c.expand(self._ctx, visible_params=visible_params) + c = c.expand(visible_params=visible_params) seq[i] = c if isinstance(c, SfgSequence): @@ -97,13 +90,13 @@ class ExpandingParameterCollector(): def branching_node(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]: """ - Each interior node that is not a sequence simply requires the union of all parameters - required by its children. + Each interior node that is not a sequence simply requires the union of all parameters + required by its children. """ return reduce(lambda x, y: x | y, (self.visit(c) for c in node.children), set()) -class ParameterCollector(): +class ParameterCollector: """Collects all parameters required but not defined in a kernel call tree. Requires that all sequences in the tree are flattened. @@ -120,7 +113,7 @@ class ParameterCollector(): @visit.case(SfgSequence) def sequence(self, sequence: SfgSequence) -> set[TypedSymbolOrObject]: """ - Only in a sequence may parameters be defined and visible to subsequent nodes. + Only in a sequence may parameters be defined and visible to subsequent nodes. """ params: set[TypedSymbolOrObject] = set() @@ -134,7 +127,7 @@ class ParameterCollector(): def branching_node(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]: """ - Each interior node that is not a sequence simply requires the union of all parameters - required by its children. + Each interior node that is not a sequence simply requires the union of all parameters + required by its children. """ return reduce(lambda x, y: x | y, (self.visit(c) for c in node.children), set()) diff --git a/test_classses_out/test_classes.cpp b/test_classses_out/test_classes.cpp new file mode 100644 index 0000000..93b26bc --- /dev/null +++ b/test_classses_out/test_classes.cpp @@ -0,0 +1,32 @@ + + +#include "test_classes.h" + + +#define FUNC_PREFIX inline + + +/************************************************************************************* + * Kernels +*************************************************************************************/ + +namespace kernels { + +FUNC_PREFIX void kernel(double * RESTRICT _data_f, double * RESTRICT const _data_g, int64_t const _size_f_0, int64_t const _size_f_1, int64_t const _stride_f_0, int64_t const _stride_f_1, int64_t const _stride_g_0, int64_t const _stride_g_1) +{ + for (int64_t ctr_0 = 0; ctr_0 < _size_f_0; ctr_0 += 1) + { + for (int64_t ctr_1 = 0; ctr_1 < _size_f_1; ctr_1 += 1) + { + _data_f[_stride_f_0*ctr_0 + _stride_f_1*ctr_1] = 3.0*_data_g[_stride_g_0*ctr_0 + _stride_g_1*ctr_1]; + } + } +} + +} // namespace kernels + +/************************************************************************************* + * Functions +*************************************************************************************/ + + diff --git a/test_classses_out/test_classes.h b/test_classses_out/test_classes.h new file mode 100644 index 0000000..3202665 --- /dev/null +++ b/test_classses_out/test_classes.h @@ -0,0 +1,24 @@ + + +#pragma once + +#include <cstdint> + + + +#define RESTRICT __restrict__ + + + +class MyClass + : +{ + +// default: + + +}; + + + + -- GitLab