Newer
Older
from typing import Generator, Sequence
from .configuration import SfgCodeStyle
from .visitors import CollectIncludes
from .source_components import (
SfgHeaderInclude,
SfgKernelNamespace,
SfgFunction,
SfgClass,
)
from .exceptions import SfgException
class SfgContext:
def __init__(
self,
outer_namespace: str | None = None,
codestyle: SfgCodeStyle = SfgCodeStyle(),
argv: Sequence[str] | None = None,
):
self._argv = argv
self._default_kernel_namespace = SfgKernelNamespace(self, "kernels")
self._outer_namespace = outer_namespace
self._inner_namespace: str | None = None
self._codestyle = codestyle
self._kernel_namespaces = {
self._default_kernel_namespace.name: self._default_kernel_namespace
}
@property
def argv(self) -> Sequence[str]:
"""If this context was created by a `pystencilssfg.SourceFileGenerator`, provides the command
line arguments given to the generator script, with configuration arguments for the code generator
stripped away.
Otherwise, throws an exception."""
if self._argv is None:
raise SfgException("This context provides no command-line arguments.")
return self._argv
return self._outer_namespace
return self._inner_namespace
match (self.outer_namespace, self.inner_namespace):
case None, None:
return None
case outer, None:
return outer
case None, inner:
return inner
case outer, inner:
return f"{outer}::{inner}"
case _:
assert False
return self._codestyle
# ----------------------------------------------------------------------------------------------
# ----------------------------------------------------------------------------------------------
"""The prelude is a comment block printed at the top of both generated files."""
if self._prelude:
self._prelude += "\n"
self._prelude += code_str
if not code_str.endswith("\n"):
self._prelude += "\n"
def includes(self) -> Generator[SfgHeaderInclude, None, None]:
"""Includes of headers. Public includes are added to the header file, private includes
are added to the implementation file."""
yield from self._includes
def add_include(self, include: SfgHeaderInclude):
self._includes.add(include)
def definitions(self) -> Generator[str, None, None]:
"""Definitions are code lines printed at the top of the header file, after the includes."""
yield from self._definitions
def add_definition(self, definition: str):
self._definitions.append(definition)
def set_namespace(self, namespace: str):
if self._inner_namespace is not None:
raise SfgException("The code namespace was already set.")
self._inner_namespace = namespace
# ----------------------------------------------------------------------------------------------
# ----------------------------------------------------------------------------------------------
def default_kernel_namespace(self) -> SfgKernelNamespace:
return self._default_kernel_namespace
def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]:
yield from self._kernel_namespaces.values()
def get_kernel_namespace(self, str) -> SfgKernelNamespace | None:
return self._kernel_namespaces.get(str)
def add_kernel_namespace(self, namespace: SfgKernelNamespace):
if namespace.name in self._kernel_namespaces:
raise ValueError(f"Duplicate kernel namespace: {namespace.name}")
self._kernel_namespaces[namespace.name] = namespace
# ----------------------------------------------------------------------------------------------
# ----------------------------------------------------------------------------------------------
def functions(self) -> Generator[SfgFunction, None, None]:
yield from self._functions.values()
def get_function(self, name: str) -> SfgFunction | None:
return self._functions.get(name, None)
if func.name in self._functions:
raise SfgException(f"Duplicate function: {func.name}")
self._functions[func.name] = func
for incl in CollectIncludes().visit(func):
self.add_include(incl)
# ----------------------------------------------------------------------------------------------
# Classes
# ----------------------------------------------------------------------------------------------
def classes(self) -> Generator[SfgClass, None, None]:
yield from self._classes.values()
def get_class(self, name: str) -> SfgClass | None:
return self._classes.get(name, None)
def add_class(self, cls: SfgClass):
if cls.class_name in self._classes:
raise SfgException(f"Duplicate class: {cls.class_name}")
self._classes[cls.class_name] = cls
for incl in CollectIncludes().visit(cls):