Skip to content
Snippets Groups Projects
context.py 7.01 KiB
from typing import Generator, Union, Optional, Sequence

import sys
import os
from os import path


from pystencils import Field
from pystencils.astnodes import KernelFunction

from .configuration import SfgConfiguration, config_from_commandline, merge_configurations, SfgCodeStyle
from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle
from .tree import SfgCallTreeNode, SfgSequence, SfgKernelCallNode, SfgStatements
from .tree.deferred_nodes import SfgDeferredFieldMapping
from .tree.builders import SfgBranchBuilder, make_sequence
from .tree.visitors import CollectIncludes
from .source_concepts import SrcField, TypedSymbolOrObject
from .source_components import SfgFunction, SfgHeaderInclude


class SourceFileGenerator:
    def __init__(self, sfg_config: SfgConfiguration | None = None):
        if sfg_config and not isinstance(sfg_config, SfgConfiguration):
            raise TypeError("sfg_config is not an SfgConfiguration.")

        import __main__
        scriptpath = __main__.__file__
        scriptname = path.split(scriptpath)[1]
        basename = path.splitext(scriptname)[0]

        project_config, cmdline_config, script_args = config_from_commandline(sys.argv)

        config = merge_configurations(project_config, cmdline_config, sfg_config)

        self._context = SfgContext(script_args, config)

        from .emitters.cpu.basic_cpu import BasicCpuEmitter
        self._emitter = BasicCpuEmitter(basename, config)

    def clean_files(self):
        for file in self._emitter.output_files:
            if path.exists(file):
                os.remove(file)

    def __enter__(self):
        self.clean_files()
        return self._context

    def __exit__(self, exc_type, exc_value, traceback):
        if exc_type is None:
            self._emitter.write_files(self._context)


class SfgContext:
    def __init__(self, argv, config: SfgConfiguration):
        self._argv = argv
        self._config = config
        self._default_kernel_namespace = SfgKernelNamespace(self, "kernels")

        self._code_namespace = None

        #   Source Components
        self._includes: set[SfgHeaderInclude] = set()
        self._kernel_namespaces = {self._default_kernel_namespace.name: self._default_kernel_namespace}
        self._functions: dict[str, SfgFunction] = dict()

    @property
    def argv(self) -> Sequence[str]:
        return self._argv

    @property
    def root_namespace(self) -> str | None:
        return self._config.base_namespace

    @property
    def inner_namespace(self) -> str | None:
        return self._code_namespace

    @property
    def fully_qualified_namespace(self) -> str | None:
        match (self.root_namespace, self.inner_namespace):
            case None, None: return None
            case outer, None: return outer
            case None, inner: return inner
            case outer, inner: return f"{outer}::{inner}"
            case _: assert False

    @property
    def codestyle(self) -> SfgCodeStyle:
        assert self._config.codestyle is not None
        return self._config.codestyle

    # ----------------------------------------------------------------------------------------------
    #   Source Component Getters
    # ----------------------------------------------------------------------------------------------

    def includes(self) -> Generator[SfgHeaderInclude, None, None]:
        yield from self._includes

    def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]:
        yield from self._kernel_namespaces.values()

    def functions(self) -> Generator[SfgFunction, None, None]:
        yield from self._functions.values()

    # ----------------------------------------------------------------------------------------------
    #   Source Component Adders
    # ----------------------------------------------------------------------------------------------

    def add_include(self, include: SfgHeaderInclude):
        self._includes.add(include)

    def add_function(self, func: SfgFunction):
        if func.name in self._functions:
            raise ValueError(f"Duplicate function: {func.name}")

        self._functions[func.name] = func
        for incl in CollectIncludes().visit(func._tree):
            self.add_include(incl)

    # ----------------------------------------------------------------------------------------------
    #   Factory-like Adders
    # ----------------------------------------------------------------------------------------------

    @property
    def kernels(self) -> SfgKernelNamespace:
        return self._default_kernel_namespace

    def kernel_namespace(self, name: str) -> SfgKernelNamespace:
        if name in self._kernel_namespaces:
            raise ValueError(f"Duplicate kernel namespace: {name}")

        kns = SfgKernelNamespace(self, name)
        self._kernel_namespaces[name] = kns
        return kns

    def include(self, header_file: str):
        system_header = False
        if header_file.startswith("<") and header_file.endswith(">"):
            header_file = header_file[1:-1]
            system_header = True

        self.add_include(SfgHeaderInclude(header_file, system_header=system_header))

    def function(self,
                 name: str,
                 ast_or_kernel_handle: Optional[Union[KernelFunction, SfgKernelHandle]] = None):
        if name in self._functions:
            raise ValueError(f"Duplicate function: {name}")

        if ast_or_kernel_handle is not None:
            if isinstance(ast_or_kernel_handle, KernelFunction):
                khandle = self._default_kernel_namespace.add(ast_or_kernel_handle)
                tree = SfgKernelCallNode(khandle)
            elif isinstance(ast_or_kernel_handle, SfgKernelCallNode):
                tree = ast_or_kernel_handle
            else:
                raise TypeError("Invalid type of argument `ast_or_kernel_handle`!")

            func = SfgFunction(self, name, tree)
            self.add_function(func)
        else:
            def sequencer(*args: SfgCallTreeNode):
                tree = make_sequence(*args)
                func = SfgFunction(self, name, tree)
                self.add_function(func)

            return sequencer

    # ----------------------------------------------------------------------------------------------
    #   In-Sequence builders to be used within the second phase of SfgContext.function().
    # ----------------------------------------------------------------------------------------------

    def call(self, kernel_handle: SfgKernelHandle) -> SfgKernelCallNode:
        return SfgKernelCallNode(kernel_handle)

    @property
    def branch(self) -> SfgBranchBuilder:
        return SfgBranchBuilder()

    def map_field(self, field: Field, src_object: Optional[SrcField] = None) -> SfgDeferredFieldMapping:
        if src_object is None:
            raise NotImplementedError("Automatic field extraction is not implemented yet.")
        else:
            return SfgDeferredFieldMapping(field, src_object)

    def map_param(self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str):
        return SfgStatements(mapping, (lhs,), (rhs,))