Newer
Older
from typing import Generator, Union, Optional, Sequence
from os import path
from pystencils import Field
from .configuration import SfgConfiguration, config_from_commandline, merge_configurations, SfgCodeStyle
from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle
from .tree import SfgCallTreeNode, SfgSequence, SfgKernelCallNode, SfgStatements
from .tree.deferred_nodes import SfgDeferredFieldMapping
from .tree.builders import SfgBranchBuilder, make_sequence
from .tree.visitors import CollectIncludes
from .source_concepts import SrcField, TypedSymbolOrObject
from .source_components import SfgFunction, SfgHeaderInclude
class SourceFileGenerator:
def __init__(self, sfg_config: SfgConfiguration | None = None):
if sfg_config and not isinstance(sfg_config, SfgConfiguration):
raise TypeError("sfg_config is not an SfgConfiguration.")
import __main__
scriptpath = __main__.__file__
scriptname = path.split(scriptpath)[1]
basename = path.splitext(scriptname)[0]
project_config, cmdline_config, script_args = config_from_commandline(sys.argv)
config = merge_configurations(project_config, cmdline_config, sfg_config)
self._context = SfgContext(script_args, config)
from .emitters.cpu.basic_cpu import BasicCpuEmitter
for file in self._emitter.output_files:
def __enter__(self):
return self._context
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
class SfgContext:
def __init__(self, argv, config: SfgConfiguration):
self._argv = argv
self._default_kernel_namespace = SfgKernelNamespace(self, "kernels")
self._kernel_namespaces = {self._default_kernel_namespace.name: self._default_kernel_namespace}
@property
def argv(self) -> Sequence[str]:
return self._argv
match (self.root_namespace, self.inner_namespace):
case None, None: return None
case outer, None: return outer
case None, inner: return inner
case outer, inner: return f"{outer}::{inner}"
# ----------------------------------------------------------------------------------------------
# Source Component Getters
# ----------------------------------------------------------------------------------------------
def includes(self) -> Generator[SfgHeaderInclude, None, None]:
yield from self._includes
def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]:
yield from self._kernel_namespaces.values()
def functions(self) -> Generator[SfgFunction, None, None]:
yield from self._functions.values()
# ----------------------------------------------------------------------------------------------
# Source Component Adders
# ----------------------------------------------------------------------------------------------
def add_include(self, include: SfgHeaderInclude):
self._includes.add(include)
def add_function(self, func: SfgFunction):
if func.name in self._functions:
raise ValueError(f"Duplicate function: {func.name}")
self._functions[func.name] = func
for incl in CollectIncludes().visit(func._tree):
self.add_include(incl)
# ----------------------------------------------------------------------------------------------
# Factory-like Adders
# ----------------------------------------------------------------------------------------------
return self._default_kernel_namespace
def kernel_namespace(self, name: str) -> SfgKernelNamespace:
if name in self._kernel_namespaces:
raise ValueError(f"Duplicate kernel namespace: {name}")
kns = SfgKernelNamespace(self, name)
self._kernel_namespaces[name] = kns
return kns
system_header = False
if header_file.startswith("<") and header_file.endswith(">"):
header_file = header_file[1:-1]
system_header = True
self.add_include(SfgHeaderInclude(header_file, system_header=system_header))
ast_or_kernel_handle: Optional[Union[KernelFunction, SfgKernelHandle]] = None):
if name in self._functions:
raise ValueError(f"Duplicate function: {name}")
if ast_or_kernel_handle is not None:
if isinstance(ast_or_kernel_handle, KernelFunction):
khandle = self._default_kernel_namespace.add(ast_or_kernel_handle)
elif isinstance(ast_or_kernel_handle, SfgKernelCallNode):
tree = ast_or_kernel_handle
else:
raise TypeError("Invalid type of argument `ast_or_kernel_handle`!")
func = SfgFunction(self, name, tree)
self.add_function(func)
tree = make_sequence(*args)
self.add_function(func)
# ----------------------------------------------------------------------------------------------
# In-Sequence builders to be used within the second phase of SfgContext.function().
# ----------------------------------------------------------------------------------------------
def call(self, kernel_handle: SfgKernelHandle) -> SfgKernelCallNode:
return SfgKernelCallNode(kernel_handle)
return SfgBranchBuilder()
def map_field(self, field: Field, src_object: Optional[SrcField] = None) -> SfgDeferredFieldMapping:
if src_object is None:
raise NotImplementedError("Automatic field extraction is not implemented yet.")
else:
return SfgDeferredFieldMapping(field, src_object)
def map_param(self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str):
return SfgStatements(mapping, (lhs,), (rhs,))