Newer
Older
from typing import Generator, Union, Optional, Sequence
from os import path
from pystencils import Field
from .configuration import SfgConfiguration, config_from_commandline, merge_configurations, SfgCodeStyle
from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle
from .tree import SfgCallTreeNode, SfgSequence, SfgKernelCallNode, SfgStatements
from .tree.deferred_nodes import SfgDeferredFieldMapping
from .tree.builders import SfgBranchBuilder, make_sequence
from .tree.visitors import CollectIncludes
from .source_concepts import SrcField, TypedSymbolOrObject
from .source_components import SfgFunction, SfgHeaderInclude
class SourceFileGenerator:
def __init__(self, sfg_config: SfgConfiguration):
import __main__
scriptpath = __main__.__file__
scriptname = path.split(scriptpath)[1]
basename = path.splitext(scriptname)[0]
project_config, cmdline_config, script_args = config_from_commandline(sys.argv)
config = merge_configurations(project_config, cmdline_config, sfg_config)
self._context = SfgContext(script_args, config)
from .emitters.cpu.basic_cpu import BasicCpuEmitter
self._emitter = BasicCpuEmitter(self._context, basename, config.output_directory)
for file in self._emitter.output_files:
def __enter__(self):
return self._context
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
self._emitter.write_files()
class SfgContext:
def __init__(self, argv, config: SfgConfiguration):
self._argv = argv
self._default_kernel_namespace = SfgKernelNamespace(self, "kernels")
self._includes = set()
self._kernel_namespaces = { self._default_kernel_namespace.name : self._default_kernel_namespace }
@property
def argv(self) -> Sequence[str]:
return self._argv
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#----------------------------------------------------------------------------------------------
# Source Component Getters
#----------------------------------------------------------------------------------------------
def includes(self) -> Generator[SfgHeaderInclude, None, None]:
yield from self._includes
def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]:
yield from self._kernel_namespaces.values()
def functions(self) -> Generator[SfgFunction, None, None]:
yield from self._functions.values()
#----------------------------------------------------------------------------------------------
# Source Component Adders
#----------------------------------------------------------------------------------------------
def add_include(self, include: SfgHeaderInclude):
self._includes.add(include)
def add_function(self, func: SfgFunction):
if func.name in self._functions:
raise ValueError(f"Duplicate function: {func.name}")
self._functions[func.name] = func
for incl in CollectIncludes().visit(func._tree):
self.add_include(incl)
#----------------------------------------------------------------------------------------------
# Factory-like Adders
#----------------------------------------------------------------------------------------------
return self._default_kernel_namespace
def kernel_namespace(self, name: str) -> SfgKernelNamespace:
if name in self._kernel_namespaces:
raise ValueError(f"Duplicate kernel namespace: {name}")
kns = SfgKernelNamespace(self, name)
self._kernel_namespaces[name] = kns
return kns
system_header = False
if header_file.startswith("<") and header_file.endswith(">"):
header_file = header_file[1:-1]
system_header = True
self.add_include(SfgHeaderInclude(header_file, system_header=system_header))
def function(self,
name: str,
ast_or_kernel_handle : Optional[Union[KernelFunction, SfgKernelHandle]] = None):
if name in self._functions:
raise ValueError(f"Duplicate function: {name}")
if ast_or_kernel_handle is not None:
if isinstance(ast_or_kernel_handle, KernelFunction):
khandle = self._default_kernel_namespace.add(ast_or_kernel_handle)
tree = SfgKernelCallNode(self, khandle)
elif isinstance(ast_or_kernel_handle, SfgKernelCallNode):
tree = ast_or_kernel_handle
else:
raise TypeError(f"Invalid type of argument `ast_or_kernel_handle`!")
else:
def sequencer(*args: SfgCallTreeNode):
tree = make_sequence(*args)
self.add_function(func)
return sequencer
#----------------------------------------------------------------------------------------------
# In-Sequence builders to be used within the second phase of SfgContext.function().
#----------------------------------------------------------------------------------------------
def call(self, kernel_handle: SfgKernelHandle) -> SfgKernelCallNode:
return SfgKernelCallNode(kernel_handle)
@property
def branch(self) -> SfgBranchBuilder:
return SfgBranchBuilder()
def map_field(self, field: Field, src_object: Optional[SrcField] = None) -> SfgSequence:
if src_object is None:
raise NotImplementedError("Automatic field extraction is not implemented yet.")
else:
return SfgDeferredFieldMapping(field, src_object)
def map_param(self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str):
return SfgStatements(mapping, (lhs,), (rhs,))