Skip to content
Snippets Groups Projects
Commit 1699c4db authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Refactor: Remove config argument from context and emitter

parent 973ba865
Branches
Tags
No related merge requests found
Pipeline #57781 passed with stages
in 5 minutes and 39 seconds
...@@ -72,9 +72,12 @@ def list_files(args): ...@@ -72,9 +72,12 @@ def list_files(args):
_, scriptname = path.split(args.codegen_script) _, scriptname = path.split(args.codegen_script)
basename = path.splitext(scriptname)[0] basename = path.splitext(scriptname)[0]
from .emitters.cpu.basic_cpu import BasicCpuEmitter from .emitters import HeaderSourcePairEmitter
emitter = BasicCpuEmitter(basename, config) emitter = HeaderSourcePairEmitter(basename,
config.header_extension,
config.source_extension,
config.output_directory)
print(args.sep.join(emitter.output_files), end=os.linesep if args.newline else '') print(args.sep.join(emitter.output_files), end=os.linesep if args.newline else '')
......
from typing import Generator, Sequence from typing import Generator, Sequence
from .configuration import SfgConfiguration, SfgCodeStyle from .configuration import SfgCodeStyle
from .tree.visitors import CollectIncludes from .tree.visitors import CollectIncludes
from .source_components import SfgHeaderInclude, SfgKernelNamespace, SfgFunction from .source_components import SfgHeaderInclude, SfgKernelNamespace, SfgFunction
from .exceptions import SfgException from .exceptions import SfgException
class SfgContext: class SfgContext:
def __init__(self, config: SfgConfiguration, argv: Sequence[str] | None = None): def __init__(self,
outer_namespace: str | None = None,
codestyle: SfgCodeStyle = SfgCodeStyle(),
argv: Sequence[str] | None = None):
self._argv = argv self._argv = argv
self._config = config
self._default_kernel_namespace = SfgKernelNamespace(self, "kernels") self._default_kernel_namespace = SfgKernelNamespace(self, "kernels")
self._code_namespace: str | None = None self._outer_namespace = outer_namespace
self._inner_namespace: str | None = None
self._codestyle = codestyle
# Source Components # Source Components
self._prelude: str = "" self._prelude: str = ""
...@@ -34,11 +39,11 @@ class SfgContext: ...@@ -34,11 +39,11 @@ class SfgContext:
@property @property
def outer_namespace(self) -> str | None: def outer_namespace(self) -> str | None:
return self._config.outer_namespace return self._outer_namespace
@property @property
def inner_namespace(self) -> str | None: def inner_namespace(self) -> str | None:
return self._code_namespace return self._inner_namespace
@property @property
def fully_qualified_namespace(self) -> str | None: def fully_qualified_namespace(self) -> str | None:
...@@ -51,8 +56,7 @@ class SfgContext: ...@@ -51,8 +56,7 @@ class SfgContext:
@property @property
def codestyle(self) -> SfgCodeStyle: def codestyle(self) -> SfgCodeStyle:
assert self._config.codestyle is not None return self._codestyle
return self._config.codestyle
# ---------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------
# Prelude, Includes, Definitions, Namespace # Prelude, Includes, Definitions, Namespace
...@@ -88,10 +92,10 @@ class SfgContext: ...@@ -88,10 +92,10 @@ class SfgContext:
self._definitions.append(definition) self._definitions.append(definition)
def set_namespace(self, namespace: str): def set_namespace(self, namespace: str):
if self._code_namespace is not None: if self._inner_namespace is not None:
raise SfgException("The code namespace was already set.") raise SfgException("The code namespace was already set.")
self._code_namespace = namespace self._inner_namespace = namespace
# ---------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------
# Kernel Namespaces # Kernel Namespaces
......
from .header_source_pair import HeaderSourcePairEmitter
__all__ = [
"HeaderSourcePairEmitter"
]
...@@ -4,16 +4,19 @@ from textwrap import indent ...@@ -4,16 +4,19 @@ from textwrap import indent
from os import path from os import path
from ...configuration import SfgConfiguration from ..context import SfgContext
from ...context import SfgContext
class BasicCpuEmitter: class HeaderSourcePairEmitter:
def __init__(self, basename: str, config: SfgConfiguration): def __init__(self,
basename: str,
header_extension: str,
impl_extension: str,
output_directory: str):
self._basename = basename self._basename = basename
self._output_directory = cast(str, config.output_directory) self._output_directory = cast(str, output_directory)
self._header_filename = f"{basename}.{config.header_extension}" self._header_filename = f"{basename}.{header_extension}"
self._source_filename = f"{basename}.{config.source_extension}" self._source_filename = f"{basename}.{impl_extension}"
@property @property
def output_files(self) -> tuple[str, str]: def output_files(self) -> tuple[str, str]:
...@@ -39,9 +42,9 @@ class BasicCpuEmitter: ...@@ -39,9 +42,9 @@ class BasicCpuEmitter:
'functions': list(ctx.functions()) 'functions': list(ctx.functions())
} }
template_name = "BasicCpu" template_name = "HeaderSourcePair"
env = Environment(loader=PackageLoader('pystencilssfg.emitters.cpu'), env = Environment(loader=PackageLoader('pystencilssfg.emitters'),
undefined=StrictUndefined, undefined=StrictUndefined,
trim_blocks=True, trim_blocks=True,
lstrip_blocks=True) lstrip_blocks=True)
......
# TODO
# mypy strict_optional=False
import sys import sys
import os import os
from os import path from os import path
...@@ -21,10 +24,13 @@ class SourceFileGenerator: ...@@ -21,10 +24,13 @@ class SourceFileGenerator:
config = merge_configurations(project_config, cmdline_config, sfg_config) config = merge_configurations(project_config, cmdline_config, sfg_config)
self._context = SfgContext(config, argv=script_args) self._context = SfgContext(config.outer_namespace, config.codestyle, argv=script_args)
from .emitters.cpu.basic_cpu import BasicCpuEmitter from .emitters import HeaderSourcePairEmitter
self._emitter = BasicCpuEmitter(basename, config) self._emitter = HeaderSourcePairEmitter(basename,
config.header_extension,
config.source_extension,
config.output_directory)
def clean_files(self): def clean_files(self):
for file in self._emitter.output_files: for file in self._emitter.output_files:
......
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Union, TypeAlias from typing import TYPE_CHECKING, Union, TypeAlias
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -19,7 +19,7 @@ class SrcObject: ...@@ -19,7 +19,7 @@ class SrcObject:
Two objects are identical if they have the same identifier and type string.""" Two objects are identical if they have the same identifier and type string."""
def __init__(self, src_type: SrcType, identifier: Optional[str]): def __init__(self, src_type: SrcType, identifier: str):
self._src_type = src_type self._src_type = src_type
self._identifier = identifier self._identifier = identifier
...@@ -28,7 +28,7 @@ class SrcObject: ...@@ -28,7 +28,7 @@ class SrcObject:
return self._identifier return self._identifier
@property @property
def name(self): def name(self) -> str:
"""For interface compatibility with ps.TypedSymbol""" """For interface compatibility with ps.TypedSymbol"""
return self._identifier return self._identifier
...@@ -53,7 +53,7 @@ TypedSymbolOrObject: TypeAlias = Union[TypedSymbol, SrcObject] ...@@ -53,7 +53,7 @@ TypedSymbolOrObject: TypeAlias = Union[TypedSymbol, SrcObject]
class SrcField(SrcObject, ABC): class SrcField(SrcObject, ABC):
def __init__(self, src_type: SrcType, identifier: Optional[str]): def __init__(self, src_type: SrcType, identifier: str):
super().__init__(src_type, identifier) super().__init__(src_type, identifier)
@abstractmethod @abstractmethod
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment