from typing import Generator, Union, Optional, Sequence from dataclasses import dataclass import sys import os from os import path from argparse import ArgumentParser from jinja2.filters import do_indent from pystencils import Field from pystencils.astnodes import KernelFunction from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle from .tree import SfgCallTreeNode, SfgSequence, SfgKernelCallNode, SfgCondition, SfgBranch from .tree.builders import SfgBranchBuilder, make_sequence from .source_concepts.containers import SrcField from .source_components import SfgFunction @dataclass class SfgCodeStyle: indent_width: int = 2 def indent(self, s: str): return do_indent(s, self.indent_width, first=True) class SourceFileGenerator: def __init__(self, namespace: str = "pystencils", codestyle: SfgCodeStyle = SfgCodeStyle()): parser = ArgumentParser( "pystencilssfg", description="pystencils Source File Generator") parser.add_argument("script_args", nargs='*') parser.add_argument("-d", "--output-dir", type=str, default='.', dest='output_directory') args = parser.parse_args(sys.argv) import __main__ scriptpath = __main__.__file__ scriptname = path.split(scriptpath)[1] basename = path.splitext(scriptname)[0] self._context = SfgContext(args.script_args, namespace, codestyle) from .emitters.cpu.basic_cpu import BasicCpuEmitter self._emitter = BasicCpuEmitter(self._context, basename, args.output_directory) def clean_files(self): for file in self._emitter.output_files: if path.exists(file): os.remove(file) def __enter__(self): self.clean_files() return self._context def __exit__(self, exc_type, exc_value, traceback): if exc_type is None: self._emitter.write_files() class SfgContext: def __init__(self, argv, root_namespace: str, codestyle: SfgCodeStyle): self._argv = argv self._root_namespace = root_namespace self._codestyle = codestyle self._default_kernel_namespace = SfgKernelNamespace(self, "kernels") # Source Components self._includes = [] self._kernel_namespaces = { self._default_kernel_namespace.name : self._default_kernel_namespace } self._functions = dict() @property def argv(self) -> Sequence[str]: return self._argv @property def root_namespace(self) -> str: return self._root_namespace @property def codestyle(self) -> SfgCodeStyle: return self._codestyle @property def kernels(self) -> SfgKernelNamespace: return self._default_kernel_namespace def kernel_namespace(self, name: str) -> SfgKernelNamespace: if name in self._kernel_namespaces: raise ValueError(f"Duplicate kernel namespace: {name}") kns = SfgKernelNamespace(self, name) self._kernel_namespaces[name] = kns return kns def includes(self) -> Generator[str, None, None]: yield from self._includes def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]: yield from self._kernel_namespaces.values() def functions(self) -> Generator[SfgFunction, None, None]: yield from self._functions.values() def include(self, header_file: str): if not (header_file.startswith("<") and header_file.endswith(">")): if not (header_file.startswith('"') and header_file.endswith('"')): header_file = f'"{header_file}"' self._includes.append(header_file) def function(self, name: str, ast_or_kernel_handle : Optional[Union[KernelFunction, SfgKernelHandle]] = None): if name in self._functions: raise ValueError(f"Duplicate function: {name}") if ast_or_kernel_handle is not None: if isinstance(ast_or_kernel_handle, KernelFunction): khandle = self._default_kernel_namespace.add(ast_or_kernel_handle) tree = SfgKernelCallNode(self, khandle) elif isinstance(ast_or_kernel_handle, SfgKernelCallNode): tree = ast_or_kernel_handle else: raise TypeError(f"Invalid type of argument `ast_or_kernel_handle`!") else: def sequencer(*args: SfgCallTreeNode): tree = make_sequence(*args) func = SfgFunction(self, name, tree) self._functions[name] = func return sequencer #---------------------------------------------------------------------------------------------- # Call Tree Node Factory #---------------------------------------------------------------------------------------------- def call(self, kernel_handle: SfgKernelHandle) -> SfgKernelCallNode: return SfgKernelCallNode(kernel_handle) @property def branch(self) -> SfgBranchBuilder: return SfgBranchBuilder() def map_field(self, field: Field, src_object: Optional[SrcField] = None) -> SfgSequence: if src_object is None: raise NotImplementedError("Automatic field extraction is not implemented yet.") else: return src_object.extract_parameters(field)