Skip to content
Snippets Groups Projects
Commit 3a49e20e authored by Frederik Hennig's avatar Frederik Hennig
Browse files

made mypy happy

parent d96021b8
No related merge requests found
Pipeline #57178 failed with stage
in 3 minutes and 12 seconds
Showing with 105 additions and 92 deletions
[mypy]
python_version=3.10
[mypy-pystencils.*]
ignore_missing_imports=true
from __future__ import annotations
from typing import List, Sequence, Any
from typing import Sequence, Any
from enum import Enum, auto
from dataclasses import dataclass, replace, asdict, InitVar
from argparse import ArgumentParser
from os import path
import importlib
from importlib import util as iutil
from jinja2.filters import do_indent
......@@ -24,7 +24,7 @@ class SfgConfigSource(Enum):
class SfgConfigException(Exception):
def __init__(self, cfg_src: SfgConfigSource, message: str):
def __init__(self, cfg_src: SfgConfigSource | None, message: str):
assert cfg_src != SfgConfigSource.DEFAULT, "Invalid default config. Contact a developer."
super().__init__(cfg_src, message)
......@@ -44,29 +44,29 @@ class SfgCodeStyle:
class SfgConfiguration:
config_source: InitVar[SfgConfigSource | None] = None
header_extension: str = None
header_extension: str | None = None
"""File extension for generated header files."""
source_extension: str = None
source_extension: str | None = None
"""File extension for generated source files."""
header_only: bool = None
header_only: bool | None = None
"""If set to `True`, generate only a header file without accompaning source file."""
base_namespace: str = None
base_namespace: str | None = None
"""The outermost namespace in the generated file. May be a valid C++ nested namespace qualifier
(like `a::b::c`) or `None` if no outer namespace should be generated."""
codestyle: SfgCodeStyle = None
codestyle: SfgCodeStyle | None = None
"""Code style that should be used by the code generator."""
output_directory: str = None
output_directory: str | None = None
"""Directory to which the generated files should be written."""
project_info: Any = None
"""Object for managing project-specific information. To be set by the configurator script."""
def __post_init__(self, cfg_src: SfgConfigSource = None):
def __post_init__(self, cfg_src: SfgConfigSource | None = None):
if self.header_only:
raise SfgConfigException(cfg_src, "Header-only code generation is not implemented yet.")
......@@ -94,12 +94,13 @@ DEFAULT_CONFIG = SfgConfiguration(
def run_configurator(configurator_script: str):
if not path.exists(configurator_script):
cfg_spec = iutil.spec_from_file_location(configurator_script)
if cfg_spec is None:
raise SfgConfigException(SfgConfigSource.PROJECT,
f"Configurator script not found: {configurator_script} is not a file.")
f"Unable to load configurator script {configurator_script}")
cfg_spec = importlib.util.spec_from_file_location(configurator_script)
configurator = importlib.util.module_from_spec(cfg_spec)
configurator = iutil.module_from_spec(cfg_spec)
if not hasattr(configurator, "sfg_config"):
raise SfgConfigException(SfgConfigSource.PROJECT, "Project configurator does not define function `sfg_config`.")
......@@ -150,7 +151,7 @@ def config_from_parser_args(args):
return project_config, cmdline_config
def config_from_commandline(argv: List[str]):
def config_from_commandline(argv: list[str]):
parser = ArgumentParser("pystencilssfg",
description="pystencils Source File Generator",
allow_abbrev=False)
......@@ -163,9 +164,9 @@ def config_from_commandline(argv: List[str]):
return project_config, cmdline_config, script_args
def merge_configurations(project_config: SfgConfiguration,
cmdline_config: SfgConfiguration,
script_config: SfgConfiguration):
def merge_configurations(project_config: SfgConfiguration | None,
cmdline_config: SfgConfiguration | None,
script_config: SfgConfiguration | None):
# Project config completely overrides default config
config = DEFAULT_CONFIG
......@@ -197,7 +198,7 @@ def _get_file_extensions(cfgsrc: SfgConfigSource, extensions: Sequence[str]):
h_ext = None
src_ext = None
extensions = ((ext[1:] if ext[0] == '.' else ext) for ext in extensions)
extensions = tuple((ext[1:] if ext[0] == '.' else ext) for ext in extensions)
for ext in extensions:
if ext in HEADER_FILE_EXTENSIONS:
......
......@@ -19,7 +19,7 @@ from .source_components import SfgFunction, SfgHeaderInclude
class SourceFileGenerator:
def __init__(self, sfg_config: SfgConfiguration = None):
def __init__(self, sfg_config: SfgConfiguration | None = None):
if sfg_config and not isinstance(sfg_config, SfgConfiguration):
raise TypeError("sfg_config is not an SfgConfiguration.")
......@@ -60,32 +60,34 @@ class SfgContext:
self._code_namespace = None
# Source Components
self._includes = set()
self._includes: set[SfgHeaderInclude] = set()
self._kernel_namespaces = {self._default_kernel_namespace.name: self._default_kernel_namespace}
self._functions = dict()
self._functions: dict[str, SfgFunction] = dict()
@property
def argv(self) -> Sequence[str]:
return self._argv
@property
def root_namespace(self) -> str:
def root_namespace(self) -> str | None:
return self._config.base_namespace
@property
def inner_namespace(self) -> str:
def inner_namespace(self) -> str | None:
return self._code_namespace
@property
def fully_qualified_namespace(self) -> str:
def fully_qualified_namespace(self) -> str | None:
match (self.root_namespace, self.inner_namespace):
case None, None: return None
case outer, None: return outer
case None, inner: return inner
case outer, inner: return f"{outer}::{inner}"
case _: assert False
@property
def codestyle(self) -> SfgCodeStyle:
assert self._config.codestyle is not None
return self._config.codestyle
# ----------------------------------------------------------------------------------------------
......@@ -176,7 +178,7 @@ class SfgContext:
def branch(self) -> SfgBranchBuilder:
return SfgBranchBuilder()
def map_field(self, field: Field, src_object: Optional[SrcField] = None) -> SfgSequence:
def map_field(self, field: Field, src_object: Optional[SrcField] = None) -> SfgDeferredFieldMapping:
if src_object is None:
raise NotImplementedError("Automatic field extraction is not implemented yet.")
else:
......
from typing import cast
from jinja2 import Environment, PackageLoader, StrictUndefined
from os import path
......@@ -9,12 +10,12 @@ from ...context import SfgContext
class BasicCpuEmitter:
def __init__(self, basename: str, config: SfgConfiguration):
self._basename = basename
self._output_directory = config.output_directory
self._output_directory = cast(str, config.output_directory)
self._header_filename = f"{basename}.{config.header_extension}"
self._source_filename = f"{basename}.{config.source_extension}"
@property
def output_files(self) -> str:
def output_files(self) -> tuple[str, str]:
return (
path.join(self._output_directory, self._header_filename),
path.join(self._output_directory, self._source_filename)
......
......@@ -8,7 +8,7 @@ class SfgKernelNamespace:
def __init__(self, ctx, name: str):
self._ctx = ctx
self._name = name
self._asts = dict()
self._asts: dict[str, KernelFunction] = dict()
@property
def name(self):
......
......@@ -24,7 +24,7 @@ class SfgHeaderInclude:
def __hash__(self) -> int:
return hash((self._header_file, self._system_header, self._private))
def __eq__(self, other: SfgHeaderInclude) -> bool:
def __eq__(self, other: object) -> bool:
return (isinstance(other, SfgHeaderInclude)
and self._header_file == other._header_file
and self._system_header == other._system_header
......
from typing import Set, Union, Tuple
from typing import Union
from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol
from ...tree import SfgStatements
from ..source_objects import SrcField
from ...source_components.header_include import SfgHeaderInclude
from ...types import PsType, cpp_typename
from ...types import PsType, cpp_typename, SrcType
from ...exceptions import SfgException
......@@ -14,20 +14,20 @@ class std_mdspan(SrcField):
def __init__(self, identifer: str,
T: PsType,
extents: Tuple[int, str],
extents: tuple[int, str],
extents_type: PsType = int,
reference: bool = False):
T = cpp_typename(T)
extents_type = cpp_typename(extents_type)
cpp_typestr = cpp_typename(T)
extents_type_str = cpp_typename(extents_type)
extents_str = f"std::extents< {extents_type}, {', '.join(str(e) for e in extents)} >"
typestring = f"std::mdspan< {T}, {extents_str} > {'&' if reference else ''}"
super().__init__(typestring, identifer)
extents_str = f"std::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >"
typestring = f"std::mdspan< {cpp_typestr}, {extents_str} > {'&' if reference else ''}"
super().__init__(SrcType(typestring), identifer)
self._extents = extents
@property
def required_includes(self) -> Set[SfgHeaderInclude]:
def required_includes(self) -> set[SfgHeaderInclude]:
return {SfgHeaderInclude("experimental/mdspan", system_header=True)}
def extract_ptr(self, ptr_symbol: FieldPointerSymbol):
......
from typing import Set, Union
from typing import Union
from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol
......@@ -13,13 +13,13 @@ from ...exceptions import SfgException
class std_vector(SrcVector, SrcField):
def __init__(self, identifer: str, T: Union[SrcType, PsType], unsafe: bool = False):
typestring = f"std::vector< {cpp_typename(T)} >"
super(SrcObject, self).__init__(identifer, typestring)
super(std_vector, self).__init__(SrcType(typestring), identifer)
self._element_type = T
self._unsafe = unsafe
@property
def required_includes(self) -> Set[SfgHeaderInclude]:
def required_includes(self) -> set[SfgHeaderInclude]:
return {SfgHeaderInclude("vector", system_header=True)}
def extract_ptr(self, ptr_symbol: FieldPointerSymbol):
......@@ -71,4 +71,4 @@ class std_vector(SrcVector, SrcField):
class std_vector_ref(std_vector):
def __init__(self, identifer: str, T: Union[SrcType, PsType]):
typestring = f"std::vector< {T} > &"
super(SrcObject, self).__init__(identifer, typestring)
super(std_vector_ref, self).__init__(identifer, SrcType(typestring))
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Union, Set, TypeAlias
from typing import TYPE_CHECKING, Optional, Union, TypeAlias
from abc import ABC, abstractmethod
......@@ -37,13 +37,13 @@ class SrcObject:
return self._src_type
@property
def required_includes(self) -> Set[SfgHeaderInclude]:
def required_includes(self) -> set[SfgHeaderInclude]:
return set()
def __hash__(self) -> int:
return hash((self._identifier, self._src_type))
def __eq__(self, other: SrcObject) -> bool:
def __eq__(self, other: object) -> bool:
return (isinstance(other, SrcObject)
and self._identifier == other._identifier
and self._src_type == other._src_type)
......
from __future__ import annotations
from typing import TYPE_CHECKING, Sequence, Set, Tuple
from typing import TYPE_CHECKING, Sequence
from abc import ABC, abstractmethod
from itertools import chain
......@@ -15,14 +15,11 @@ if TYPE_CHECKING:
class SfgCallTreeNode(ABC):
"""Base class for all nodes comprising SFG call trees. """
def __init__(self, *children: SfgCallTreeNode):
self._children = children
self._children = list(children)
@property
def children(self) -> Tuple[SfgCallTreeNode]:
return self._children
def child(self, idx: int) -> SfgCallTreeNode:
return self._children[idx]
def children(self) -> tuple[SfgCallTreeNode, ...]:
return tuple(self._children)
@children.setter
def children(self, cs: Sequence[SfgCallTreeNode]) -> None:
......@@ -30,6 +27,9 @@ class SfgCallTreeNode(ABC):
raise ValueError("The number of child nodes must remain the same!")
self._children = list(cs)
def child(self, idx: int) -> SfgCallTreeNode:
return self._children[idx]
def __getitem__(self, idx: int) -> SfgCallTreeNode:
return self._children[idx]
......@@ -44,7 +44,7 @@ class SfgCallTreeNode(ABC):
"""
@property
def required_includes(self) -> Set[SfgHeaderInclude]:
def required_includes(self) -> set[SfgHeaderInclude]:
return set()
......@@ -52,7 +52,7 @@ class SfgCallTreeLeaf(SfgCallTreeNode, ABC):
@property
@abstractmethod
def required_parameters(self) -> Set[TypedSymbolOrObject]:
def required_parameters(self) -> set[TypedSymbolOrObject]:
pass
......@@ -90,15 +90,15 @@ class SfgStatements(SfgCallTreeLeaf):
self._required_includes |= obj.required_includes
@property
def required_parameters(self) -> Set[TypedSymbolOrObject]:
def required_parameters(self) -> set[TypedSymbolOrObject]:
return self._required_params
@property
def defined_parameters(self) -> Set[TypedSymbolOrObject]:
def defined_parameters(self) -> set[TypedSymbolOrObject]:
return self._defined_params
@property
def required_includes(self) -> Set[SfgHeaderInclude]:
def required_includes(self) -> set[SfgHeaderInclude]:
return self._required_includes
def get_code(self, ctx: SfgContext) -> str:
......@@ -122,7 +122,7 @@ class SfgBlock(SfgCallTreeNode):
return self._children[0]
def get_code(self, ctx: SfgContext) -> str:
subtree_code = ctx.codestyle.indent(self._subtree.get_code(ctx))
subtree_code = ctx.codestyle.indent(self.subtree.get_code(ctx))
return "{\n" + subtree_code + "\n}"
......@@ -133,7 +133,7 @@ class SfgKernelCallNode(SfgCallTreeLeaf):
self._kernel_handle = kernel_handle
@property
def required_parameters(self) -> Set[TypedSymbolOrObject]:
def required_parameters(self) -> set[TypedSymbolOrObject]:
return set(p.symbol for p in self._kernel_handle.parameters)
def get_code(self, ctx: SfgContext) -> str:
......
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Set
from typing import TYPE_CHECKING, Optional, cast
from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf
from ..source_concepts.source_objects import TypedSymbolOrObject
......@@ -18,7 +18,7 @@ class SfgCustomCondition(SfgCondition):
self._cond_text = cond_text
@property
def required_parameters(self) -> Set[TypedSymbolOrObject]:
def required_parameters(self) -> set[TypedSymbolOrObject]:
return set()
def get_code(self, ctx: SfgContext) -> str:
......@@ -38,7 +38,7 @@ class SfgBranch(SfgCallTreeNode):
@property
def condition(self) -> SfgCondition:
return self._children[0]
return cast(SfgCondition, self._children[0])
@property
def branch_true(self) -> SfgCallTreeNode:
......
from __future__ import annotations
from typing import TYPE_CHECKING, Set
from typing import TYPE_CHECKING
from pystencilssfg.context import SfgContext
if TYPE_CHECKING:
from ..context import SfgContext
......@@ -15,6 +17,7 @@ from .basic_nodes import SfgCallTreeNode
from .builders import make_sequence
from ..source_concepts import SrcField
from ..source_concepts.source_objects import TypedSymbolOrObject
class SfgDeferredNode(SfgCallTreeNode, ABC):
......@@ -32,16 +35,13 @@ class SfgDeferredNode(SfgCallTreeNode, ABC):
def __init__(self):
self._children = SfgDeferredNode.InvalidAccess
get_code = InvalidAccess
@abstractmethod
def expand(self, ctx: SfgContext, *args, **kwargs) -> SfgCallTreeNode:
pass
def get_code(self, ctx: SfgContext) -> str:
raise SfgException("Invalid access into deferred node; deferred nodes must be expanded first.")
class SfgParamCollectionDeferredNode(SfgDeferredNode, ABC):
@abstractmethod
def expand(self, ctx: SfgContext, visible_params: Set[TypedSymbol]) -> SfgCallTreeNode:
def expand(self, ctx: SfgContext, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode:
pass
......@@ -50,7 +50,7 @@ class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode):
self._field = field
self._src_field = src_field
def expand(self, ctx: SfgContext, visible_params: Set[TypedSymbol]) -> SfgCallTreeNode:
def expand(self, ctx: SfgContext, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode:
# Find field pointer
ptr = None
for param in visible_params:
......
from __future__ import annotations
from typing import Callable
from typing import Callable, TypeVar, Generic, Any, ParamSpec, Concatenate
from types import MethodType
from functools import wraps
from .basic_nodes import SfgCallTreeNode
V = TypeVar("V")
R = TypeVar("R")
P = ParamSpec("P")
class VisitorDispatcher:
def __init__(self, wrapped_method):
self._dispatch_dict = {}
self._wrapped_method = wrapped_method
class VisitorDispatcher(Generic[V, R]):
def __init__(self, wrapped_method: Callable[..., R]):
self._dispatch_dict: dict[type, Callable[..., R]] = {}
self._wrapped_method: Callable[..., R] = wrapped_method
def case(self, node_type: type):
"""Decorator for visitor's methods"""
def decorate(handler: Callable):
def decorate(handler: Callable[..., R]):
if node_type in self._dispatch_dict:
raise ValueError(f"Duplicate visitor case {node_type}")
self._dispatch_dict[node_type] = handler
......@@ -23,14 +26,14 @@ class VisitorDispatcher:
return decorate
def __call__(self, instance, node: SfgCallTreeNode, *args, **kwargs):
def __call__(self, instance: V, node: SfgCallTreeNode, *args, **kwargs) -> R:
for cls in node.__class__.mro():
if cls in self._dispatch_dict:
return self._dispatch_dict[cls](instance, node, *args, **kwargs)
return self._wrapped_method(instance, node, *args, **kwargs)
def __get__(self, obj, objtype=None):
def __get__(self, obj: V, objtype=None) -> Callable[..., R]:
if obj is None:
return self
return MethodType(self, obj)
......
from __future__ import annotations
from typing import TYPE_CHECKING, Set
from typing import TYPE_CHECKING
from functools import reduce
......@@ -9,6 +9,7 @@ from pystencils.typing import TypedSymbol
from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements
from .deferred_nodes import SfgParamCollectionDeferredNode
from .dispatcher import visitor
from ..source_concepts.source_objects import TypedSymbolOrObject
if TYPE_CHECKING:
from ..context import SfgContext
......@@ -61,28 +62,28 @@ class ExpandingParameterCollector():
self._flattener = FlattenSequences()
@visitor
def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]:
def visit(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]:
return self.branching_node(node)
@visit.case(SfgCallTreeLeaf)
def leaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]:
def leaf(self, leaf: SfgCallTreeLeaf) -> set[TypedSymbolOrObject]:
return leaf.required_parameters
@visit.case(SfgSequence)
def sequence(self, sequence: SfgSequence) -> Set[TypedSymbol]:
def sequence(self, sequence: SfgSequence) -> set[TypedSymbolOrObject]:
"""
Only in a sequence may parameters be defined and visible to subsequent nodes.
"""
params = set()
params: set[TypedSymbolOrObject] = set()
def iter_nested_sequences(seq: SfgSequence, visible_params: Set[TypedSymbol]):
def iter_nested_sequences(seq: SfgSequence, visible_params: set[TypedSymbolOrObject]):
for i in range(len(seq.children) - 1, -1, -1):
c = seq.children[i]
if isinstance(c, SfgParamCollectionDeferredNode):
c = c.expand(self._ctx, visible_params=visible_params)
seq.replace_child(i, c)
seq[i] = c
if isinstance(c, SfgSequence):
iter_nested_sequences(c, visible_params)
......@@ -96,7 +97,7 @@ class ExpandingParameterCollector():
return params
def branching_node(self, node: SfgCallTreeNode):
def branching_node(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]:
"""
Each interior node that is not a sequence simply requires the union of all parameters
required by its children.
......@@ -111,20 +112,20 @@ class ParameterCollector():
"""
@visitor
def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]:
def visit(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]:
return self.branching_node(node)
@visit.case(SfgCallTreeLeaf)
def leaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]:
def leaf(self, leaf: SfgCallTreeLeaf) -> set[TypedSymbolOrObject]:
return leaf.required_parameters
@visit.case(SfgSequence)
def sequence(self, sequence: SfgSequence) -> Set[TypedSymbol]:
def sequence(self, sequence: SfgSequence) -> set[TypedSymbolOrObject]:
"""
Only in a sequence may parameters be defined and visible to subsequent nodes.
"""
params = set()
params: set[TypedSymbolOrObject] = set()
for c in sequence.children[::-1]:
if isinstance(c, SfgStatements):
params -= c.defined_parameters
......@@ -133,7 +134,7 @@ class ParameterCollector():
params |= self.visit(c)
return params
def branching_node(self, node: SfgCallTreeNode):
def branching_node(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]:
"""
Each interior node that is not a sequence simply requires the union of all parameters
required by its children.
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment