diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000000000000000000000000000000000000..ca7990ba101438a4b0bebc14824640a2f0c29479 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,5 @@ +[mypy] +python_version=3.10 + +[mypy-pystencils.*] +ignore_missing_imports=true diff --git a/pystencilssfg/configuration.py b/pystencilssfg/configuration.py index 2d68d35df32e6becfc239d16e3fcba5d303d4d0a..07a6b170d4012b92c743133797ddb54a26ebf2f8 100644 --- a/pystencilssfg/configuration.py +++ b/pystencilssfg/configuration.py @@ -1,12 +1,12 @@ from __future__ import annotations -from typing import List, Sequence, Any +from typing import Sequence, Any from enum import Enum, auto from dataclasses import dataclass, replace, asdict, InitVar from argparse import ArgumentParser from os import path -import importlib +from importlib import util as iutil from jinja2.filters import do_indent @@ -24,7 +24,7 @@ class SfgConfigSource(Enum): class SfgConfigException(Exception): - def __init__(self, cfg_src: SfgConfigSource, message: str): + def __init__(self, cfg_src: SfgConfigSource | None, message: str): assert cfg_src != SfgConfigSource.DEFAULT, "Invalid default config. Contact a developer." super().__init__(cfg_src, message) @@ -44,29 +44,29 @@ class SfgCodeStyle: class SfgConfiguration: config_source: InitVar[SfgConfigSource | None] = None - header_extension: str = None + header_extension: str | None = None """File extension for generated header files.""" - source_extension: str = None + source_extension: str | None = None """File extension for generated source files.""" - header_only: bool = None + header_only: bool | None = None """If set to `True`, generate only a header file without accompaning source file.""" - base_namespace: str = None + base_namespace: str | None = None """The outermost namespace in the generated file. May be a valid C++ nested namespace qualifier (like `a::b::c`) or `None` if no outer namespace should be generated.""" - codestyle: SfgCodeStyle = None + codestyle: SfgCodeStyle | None = None """Code style that should be used by the code generator.""" - output_directory: str = None + output_directory: str | None = None """Directory to which the generated files should be written.""" project_info: Any = None """Object for managing project-specific information. To be set by the configurator script.""" - def __post_init__(self, cfg_src: SfgConfigSource = None): + def __post_init__(self, cfg_src: SfgConfigSource | None = None): if self.header_only: raise SfgConfigException(cfg_src, "Header-only code generation is not implemented yet.") @@ -94,12 +94,13 @@ DEFAULT_CONFIG = SfgConfiguration( def run_configurator(configurator_script: str): - if not path.exists(configurator_script): + cfg_spec = iutil.spec_from_file_location(configurator_script) + + if cfg_spec is None: raise SfgConfigException(SfgConfigSource.PROJECT, - f"Configurator script not found: {configurator_script} is not a file.") + f"Unable to load configurator script {configurator_script}") - cfg_spec = importlib.util.spec_from_file_location(configurator_script) - configurator = importlib.util.module_from_spec(cfg_spec) + configurator = iutil.module_from_spec(cfg_spec) if not hasattr(configurator, "sfg_config"): raise SfgConfigException(SfgConfigSource.PROJECT, "Project configurator does not define function `sfg_config`.") @@ -150,7 +151,7 @@ def config_from_parser_args(args): return project_config, cmdline_config -def config_from_commandline(argv: List[str]): +def config_from_commandline(argv: list[str]): parser = ArgumentParser("pystencilssfg", description="pystencils Source File Generator", allow_abbrev=False) @@ -163,9 +164,9 @@ def config_from_commandline(argv: List[str]): return project_config, cmdline_config, script_args -def merge_configurations(project_config: SfgConfiguration, - cmdline_config: SfgConfiguration, - script_config: SfgConfiguration): +def merge_configurations(project_config: SfgConfiguration | None, + cmdline_config: SfgConfiguration | None, + script_config: SfgConfiguration | None): # Project config completely overrides default config config = DEFAULT_CONFIG @@ -197,7 +198,7 @@ def _get_file_extensions(cfgsrc: SfgConfigSource, extensions: Sequence[str]): h_ext = None src_ext = None - extensions = ((ext[1:] if ext[0] == '.' else ext) for ext in extensions) + extensions = tuple((ext[1:] if ext[0] == '.' else ext) for ext in extensions) for ext in extensions: if ext in HEADER_FILE_EXTENSIONS: diff --git a/pystencilssfg/context.py b/pystencilssfg/context.py index 61cb60c5876e6c3f9ebb19963d3d44b56ac71dcb..9bed7e0ec60c54ad1205748cc9a8ecc408cde1e2 100644 --- a/pystencilssfg/context.py +++ b/pystencilssfg/context.py @@ -19,7 +19,7 @@ from .source_components import SfgFunction, SfgHeaderInclude class SourceFileGenerator: - def __init__(self, sfg_config: SfgConfiguration = None): + def __init__(self, sfg_config: SfgConfiguration | None = None): if sfg_config and not isinstance(sfg_config, SfgConfiguration): raise TypeError("sfg_config is not an SfgConfiguration.") @@ -60,32 +60,34 @@ class SfgContext: self._code_namespace = None # Source Components - self._includes = set() + self._includes: set[SfgHeaderInclude] = set() self._kernel_namespaces = {self._default_kernel_namespace.name: self._default_kernel_namespace} - self._functions = dict() + self._functions: dict[str, SfgFunction] = dict() @property def argv(self) -> Sequence[str]: return self._argv @property - def root_namespace(self) -> str: + def root_namespace(self) -> str | None: return self._config.base_namespace @property - def inner_namespace(self) -> str: + def inner_namespace(self) -> str | None: return self._code_namespace @property - def fully_qualified_namespace(self) -> str: + def fully_qualified_namespace(self) -> str | None: match (self.root_namespace, self.inner_namespace): case None, None: return None case outer, None: return outer case None, inner: return inner case outer, inner: return f"{outer}::{inner}" + case _: assert False @property def codestyle(self) -> SfgCodeStyle: + assert self._config.codestyle is not None return self._config.codestyle # ---------------------------------------------------------------------------------------------- @@ -176,7 +178,7 @@ class SfgContext: def branch(self) -> SfgBranchBuilder: return SfgBranchBuilder() - def map_field(self, field: Field, src_object: Optional[SrcField] = None) -> SfgSequence: + def map_field(self, field: Field, src_object: Optional[SrcField] = None) -> SfgDeferredFieldMapping: if src_object is None: raise NotImplementedError("Automatic field extraction is not implemented yet.") else: diff --git a/pystencilssfg/emitters/cpu/basic_cpu.py b/pystencilssfg/emitters/cpu/basic_cpu.py index e6c3438869d40ef69390b8631e3162cd396d798b..46ddfb5ccac5c8106a92b1af34d38184d7ca69ce 100644 --- a/pystencilssfg/emitters/cpu/basic_cpu.py +++ b/pystencilssfg/emitters/cpu/basic_cpu.py @@ -1,3 +1,4 @@ +from typing import cast from jinja2 import Environment, PackageLoader, StrictUndefined from os import path @@ -9,12 +10,12 @@ from ...context import SfgContext class BasicCpuEmitter: def __init__(self, basename: str, config: SfgConfiguration): self._basename = basename - self._output_directory = config.output_directory + self._output_directory = cast(str, config.output_directory) self._header_filename = f"{basename}.{config.header_extension}" self._source_filename = f"{basename}.{config.source_extension}" @property - def output_files(self) -> str: + def output_files(self) -> tuple[str, str]: return ( path.join(self._output_directory, self._header_filename), path.join(self._output_directory, self._source_filename) diff --git a/pystencilssfg/kernel_namespace.py b/pystencilssfg/kernel_namespace.py index ce5dde5abfc87e1ecee0afdff397e9e46c51edd2..f11b8f1b97362f76ba00821097d2795107c142ac 100644 --- a/pystencilssfg/kernel_namespace.py +++ b/pystencilssfg/kernel_namespace.py @@ -8,7 +8,7 @@ class SfgKernelNamespace: def __init__(self, ctx, name: str): self._ctx = ctx self._name = name - self._asts = dict() + self._asts: dict[str, KernelFunction] = dict() @property def name(self): diff --git a/pystencilssfg/source_components/header_include.py b/pystencilssfg/source_components/header_include.py index fb4aad62ab6a370ed489f05458cefb0a1e41bae6..fc9cd87b95005422840eebd4b33e145c1e24813f 100644 --- a/pystencilssfg/source_components/header_include.py +++ b/pystencilssfg/source_components/header_include.py @@ -24,7 +24,7 @@ class SfgHeaderInclude: def __hash__(self) -> int: return hash((self._header_file, self._system_header, self._private)) - def __eq__(self, other: SfgHeaderInclude) -> bool: + def __eq__(self, other: object) -> bool: return (isinstance(other, SfgHeaderInclude) and self._header_file == other._header_file and self._system_header == other._system_header diff --git a/pystencilssfg/source_concepts/cpp/std_mdspan.py b/pystencilssfg/source_concepts/cpp/std_mdspan.py index c87b5b1e35f4de210f3f2399e7aef7c1816d2a37..2550991b009799f6ec8235560987f9ac0559c200 100644 --- a/pystencilssfg/source_concepts/cpp/std_mdspan.py +++ b/pystencilssfg/source_concepts/cpp/std_mdspan.py @@ -1,11 +1,11 @@ -from typing import Set, Union, Tuple +from typing import Union from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol from ...tree import SfgStatements from ..source_objects import SrcField from ...source_components.header_include import SfgHeaderInclude -from ...types import PsType, cpp_typename +from ...types import PsType, cpp_typename, SrcType from ...exceptions import SfgException @@ -14,20 +14,20 @@ class std_mdspan(SrcField): def __init__(self, identifer: str, T: PsType, - extents: Tuple[int, str], + extents: tuple[int, str], extents_type: PsType = int, reference: bool = False): - T = cpp_typename(T) - extents_type = cpp_typename(extents_type) + cpp_typestr = cpp_typename(T) + extents_type_str = cpp_typename(extents_type) - extents_str = f"std::extents< {extents_type}, {', '.join(str(e) for e in extents)} >" - typestring = f"std::mdspan< {T}, {extents_str} > {'&' if reference else ''}" - super().__init__(typestring, identifer) + extents_str = f"std::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >" + typestring = f"std::mdspan< {cpp_typestr}, {extents_str} > {'&' if reference else ''}" + super().__init__(SrcType(typestring), identifer) self._extents = extents @property - def required_includes(self) -> Set[SfgHeaderInclude]: + def required_includes(self) -> set[SfgHeaderInclude]: return {SfgHeaderInclude("experimental/mdspan", system_header=True)} def extract_ptr(self, ptr_symbol: FieldPointerSymbol): diff --git a/pystencilssfg/source_concepts/cpp/std_vector.py b/pystencilssfg/source_concepts/cpp/std_vector.py index f203f06437d8ed44d6c4c7266f440978ab980da6..5398f5dfdc5a84488f53d22fd1af7707d558f938 100644 --- a/pystencilssfg/source_concepts/cpp/std_vector.py +++ b/pystencilssfg/source_concepts/cpp/std_vector.py @@ -1,4 +1,4 @@ -from typing import Set, Union +from typing import Union from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol @@ -13,13 +13,13 @@ from ...exceptions import SfgException class std_vector(SrcVector, SrcField): def __init__(self, identifer: str, T: Union[SrcType, PsType], unsafe: bool = False): typestring = f"std::vector< {cpp_typename(T)} >" - super(SrcObject, self).__init__(identifer, typestring) + super(std_vector, self).__init__(SrcType(typestring), identifer) self._element_type = T self._unsafe = unsafe @property - def required_includes(self) -> Set[SfgHeaderInclude]: + def required_includes(self) -> set[SfgHeaderInclude]: return {SfgHeaderInclude("vector", system_header=True)} def extract_ptr(self, ptr_symbol: FieldPointerSymbol): @@ -71,4 +71,4 @@ class std_vector(SrcVector, SrcField): class std_vector_ref(std_vector): def __init__(self, identifer: str, T: Union[SrcType, PsType]): typestring = f"std::vector< {T} > &" - super(SrcObject, self).__init__(identifer, typestring) + super(std_vector_ref, self).__init__(identifer, SrcType(typestring)) diff --git a/pystencilssfg/source_concepts/source_objects.py b/pystencilssfg/source_concepts/source_objects.py index 6e2849c1f77b91d88057c3f649daee2ae4c39ebd..f7e5a8ae4158588eb193d2c68aaf7bdb2a24afc7 100644 --- a/pystencilssfg/source_concepts/source_objects.py +++ b/pystencilssfg/source_concepts/source_objects.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Union, Set, TypeAlias +from typing import TYPE_CHECKING, Optional, Union, TypeAlias from abc import ABC, abstractmethod @@ -37,13 +37,13 @@ class SrcObject: return self._src_type @property - def required_includes(self) -> Set[SfgHeaderInclude]: + def required_includes(self) -> set[SfgHeaderInclude]: return set() def __hash__(self) -> int: return hash((self._identifier, self._src_type)) - def __eq__(self, other: SrcObject) -> bool: + def __eq__(self, other: object) -> bool: return (isinstance(other, SrcObject) and self._identifier == other._identifier and self._src_type == other._src_type) diff --git a/pystencilssfg/tree/basic_nodes.py b/pystencilssfg/tree/basic_nodes.py index 89d0edfb22ebe3f0794802943f4dab98e7ab840e..1363f352ae607837eb1b86b533d6aec45e4f84c9 100644 --- a/pystencilssfg/tree/basic_nodes.py +++ b/pystencilssfg/tree/basic_nodes.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence, Set, Tuple +from typing import TYPE_CHECKING, Sequence from abc import ABC, abstractmethod from itertools import chain @@ -15,14 +15,11 @@ if TYPE_CHECKING: class SfgCallTreeNode(ABC): """Base class for all nodes comprising SFG call trees. """ def __init__(self, *children: SfgCallTreeNode): - self._children = children + self._children = list(children) @property - def children(self) -> Tuple[SfgCallTreeNode]: - return self._children - - def child(self, idx: int) -> SfgCallTreeNode: - return self._children[idx] + def children(self) -> tuple[SfgCallTreeNode, ...]: + return tuple(self._children) @children.setter def children(self, cs: Sequence[SfgCallTreeNode]) -> None: @@ -30,6 +27,9 @@ class SfgCallTreeNode(ABC): raise ValueError("The number of child nodes must remain the same!") self._children = list(cs) + def child(self, idx: int) -> SfgCallTreeNode: + return self._children[idx] + def __getitem__(self, idx: int) -> SfgCallTreeNode: return self._children[idx] @@ -44,7 +44,7 @@ class SfgCallTreeNode(ABC): """ @property - def required_includes(self) -> Set[SfgHeaderInclude]: + def required_includes(self) -> set[SfgHeaderInclude]: return set() @@ -52,7 +52,7 @@ class SfgCallTreeLeaf(SfgCallTreeNode, ABC): @property @abstractmethod - def required_parameters(self) -> Set[TypedSymbolOrObject]: + def required_parameters(self) -> set[TypedSymbolOrObject]: pass @@ -90,15 +90,15 @@ class SfgStatements(SfgCallTreeLeaf): self._required_includes |= obj.required_includes @property - def required_parameters(self) -> Set[TypedSymbolOrObject]: + def required_parameters(self) -> set[TypedSymbolOrObject]: return self._required_params @property - def defined_parameters(self) -> Set[TypedSymbolOrObject]: + def defined_parameters(self) -> set[TypedSymbolOrObject]: return self._defined_params @property - def required_includes(self) -> Set[SfgHeaderInclude]: + def required_includes(self) -> set[SfgHeaderInclude]: return self._required_includes def get_code(self, ctx: SfgContext) -> str: @@ -122,7 +122,7 @@ class SfgBlock(SfgCallTreeNode): return self._children[0] def get_code(self, ctx: SfgContext) -> str: - subtree_code = ctx.codestyle.indent(self._subtree.get_code(ctx)) + subtree_code = ctx.codestyle.indent(self.subtree.get_code(ctx)) return "{\n" + subtree_code + "\n}" @@ -133,7 +133,7 @@ class SfgKernelCallNode(SfgCallTreeLeaf): self._kernel_handle = kernel_handle @property - def required_parameters(self) -> Set[TypedSymbolOrObject]: + def required_parameters(self) -> set[TypedSymbolOrObject]: return set(p.symbol for p in self._kernel_handle.parameters) def get_code(self, ctx: SfgContext) -> str: diff --git a/pystencilssfg/tree/conditional.py b/pystencilssfg/tree/conditional.py index 45e91d8f8914c04870c3f97ed3427b3fa892a042..39663cbf60eace8f13053b6064a8deb6ee53f99e 100644 --- a/pystencilssfg/tree/conditional.py +++ b/pystencilssfg/tree/conditional.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Set +from typing import TYPE_CHECKING, Optional, cast from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf from ..source_concepts.source_objects import TypedSymbolOrObject @@ -18,7 +18,7 @@ class SfgCustomCondition(SfgCondition): self._cond_text = cond_text @property - def required_parameters(self) -> Set[TypedSymbolOrObject]: + def required_parameters(self) -> set[TypedSymbolOrObject]: return set() def get_code(self, ctx: SfgContext) -> str: @@ -38,7 +38,7 @@ class SfgBranch(SfgCallTreeNode): @property def condition(self) -> SfgCondition: - return self._children[0] + return cast(SfgCondition, self._children[0]) @property def branch_true(self) -> SfgCallTreeNode: diff --git a/pystencilssfg/tree/deferred_nodes.py b/pystencilssfg/tree/deferred_nodes.py index 4f2aef554c161fe13cb55cf574105fdc5773c9ec..3eb42e5df94e0e02ac8f639852e31290640f0c45 100644 --- a/pystencilssfg/tree/deferred_nodes.py +++ b/pystencilssfg/tree/deferred_nodes.py @@ -1,5 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Set +from typing import TYPE_CHECKING + +from pystencilssfg.context import SfgContext if TYPE_CHECKING: from ..context import SfgContext @@ -15,6 +17,7 @@ from .basic_nodes import SfgCallTreeNode from .builders import make_sequence from ..source_concepts import SrcField +from ..source_concepts.source_objects import TypedSymbolOrObject class SfgDeferredNode(SfgCallTreeNode, ABC): @@ -32,16 +35,13 @@ class SfgDeferredNode(SfgCallTreeNode, ABC): def __init__(self): self._children = SfgDeferredNode.InvalidAccess - get_code = InvalidAccess - - @abstractmethod - def expand(self, ctx: SfgContext, *args, **kwargs) -> SfgCallTreeNode: - pass + def get_code(self, ctx: SfgContext) -> str: + raise SfgException("Invalid access into deferred node; deferred nodes must be expanded first.") class SfgParamCollectionDeferredNode(SfgDeferredNode, ABC): @abstractmethod - def expand(self, ctx: SfgContext, visible_params: Set[TypedSymbol]) -> SfgCallTreeNode: + def expand(self, ctx: SfgContext, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode: pass @@ -50,7 +50,7 @@ class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode): self._field = field self._src_field = src_field - def expand(self, ctx: SfgContext, visible_params: Set[TypedSymbol]) -> SfgCallTreeNode: + def expand(self, ctx: SfgContext, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode: # Find field pointer ptr = None for param in visible_params: diff --git a/pystencilssfg/tree/dispatcher.py b/pystencilssfg/tree/dispatcher.py index 5cab6ae732482457f37e3ead6d5393b61cb10107..9ecba4801e0b10e7501c9276adecdb9c17a048bd 100644 --- a/pystencilssfg/tree/dispatcher.py +++ b/pystencilssfg/tree/dispatcher.py @@ -1,21 +1,24 @@ from __future__ import annotations -from typing import Callable +from typing import Callable, TypeVar, Generic, Any, ParamSpec, Concatenate from types import MethodType from functools import wraps from .basic_nodes import SfgCallTreeNode +V = TypeVar("V") +R = TypeVar("R") +P = ParamSpec("P") -class VisitorDispatcher: - def __init__(self, wrapped_method): - self._dispatch_dict = {} - self._wrapped_method = wrapped_method +class VisitorDispatcher(Generic[V, R]): + def __init__(self, wrapped_method: Callable[..., R]): + self._dispatch_dict: dict[type, Callable[..., R]] = {} + self._wrapped_method: Callable[..., R] = wrapped_method def case(self, node_type: type): """Decorator for visitor's methods""" - def decorate(handler: Callable): + def decorate(handler: Callable[..., R]): if node_type in self._dispatch_dict: raise ValueError(f"Duplicate visitor case {node_type}") self._dispatch_dict[node_type] = handler @@ -23,14 +26,14 @@ class VisitorDispatcher: return decorate - def __call__(self, instance, node: SfgCallTreeNode, *args, **kwargs): + def __call__(self, instance: V, node: SfgCallTreeNode, *args, **kwargs) -> R: for cls in node.__class__.mro(): if cls in self._dispatch_dict: return self._dispatch_dict[cls](instance, node, *args, **kwargs) return self._wrapped_method(instance, node, *args, **kwargs) - def __get__(self, obj, objtype=None): + def __get__(self, obj: V, objtype=None) -> Callable[..., R]: if obj is None: return self return MethodType(self, obj) diff --git a/pystencilssfg/tree/visitors.py b/pystencilssfg/tree/visitors.py index 043e8ad9008034d85e586b952246e20d7cfc69b7..434e271540562b0827b0f0e9e2b6505c50012655 100644 --- a/pystencilssfg/tree/visitors.py +++ b/pystencilssfg/tree/visitors.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Set +from typing import TYPE_CHECKING from functools import reduce @@ -9,6 +9,7 @@ from pystencils.typing import TypedSymbol from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements from .deferred_nodes import SfgParamCollectionDeferredNode from .dispatcher import visitor +from ..source_concepts.source_objects import TypedSymbolOrObject if TYPE_CHECKING: from ..context import SfgContext @@ -61,28 +62,28 @@ class ExpandingParameterCollector(): self._flattener = FlattenSequences() @visitor - def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]: + def visit(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]: return self.branching_node(node) @visit.case(SfgCallTreeLeaf) - def leaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]: + def leaf(self, leaf: SfgCallTreeLeaf) -> set[TypedSymbolOrObject]: return leaf.required_parameters @visit.case(SfgSequence) - def sequence(self, sequence: SfgSequence) -> Set[TypedSymbol]: + def sequence(self, sequence: SfgSequence) -> set[TypedSymbolOrObject]: """ Only in a sequence may parameters be defined and visible to subsequent nodes. """ - params = set() + params: set[TypedSymbolOrObject] = set() - def iter_nested_sequences(seq: SfgSequence, visible_params: Set[TypedSymbol]): + def iter_nested_sequences(seq: SfgSequence, visible_params: set[TypedSymbolOrObject]): for i in range(len(seq.children) - 1, -1, -1): c = seq.children[i] if isinstance(c, SfgParamCollectionDeferredNode): c = c.expand(self._ctx, visible_params=visible_params) - seq.replace_child(i, c) + seq[i] = c if isinstance(c, SfgSequence): iter_nested_sequences(c, visible_params) @@ -96,7 +97,7 @@ class ExpandingParameterCollector(): return params - def branching_node(self, node: SfgCallTreeNode): + def branching_node(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]: """ Each interior node that is not a sequence simply requires the union of all parameters required by its children. @@ -111,20 +112,20 @@ class ParameterCollector(): """ @visitor - def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]: + def visit(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]: return self.branching_node(node) @visit.case(SfgCallTreeLeaf) - def leaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]: + def leaf(self, leaf: SfgCallTreeLeaf) -> set[TypedSymbolOrObject]: return leaf.required_parameters @visit.case(SfgSequence) - def sequence(self, sequence: SfgSequence) -> Set[TypedSymbol]: + def sequence(self, sequence: SfgSequence) -> set[TypedSymbolOrObject]: """ Only in a sequence may parameters be defined and visible to subsequent nodes. """ - params = set() + params: set[TypedSymbolOrObject] = set() for c in sequence.children[::-1]: if isinstance(c, SfgStatements): params -= c.defined_parameters @@ -133,7 +134,7 @@ class ParameterCollector(): params |= self.visit(c) return params - def branching_node(self, node: SfgCallTreeNode): + def branching_node(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]: """ Each interior node that is not a sequence simply requires the union of all parameters required by its children.