Newer
Older
from typing import Generator, Union, Optional, Sequence
from os import path
from argparse import ArgumentParser
from pystencils import Field
from pystencils.astnodes import KernelFunction
from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle
from .tree import SfgCallTreeNode, SfgSequence, SfgKernelCallNode
from .tree.deferred_nodes import SfgDeferredFieldMapping
from .tree.builders import SfgBranchBuilder, make_sequence
from .tree.visitors import CollectIncludes
from .source_concepts.containers import SrcField
from .source_components import SfgFunction, SfgHeaderInclude
@dataclass
class SfgCodeStyle:
indent_width: int = 2
def indent(self, s: str):
return do_indent(s, self.indent_width, first=True)
class SourceFileGenerator:
def __init__(self,
namespace: str = "pystencils",
codestyle: SfgCodeStyle = SfgCodeStyle()):
parser = ArgumentParser(
"pystencilssfg",
description="pystencils Source File Generator")
parser.add_argument("script_args", nargs='*')
parser.add_argument("-d", "--output-dir", type=str, default='.', dest='output_directory')
args = parser.parse_args(sys.argv)
import __main__
scriptpath = __main__.__file__
scriptname = path.split(scriptpath)[1]
basename = path.splitext(scriptname)[0]
self._context = SfgContext(args.script_args, namespace, codestyle)
from .emitters.cpu.basic_cpu import BasicCpuEmitter
self._emitter = BasicCpuEmitter(self._context, basename, args.output_directory)
for file in self._emitter.output_files:
def __enter__(self):
return self._context
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
self._emitter.write_files()
class SfgContext:
def __init__(self, argv, root_namespace: str, codestyle: SfgCodeStyle):
self._argv = argv
self._root_namespace = root_namespace
self._default_kernel_namespace = SfgKernelNamespace(self, "kernels")
self._includes = set()
self._kernel_namespaces = { self._default_kernel_namespace.name : self._default_kernel_namespace }
@property
def argv(self) -> Sequence[str]:
return self._argv
return self._root_namespace
@property
def codestyle(self) -> SfgCodeStyle:
return self._codestyle
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#----------------------------------------------------------------------------------------------
# Source Component Getters
#----------------------------------------------------------------------------------------------
def includes(self) -> Generator[SfgHeaderInclude, None, None]:
yield from self._includes
def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]:
yield from self._kernel_namespaces.values()
def functions(self) -> Generator[SfgFunction, None, None]:
yield from self._functions.values()
#----------------------------------------------------------------------------------------------
# Source Component Adders
#----------------------------------------------------------------------------------------------
def add_include(self, include: SfgHeaderInclude):
self._includes.add(include)
def add_function(self, func: SfgFunction):
if func.name in self._functions:
raise ValueError(f"Duplicate function: {func.name}")
self._functions[func.name] = func
for incl in CollectIncludes().visit(func._tree):
self.add_include(incl)
#----------------------------------------------------------------------------------------------
# Factory-like Adders
#----------------------------------------------------------------------------------------------
return self._default_kernel_namespace
def kernel_namespace(self, name: str) -> SfgKernelNamespace:
if name in self._kernel_namespaces:
raise ValueError(f"Duplicate kernel namespace: {name}")
kns = SfgKernelNamespace(self, name)
self._kernel_namespaces[name] = kns
return kns
system_header = False
if header_file.startswith("<") and header_file.endswith(">"):
header_file = header_file[1:-1]
system_header = True
self.add_include(SfgHeaderInclude(header_file, system_header=system_header))
def function(self,
name: str,
ast_or_kernel_handle : Optional[Union[KernelFunction, SfgKernelHandle]] = None):
if name in self._functions:
raise ValueError(f"Duplicate function: {name}")
if ast_or_kernel_handle is not None:
if isinstance(ast_or_kernel_handle, KernelFunction):
khandle = self._default_kernel_namespace.add(ast_or_kernel_handle)
tree = SfgKernelCallNode(self, khandle)
elif isinstance(ast_or_kernel_handle, SfgKernelCallNode):
tree = ast_or_kernel_handle
else:
raise TypeError(f"Invalid type of argument `ast_or_kernel_handle`!")
else:
def sequencer(*args: SfgCallTreeNode):
tree = make_sequence(*args)
self.add_function(func)
return sequencer
#----------------------------------------------------------------------------------------------
# Call Tree Node Factory
#----------------------------------------------------------------------------------------------
def call(self, kernel_handle: SfgKernelHandle) -> SfgKernelCallNode:
return SfgKernelCallNode(kernel_handle)
@property
def branch(self) -> SfgBranchBuilder:
return SfgBranchBuilder()
def map_field(self, field: Field, src_object: Optional[SrcField] = None) -> SfgSequence:
if src_object is None:
raise NotImplementedError("Automatic field extraction is not implemented yet.")
else:
return SfgDeferredFieldMapping(field, src_object)