Newer
Older
from typing import Generator, Union, Optional, Sequence
from os import path
from argparse import ArgumentParser
from pystencils import Field
from pystencils.astnodes import KernelFunction
from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle
from .tree import SfgCallTreeNode, SfgSequence, SfgKernelCallNode, SfgCondition, SfgBranch
from .tree.builders import SfgBranchBuilder, make_sequence
from .source_concepts.containers import SrcField
from .source_components import SfgFunction
@dataclass
class SfgCodeStyle:
indent_width: int = 2
def indent(self, s: str):
return do_indent(s, self.indent_width, first=True)
class SourceFileGenerator:
def __init__(self,
namespace: str = "pystencils",
codestyle: SfgCodeStyle = SfgCodeStyle()):
parser = ArgumentParser(
"pystencilssfg",
description="pystencils Source File Generator")
parser.add_argument("script_args", nargs='*')
parser.add_argument("-d", "--output-dir", type=str, default='.', dest='output_directory')
args = parser.parse_args(sys.argv)
import __main__
scriptpath = __main__.__file__
scriptname = path.split(scriptpath)[1]
basename = path.splitext(scriptname)[0]
self._context = SfgContext(args.script_args, namespace, codestyle)
from .emitters.cpu.basic_cpu import BasicCpuEmitter
self._emitter = BasicCpuEmitter(self._context, basename, args.output_directory)
for file in self._emitter.output_files:
def __enter__(self):
return self._context
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
self._emitter.write_files()
class SfgContext:
def __init__(self, argv, root_namespace: str, codestyle: SfgCodeStyle):
self._argv = argv
self._root_namespace = root_namespace
self._default_kernel_namespace = SfgKernelNamespace(self, "kernels")
self._kernel_namespaces = { self._default_kernel_namespace.name : self._default_kernel_namespace }
@property
def argv(self) -> Sequence[str]:
return self._argv
return self._root_namespace
@property
def codestyle(self) -> SfgCodeStyle:
return self._codestyle
return self._default_kernel_namespace
def kernel_namespace(self, name: str) -> SfgKernelNamespace:
if name in self._kernel_namespaces:
raise ValueError(f"Duplicate kernel namespace: {name}")
kns = SfgKernelNamespace(self, name)
self._kernel_namespaces[name] = kns
return kns
def includes(self) -> Generator[str, None, None]:
yield from self._includes
def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]:
yield from self._kernel_namespaces.values()
def functions(self) -> Generator[SfgFunction, None, None]:
yield from self._functions.values()
def include(self, header_file: str):
if not (header_file.startswith("<") and header_file.endswith(">")):
if not (header_file.startswith('"') and header_file.endswith('"')):
header_file = f'"{header_file}"'
def function(self,
name: str,
ast_or_kernel_handle : Optional[Union[KernelFunction, SfgKernelHandle]] = None):
if name in self._functions:
raise ValueError(f"Duplicate function: {name}")
if ast_or_kernel_handle is not None:
if isinstance(ast_or_kernel_handle, KernelFunction):
khandle = self._default_kernel_namespace.add(ast_or_kernel_handle)
tree = SfgKernelCallNode(self, khandle)
elif isinstance(ast_or_kernel_handle, SfgKernelCallNode):
tree = ast_or_kernel_handle
else:
raise TypeError(f"Invalid type of argument `ast_or_kernel_handle`!")
else:
def sequencer(*args: SfgCallTreeNode):
tree = make_sequence(*args)
func = SfgFunction(self, name, tree)
self._functions[name] = func
return sequencer
#----------------------------------------------------------------------------------------------
# Call Tree Node Factory
#----------------------------------------------------------------------------------------------
def call(self, kernel_handle: SfgKernelHandle) -> SfgKernelCallNode:
return SfgKernelCallNode(kernel_handle)
@property
def branch(self) -> SfgBranchBuilder:
return SfgBranchBuilder()
def map_field(self, field: Field, src_object: Optional[SrcField] = None) -> SfgSequence:
if src_object is None:
raise NotImplementedError("Automatic field extraction is not implemented yet.")
else:
return src_object.extract_parameters(field)