Skip to content
Snippets Groups Projects
context.py 5.33 KiB
Newer Older
from typing import Generator, Union, Optional, Sequence
Frederik Hennig's avatar
Frederik Hennig committed
from dataclasses import dataclass

Frederik Hennig's avatar
Frederik Hennig committed
import os
from argparse import ArgumentParser

Frederik Hennig's avatar
Frederik Hennig committed
from jinja2.filters import do_indent

from pystencils import Field
Frederik Hennig's avatar
Frederik Hennig committed
from pystencils.astnodes import KernelFunction

from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle
from .tree import SfgCallTreeNode, SfgSequence, SfgKernelCallNode, SfgCondition, SfgBranch
from .tree.builders import SfgBranchBuilder, make_sequence
from .source_concepts.containers import SrcField
Frederik Hennig's avatar
Frederik Hennig committed
from .source_components import SfgFunction


@dataclass
class SfgCodeStyle:
    indent_width: int = 2

    def indent(self, s: str):
        return do_indent(s, self.indent_width, first=True)

Frederik Hennig's avatar
Frederik Hennig committed
    def __init__(self,
                 namespace: str = "pystencils",
                 codestyle: SfgCodeStyle = SfgCodeStyle()):
        parser = ArgumentParser(
            "pystencilssfg",
            description="pystencils Source File Generator")
        
        parser.add_argument("script_args", nargs='*')
        parser.add_argument("-d", "--output-dir", type=str, default='.', dest='output_directory')

        args = parser.parse_args(sys.argv)
        import __main__
        scriptpath = __main__.__file__
        scriptname = path.split(scriptpath)[1]
        basename = path.splitext(scriptname)[0]        
        self._context = SfgContext(args.script_args, namespace, codestyle)

        from .emitters.cpu.basic_cpu import BasicCpuEmitter
        self._emitter = BasicCpuEmitter(self._context, basename, args.output_directory)
Frederik Hennig's avatar
Frederik Hennig committed

    def clean_files(self):
        for file in self._emitter.output_files:
Frederik Hennig's avatar
Frederik Hennig committed
            if path.exists(file):
                os.remove(file)
Frederik Hennig's avatar
Frederik Hennig committed
        self.clean_files()
Frederik Hennig's avatar
Frederik Hennig committed
    def __exit__(self, exc_type, exc_value, traceback):
        if exc_type is None:
            self._emitter.write_files()
    def __init__(self, argv, root_namespace: str, codestyle: SfgCodeStyle):
        self._argv = argv
        self._root_namespace = root_namespace
Frederik Hennig's avatar
Frederik Hennig committed
        self._codestyle = codestyle
        self._default_kernel_namespace = SfgKernelNamespace(self, "kernels")

Frederik Hennig's avatar
Frederik Hennig committed
        #   Source Components
        self._includes = []
        self._kernel_namespaces = { self._default_kernel_namespace.name : self._default_kernel_namespace }
Frederik Hennig's avatar
Frederik Hennig committed
        self._functions = dict()

    @property
    def argv(self) -> Sequence[str]:
        return self._argv
Frederik Hennig's avatar
Frederik Hennig committed
    def root_namespace(self) -> str:
        return self._root_namespace
Frederik Hennig's avatar
Frederik Hennig committed
    
    @property
    def codestyle(self) -> SfgCodeStyle:
        return self._codestyle
Frederik Hennig's avatar
Frederik Hennig committed
    def kernels(self) -> SfgKernelNamespace:
        return self._default_kernel_namespace

Frederik Hennig's avatar
Frederik Hennig committed
    def kernel_namespace(self, name: str) -> SfgKernelNamespace:
        if name in self._kernel_namespaces:
            raise ValueError(f"Duplicate kernel namespace: {name}")

        kns = SfgKernelNamespace(self, name)
        self._kernel_namespaces[name] = kns
        return kns
    
    def includes(self) -> Generator[str, None, None]:
        yield from self._includes
Frederik Hennig's avatar
Frederik Hennig committed
    def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]:
        yield from self._kernel_namespaces.values()
Frederik Hennig's avatar
Frederik Hennig committed

    def functions(self) -> Generator[SfgFunction, None, None]:
        yield from self._functions.values()

    def include(self, header_file: str):
        if not (header_file.startswith("<") and header_file.endswith(">")):
            if not (header_file.startswith('"') and header_file.endswith('"')):
                header_file = f'"{header_file}"'
        
Frederik Hennig's avatar
Frederik Hennig committed
        self._includes.append(header_file)
Frederik Hennig's avatar
Frederik Hennig committed

    def function(self, 
                 name: str,
                 ast_or_kernel_handle : Optional[Union[KernelFunction, SfgKernelHandle]] = None):
        if name in self._functions:
            raise ValueError(f"Duplicate function: {name}")
        
        if ast_or_kernel_handle is not None:
            if isinstance(ast_or_kernel_handle, KernelFunction):
                khandle = self._default_kernel_namespace.add(ast_or_kernel_handle)
                tree = SfgKernelCallNode(self, khandle)
            elif isinstance(ast_or_kernel_handle, SfgKernelCallNode):
                tree = ast_or_kernel_handle
            else:
                raise TypeError(f"Invalid type of argument `ast_or_kernel_handle`!")
        else:
            def sequencer(*args: SfgCallTreeNode):
                tree = make_sequence(*args)
Frederik Hennig's avatar
Frederik Hennig committed
                func = SfgFunction(self, name, tree)
                self._functions[name] = func

            return sequencer
        

    #----------------------------------------------------------------------------------------------
    #   Call Tree Node Factory
    #----------------------------------------------------------------------------------------------

    def call(self, kernel_handle: SfgKernelHandle) -> SfgKernelCallNode:
        return SfgKernelCallNode(kernel_handle)
    
    @property
    def branch(self) -> SfgBranchBuilder:
        return SfgBranchBuilder()
    
    def map_field(self, field: Field, src_object: Optional[SrcField] = None) -> SfgSequence:
        if src_object is None:
            raise NotImplementedError("Automatic field extraction is not implemented yet.")
        else:
            return src_object.extract_parameters(field)
Frederik Hennig's avatar
Frederik Hennig committed