diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index ef69c06bd9921717fa745f990261fc3fca46dd7a..222b313424c67f031b580d6061375022dceda603 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -1,7 +1,6 @@ from typing import Generator, Sequence from .configuration import SfgCodeStyle -from .visitors import CollectIncludes from .source_components import ( SfgHeaderInclude, SfgKernelNamespace, @@ -156,9 +155,6 @@ class SfgContext: self._functions[func.name] = func self._declarations_ordered.append(func) - for incl in CollectIncludes().visit(func): - self.add_include(incl) - # ---------------------------------------------------------------------------------------------- # Classes # ---------------------------------------------------------------------------------------------- @@ -176,9 +172,6 @@ class SfgContext: self._classes[cls.class_name] = cls self._declarations_ordered.append(cls) - for incl in CollectIncludes().visit(cls): - self.add_include(incl) - # ---------------------------------------------------------------------------------------------- # Declarations in order of addition # ---------------------------------------------------------------------------------------------- diff --git a/src/pystencilssfg/emission/header_source_pair.py b/src/pystencilssfg/emission/header_source_pair.py index b21b9952a7c77c82187cdbdb8fb9e2ee2ce1c8dd..20e590393a97eb567d025a14cfa884d1b83a7d9a 100644 --- a/src/pystencilssfg/emission/header_source_pair.py +++ b/src/pystencilssfg/emission/header_source_pair.py @@ -2,6 +2,7 @@ from os import path, makedirs from ..configuration import SfgOutputSpec from ..context import SfgContext +from .prepare import prepare_context from .printers import SfgHeaderPrinter, SfgImplPrinter from .clang_format import invoke_clang_format @@ -20,10 +21,12 @@ class HeaderSourcePairEmitter: def output_files(self) -> tuple[str, str]: return ( path.join(self._output_directory, self._header_filename), - path.join(self._output_directory, self._impl_filename) + path.join(self._output_directory, self._impl_filename), ) def write_files(self, ctx: SfgContext): + ctx = prepare_context(ctx) + header_printer = SfgHeaderPrinter(ctx, self._ospec) impl_printer = SfgImplPrinter(ctx, self._ospec) @@ -35,8 +38,8 @@ class HeaderSourcePairEmitter: makedirs(self._output_directory, exist_ok=True) - with open(self._ospec.get_header_filepath(), 'w') as headerfile: + with open(self._ospec.get_header_filepath(), "w") as headerfile: headerfile.write(header) - with open(self._ospec.get_impl_filepath(), 'w') as cppfile: + with open(self._ospec.get_impl_filepath(), "w") as cppfile: cppfile.write(impl) diff --git a/src/pystencilssfg/emission/prepare.py b/src/pystencilssfg/emission/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..28493bbb99abecad9dc0fbaf62d0cb1582373fac --- /dev/null +++ b/src/pystencilssfg/emission/prepare.py @@ -0,0 +1,20 @@ +from ..context import SfgContext + +from ..visitors import CollectIncludes + + +def prepare_context(ctx: SfgContext): + """Prepares a populated context for printing. Make sure to run this function on the + [SfgContext][pystencilssfg.SfgContext] before passing it to a printer. + + Steps: + - Collection of includes: All defined functions and classes are traversed to collect all required + header includes + """ + + # Collect all includes + required_includes = CollectIncludes().visit(ctx) + for incl in required_includes: + ctx.add_include(incl) + + return ctx diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py index e753b3717f49778af8fb65a423af3d4b6b53ec72..bf3a5147b47b3dd62bad99e450c1c89335a9cc49 100644 --- a/src/pystencilssfg/source_components.py +++ b/src/pystencilssfg/source_components.py @@ -440,9 +440,14 @@ class SfgClass: def members( self, visibility: SfgVisibility | None = None ) -> Generator[SfgClassMember, None, None]: - yield from chain.from_iterable( - b.members() for b in filter(lambda b: b.visibility == visibility, self._blocks) - ) + if visibility is None: + yield from chain.from_iterable( + b.members() for b in self._blocks + ) + else: + yield from chain.from_iterable( + b.members() for b in filter(lambda b: b.visibility == visibility, self._blocks) + ) def definitions( self, visibility: SfgVisibility | None = None diff --git a/src/pystencilssfg/visitors/collectors.py b/src/pystencilssfg/visitors/collectors.py index 9bd1774ff393ae09b4883914f9ca745776dbd76a..fef8955bb8a5eab7c80d2c812b41e64ce7f938ba 100644 --- a/src/pystencilssfg/visitors/collectors.py +++ b/src/pystencilssfg/visitors/collectors.py @@ -7,7 +7,14 @@ from functools import reduce from .dispatcher import visitor from ..exceptions import SfgException from ..tree import SfgCallTreeNode -from ..source_components import SfgFunction, SfgClass, SfgConstructor, SfgMemberVariable +from ..source_components import ( + SfgFunction, + SfgClass, + SfgConstructor, + SfgMemberVariable, + SfgInClassDefinition, +) +from ..context import SfgContext if TYPE_CHECKING: from ..source_components import SfgHeaderInclude @@ -18,10 +25,23 @@ class CollectIncludes: def visit(self, obj: object) -> set[SfgHeaderInclude]: raise SfgException(f"Can't collect includes from object of type {type(obj)}") + @visit.case(SfgContext) + def context(self, ctx: SfgContext) -> set[SfgHeaderInclude]: + includes = set() + for func in ctx.functions(): + includes |= self.visit(func) + + for cls in ctx.classes(): + includes |= self.visit(cls) + + return includes + @visit.case(SfgCallTreeNode) def tree_node(self, node: SfgCallTreeNode) -> set[SfgHeaderInclude]: return reduce( - lambda accu, child: accu | self.visit(child), node.children, node.required_includes + lambda accu, child: accu | self.visit(child), + node.children, + node.required_includes, ) @visit.case(SfgFunction) @@ -43,3 +63,7 @@ class CollectIncludes: @visit.case(SfgMemberVariable) def sfg_member_var(self, var: SfgMemberVariable) -> set[SfgHeaderInclude]: return var.required_includes + + @visit.case(SfgInClassDefinition) + def sfg_cls_def(self, _: SfgInClassDefinition) -> set[SfgHeaderInclude]: + return set()