Skip to content
Snippets Groups Projects
Commit af279cf0 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

refactored prelude printing and file names

parent 1699c4db
No related merge requests found
Pipeline #57782 passed with stages
in 43 seconds
...@@ -10,7 +10,7 @@ def sfg_config(): ...@@ -10,7 +10,7 @@ def sfg_config():
return SfgConfiguration( return SfgConfiguration(
header_extension='hpp', header_extension='hpp',
source_extension='cpp', impl_extension='cpp',
outer_namespace='cmake_demo', outer_namespace='cmake_demo',
project_info=project_info project_info=project_info
) )
...@@ -76,7 +76,7 @@ def list_files(args): ...@@ -76,7 +76,7 @@ def list_files(args):
emitter = HeaderSourcePairEmitter(basename, emitter = HeaderSourcePairEmitter(basename,
config.header_extension, config.header_extension,
config.source_extension, config.impl_extension,
config.output_directory) config.output_directory)
print(args.sep.join(emitter.output_files), end=os.linesep if args.newline else '') print(args.sep.join(emitter.output_files), end=os.linesep if args.newline else '')
......
...@@ -14,7 +14,7 @@ from importlib import util as iutil ...@@ -14,7 +14,7 @@ from importlib import util as iutil
from .exceptions import SfgException from .exceptions import SfgException
HEADER_FILE_EXTENSIONS = {'h', 'hpp'} HEADER_FILE_EXTENSIONS = {'h', 'hpp'}
SOURCE_FILE_EXTENSIONS = {'c', 'cpp'} IMPL_FILE_EXTENSIONS = {'c', 'cpp', '.impl.h'}
class SfgConfigSource(Enum): class SfgConfigSource(Enum):
...@@ -35,20 +35,57 @@ class SfgConfigException(Exception): ...@@ -35,20 +35,57 @@ class SfgConfigException(Exception):
class SfgCodeStyle: class SfgCodeStyle:
indent_width: int = 2 indent_width: int = 2
code_style: str = "LLVM"
"""Code style to be used by clang-format. Passed verbatim to `--style` argument of the clang-format CLI."""
force_clang_format: bool = False
"""If set to True, abort code generation if `clang-format` binary cannot be found."""
def indent(self, s: str): def indent(self, s: str):
prefix = " " * self.indent_width prefix = " " * self.indent_width
return indent(s, prefix) return indent(s, prefix)
@dataclass
class SfgOutputSpec:
"""Name and path specification for files output by the code generator.
Filenames are constructed as `<output_directory>/<basename>.<extension>`."""
output_directory: str
"""Directory to which the generated files should be written."""
basename: str
"""Base name for output files."""
header_extension: str
"""File extension for generated header file."""
impl_extension: str
"""File extension for generated implementation file."""
def get_header_filename(self):
return f"{self.basename}.{self.header_extension}"
def get_impl_filename(self):
return f"{self.basename}.{self.impl_extension}"
def get_header_filepath(self):
return path.join(self.output_directory, self.get_header_filename())
def get_impl_filepath(self):
return path.join(self.output_directory, self.get_impl_filename())
@dataclass @dataclass
class SfgConfiguration: class SfgConfiguration:
config_source: InitVar[SfgConfigSource | None] = None config_source: InitVar[SfgConfigSource | None] = None
header_extension: str | None = None header_extension: str | None = None
"""File extension for generated header files.""" """File extension for generated header file."""
source_extension: str | None = None impl_extension: str | None = None
"""File extension for generated source files.""" """File extension for generated implementation file."""
header_only: bool | None = None header_only: bool | None = None
"""If set to `True`, generate only a header file without accompaning source file.""" """If set to `True`, generate only a header file without accompaning source file."""
...@@ -73,22 +110,34 @@ class SfgConfiguration: ...@@ -73,22 +110,34 @@ class SfgConfiguration:
if self.header_extension and self.header_extension[0] == '.': if self.header_extension and self.header_extension[0] == '.':
self.header_extension = self.header_extension[1:] self.header_extension = self.header_extension[1:]
if self.source_extension and self.source_extension[0] == '.': if self.impl_extension and self.impl_extension[0] == '.':
self.source_extension = self.source_extension[1:] self.impl_extension = self.impl_extension[1:]
def override(self, other: SfgConfiguration): def override(self, other: SfgConfiguration):
other_dict: dict[str, Any] = {k: v for k, v in asdict(other).items() if v is not None} other_dict: dict[str, Any] = {k: v for k, v in asdict(other).items() if v is not None}
return replace(self, **other_dict) return replace(self, **other_dict)
def get_output_spec(self, basename: str) -> SfgOutputSpec:
assert self.header_extension is not None
assert self.impl_extension is not None
assert self.output_directory is not None
return SfgOutputSpec(
self.output_directory,
basename,
self.header_extension,
self.impl_extension
)
DEFAULT_CONFIG = SfgConfiguration( DEFAULT_CONFIG = SfgConfiguration(
config_source=SfgConfigSource.DEFAULT, config_source=SfgConfigSource.DEFAULT,
header_extension='h', header_extension='h',
source_extension='cpp', impl_extension='cpp',
header_only=False, header_only=False,
outer_namespace=None, outer_namespace=None,
codestyle=SfgCodeStyle(), codestyle=SfgCodeStyle(),
output_directory="" output_directory="."
) )
...@@ -145,7 +194,7 @@ def config_from_parser_args(args): ...@@ -145,7 +194,7 @@ def config_from_parser_args(args):
cmdline_config = SfgConfiguration( cmdline_config = SfgConfiguration(
config_source=SfgConfigSource.COMMANDLINE, config_source=SfgConfigSource.COMMANDLINE,
header_extension=h_ext, header_extension=h_ext,
source_extension=src_ext, impl_extension=src_ext,
header_only=args.header_only, header_only=args.header_only,
output_directory=args.output_directory output_directory=args.output_directory
) )
...@@ -207,7 +256,7 @@ def _get_file_extensions(cfgsrc: SfgConfigSource, extensions: Sequence[str]): ...@@ -207,7 +256,7 @@ def _get_file_extensions(cfgsrc: SfgConfigSource, extensions: Sequence[str]):
if h_ext is not None: if h_ext is not None:
raise SfgConfigException(cfgsrc, "Multiple header file extensions specified.") raise SfgConfigException(cfgsrc, "Multiple header file extensions specified.")
h_ext = ext h_ext = ext
elif ext in SOURCE_FILE_EXTENSIONS: elif ext in IMPL_FILE_EXTENSIONS:
if src_ext is not None: if src_ext is not None:
raise SfgConfigException(cfgsrc, "Multiple source file extensions specified.") raise SfgConfigException(cfgsrc, "Multiple source file extensions specified.")
src_ext = ext src_ext = ext
......
from typing import cast
from jinja2 import Environment, PackageLoader, StrictUndefined from jinja2 import Environment, PackageLoader, StrictUndefined
from textwrap import indent
from os import path from os import path
from ..configuration import SfgOutputSpec
from ..context import SfgContext from ..context import SfgContext
class HeaderSourcePairEmitter: class HeaderSourcePairEmitter:
def __init__(self, def __init__(self, output_spec: SfgOutputSpec):
basename: str, self._basename = output_spec.basename
header_extension: str, self._output_directory = output_spec.output_directory
impl_extension: str, self._header_filename = output_spec.get_header_filename()
output_directory: str): self._impl_filename = output_spec.get_impl_filename()
self._basename = basename
self._output_directory = cast(str, output_directory) self._ospec = output_spec
self._header_filename = f"{basename}.{header_extension}"
self._source_filename = f"{basename}.{impl_extension}"
@property @property
def output_files(self) -> tuple[str, str]: def output_files(self) -> tuple[str, str]:
return ( return (
path.join(self._output_directory, self._header_filename), path.join(self._output_directory, self._header_filename),
path.join(self._output_directory, self._source_filename) path.join(self._output_directory, self._impl_filename)
) )
def write_files(self, ctx: SfgContext): def write_files(self, ctx: SfgContext):
...@@ -31,9 +28,9 @@ class HeaderSourcePairEmitter: ...@@ -31,9 +28,9 @@ class HeaderSourcePairEmitter:
jinja_context = { jinja_context = {
'ctx': ctx, 'ctx': ctx,
'header_filename': self._header_filename, 'header_filename': self._header_filename,
'source_filename': self._source_filename, 'source_filename': self._impl_filename,
'basename': self._basename, 'basename': self._basename,
'prelude': get_prelude_comment(ctx), 'prelude_comment': ctx.prelude_comment,
'definitions': list(ctx.definitions()), 'definitions': list(ctx.definitions()),
'fq_namespace': fq_namespace, 'fq_namespace': fq_namespace,
'public_includes': list(incl.get_code() for incl in ctx.includes() if not incl.private), 'public_includes': list(incl.get_code() for incl in ctx.includes() if not incl.private),
...@@ -55,15 +52,8 @@ class HeaderSourcePairEmitter: ...@@ -55,15 +52,8 @@ class HeaderSourcePairEmitter:
header = env.get_template(f"{template_name}.tmpl.h").render(**jinja_context) header = env.get_template(f"{template_name}.tmpl.h").render(**jinja_context)
source = env.get_template(f"{template_name}.tmpl.cpp").render(**jinja_context) source = env.get_template(f"{template_name}.tmpl.cpp").render(**jinja_context)
with open(path.join(self._output_directory, self._header_filename), 'w') as headerfile: with open(self._ospec.get_header_filepath(), 'w') as headerfile:
headerfile.write(header) headerfile.write(header)
with open(path.join(self._output_directory, self._source_filename), 'w') as cppfile: with open(self._ospec.get_impl_filepath(), 'w') as cppfile:
cppfile.write(source) cppfile.write(source)
def get_prelude_comment(ctx: SfgContext):
if not ctx.prelude_comment:
return ""
return "/*\n" + indent(ctx.prelude_comment, "* ", predicate=lambda _: True) + "*/\n"
from jinja2 import pass_context from jinja2 import pass_context
from textwrap import indent
from pystencils.astnodes import KernelFunction from pystencils.astnodes import KernelFunction
from pystencils import Backend from pystencils import Backend
...@@ -7,6 +8,13 @@ from pystencils.backends import generate_c ...@@ -7,6 +8,13 @@ from pystencils.backends import generate_c
from pystencilssfg.source_components import SfgFunction from pystencilssfg.source_components import SfgFunction
def format_prelude_comment(prelude_comment: str):
if not prelude_comment:
return ""
return "/*\n" + indent(prelude_comment, "* ", predicate=lambda _: True) + "*/\n"
@pass_context @pass_context
def generate_kernel_definition(ctx, ast: KernelFunction): def generate_kernel_definition(ctx, ast: KernelFunction):
return generate_c(ast, dialect=Backend.C) return generate_c(ast, dialect=Backend.C)
...@@ -23,6 +31,7 @@ def generate_function_body(func: SfgFunction): ...@@ -23,6 +31,7 @@ def generate_function_body(func: SfgFunction):
def add_filters_to_jinja(jinja_env): def add_filters_to_jinja(jinja_env):
jinja_env.filters['format_prelude_comment'] = format_prelude_comment
jinja_env.filters['generate_kernel_definition'] = generate_kernel_definition jinja_env.filters['generate_kernel_definition'] = generate_kernel_definition
jinja_env.filters['generate_function_parameter_list'] = generate_function_parameter_list jinja_env.filters['generate_function_parameter_list'] = generate_function_parameter_list
jinja_env.filters['generate_function_body'] = generate_function_body jinja_env.filters['generate_function_body'] = generate_function_body
{{ prelude }} {{ prelude_comment | format_prelude_comment }}
#include "{{header_filename}}" #include "{{header_filename}}"
......
{{ prelude }} {{ prelude_comment | format_prelude_comment }}
#pragma once #pragma once
......
...@@ -27,10 +27,7 @@ class SourceFileGenerator: ...@@ -27,10 +27,7 @@ class SourceFileGenerator:
self._context = SfgContext(config.outer_namespace, config.codestyle, argv=script_args) self._context = SfgContext(config.outer_namespace, config.codestyle, argv=script_args)
from .emitters import HeaderSourcePairEmitter from .emitters import HeaderSourcePairEmitter
self._emitter = HeaderSourcePairEmitter(basename, self._emitter = HeaderSourcePairEmitter(config.get_output_spec(basename))
config.header_extension,
config.source_extension,
config.output_directory)
def clean_files(self): def clean_files(self):
for file in self._emitter.output_files: for file in self._emitter.output_files:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment