diff --git a/integration/CMakeDemo/codegen_config.py b/integration/CMakeDemo/codegen_config.py index 40a70f91c70e0d9cec2a5f748bd002a99f875395..2f1231a95d4050b994235b4dd56e964819d774c0 100644 --- a/integration/CMakeDemo/codegen_config.py +++ b/integration/CMakeDemo/codegen_config.py @@ -10,7 +10,7 @@ def sfg_config(): return SfgConfiguration( header_extension='hpp', - source_extension='cpp', + impl_extension='cpp', outer_namespace='cmake_demo', project_info=project_info ) diff --git a/src/pystencilssfg/cli.py b/src/pystencilssfg/cli.py index c51db09d28adc0e91633856222a4d5d7dc1c9805..5640be817051c41c35190e10cf1ff0f61ad9b0d4 100644 --- a/src/pystencilssfg/cli.py +++ b/src/pystencilssfg/cli.py @@ -76,7 +76,7 @@ def list_files(args): emitter = HeaderSourcePairEmitter(basename, config.header_extension, - config.source_extension, + config.impl_extension, config.output_directory) print(args.sep.join(emitter.output_files), end=os.linesep if args.newline else '') diff --git a/src/pystencilssfg/configuration.py b/src/pystencilssfg/configuration.py index e155e8f51ad4c6c088475c0340f057b4de07d155..b05ba299a4272d5057879e964ddd18ce403fddfb 100644 --- a/src/pystencilssfg/configuration.py +++ b/src/pystencilssfg/configuration.py @@ -14,7 +14,7 @@ from importlib import util as iutil from .exceptions import SfgException HEADER_FILE_EXTENSIONS = {'h', 'hpp'} -SOURCE_FILE_EXTENSIONS = {'c', 'cpp'} +IMPL_FILE_EXTENSIONS = {'c', 'cpp', '.impl.h'} class SfgConfigSource(Enum): @@ -35,20 +35,57 @@ class SfgConfigException(Exception): class SfgCodeStyle: indent_width: int = 2 + code_style: str = "LLVM" + """Code style to be used by clang-format. Passed verbatim to `--style` argument of the clang-format CLI.""" + + force_clang_format: bool = False + """If set to True, abort code generation if `clang-format` binary cannot be found.""" + def indent(self, s: str): prefix = " " * self.indent_width return indent(s, prefix) +@dataclass +class SfgOutputSpec: + """Name and path specification for files output by the code generator. + + Filenames are constructed as `<output_directory>/<basename>.<extension>`.""" + + output_directory: str + """Directory to which the generated files should be written.""" + + basename: str + """Base name for output files.""" + + header_extension: str + """File extension for generated header file.""" + + impl_extension: str + """File extension for generated implementation file.""" + + def get_header_filename(self): + return f"{self.basename}.{self.header_extension}" + + def get_impl_filename(self): + return f"{self.basename}.{self.impl_extension}" + + def get_header_filepath(self): + return path.join(self.output_directory, self.get_header_filename()) + + def get_impl_filepath(self): + return path.join(self.output_directory, self.get_impl_filename()) + + @dataclass class SfgConfiguration: config_source: InitVar[SfgConfigSource | None] = None header_extension: str | None = None - """File extension for generated header files.""" + """File extension for generated header file.""" - source_extension: str | None = None - """File extension for generated source files.""" + impl_extension: str | None = None + """File extension for generated implementation file.""" header_only: bool | None = None """If set to `True`, generate only a header file without accompaning source file.""" @@ -73,22 +110,34 @@ class SfgConfiguration: if self.header_extension and self.header_extension[0] == '.': self.header_extension = self.header_extension[1:] - if self.source_extension and self.source_extension[0] == '.': - self.source_extension = self.source_extension[1:] + if self.impl_extension and self.impl_extension[0] == '.': + self.impl_extension = self.impl_extension[1:] def override(self, other: SfgConfiguration): other_dict: dict[str, Any] = {k: v for k, v in asdict(other).items() if v is not None} return replace(self, **other_dict) + def get_output_spec(self, basename: str) -> SfgOutputSpec: + assert self.header_extension is not None + assert self.impl_extension is not None + assert self.output_directory is not None + + return SfgOutputSpec( + self.output_directory, + basename, + self.header_extension, + self.impl_extension + ) + DEFAULT_CONFIG = SfgConfiguration( config_source=SfgConfigSource.DEFAULT, header_extension='h', - source_extension='cpp', + impl_extension='cpp', header_only=False, outer_namespace=None, codestyle=SfgCodeStyle(), - output_directory="" + output_directory="." ) @@ -145,7 +194,7 @@ def config_from_parser_args(args): cmdline_config = SfgConfiguration( config_source=SfgConfigSource.COMMANDLINE, header_extension=h_ext, - source_extension=src_ext, + impl_extension=src_ext, header_only=args.header_only, output_directory=args.output_directory ) @@ -207,7 +256,7 @@ def _get_file_extensions(cfgsrc: SfgConfigSource, extensions: Sequence[str]): if h_ext is not None: raise SfgConfigException(cfgsrc, "Multiple header file extensions specified.") h_ext = ext - elif ext in SOURCE_FILE_EXTENSIONS: + elif ext in IMPL_FILE_EXTENSIONS: if src_ext is not None: raise SfgConfigException(cfgsrc, "Multiple source file extensions specified.") src_ext = ext diff --git a/src/pystencilssfg/emitters/header_source_pair.py b/src/pystencilssfg/emitters/header_source_pair.py index fc4e9433a9bdf7697ce59774526df8844444fdbb..33653ab5d9772b8d2c6fe3a511abe2e7bb491cbc 100644 --- a/src/pystencilssfg/emitters/header_source_pair.py +++ b/src/pystencilssfg/emitters/header_source_pair.py @@ -1,28 +1,25 @@ -from typing import cast from jinja2 import Environment, PackageLoader, StrictUndefined -from textwrap import indent from os import path +from ..configuration import SfgOutputSpec from ..context import SfgContext class HeaderSourcePairEmitter: - def __init__(self, - basename: str, - header_extension: str, - impl_extension: str, - output_directory: str): - self._basename = basename - self._output_directory = cast(str, output_directory) - self._header_filename = f"{basename}.{header_extension}" - self._source_filename = f"{basename}.{impl_extension}" + def __init__(self, output_spec: SfgOutputSpec): + self._basename = output_spec.basename + self._output_directory = output_spec.output_directory + self._header_filename = output_spec.get_header_filename() + self._impl_filename = output_spec.get_impl_filename() + + self._ospec = output_spec @property def output_files(self) -> tuple[str, str]: return ( path.join(self._output_directory, self._header_filename), - path.join(self._output_directory, self._source_filename) + path.join(self._output_directory, self._impl_filename) ) def write_files(self, ctx: SfgContext): @@ -31,9 +28,9 @@ class HeaderSourcePairEmitter: jinja_context = { 'ctx': ctx, 'header_filename': self._header_filename, - 'source_filename': self._source_filename, + 'source_filename': self._impl_filename, 'basename': self._basename, - 'prelude': get_prelude_comment(ctx), + 'prelude_comment': ctx.prelude_comment, 'definitions': list(ctx.definitions()), 'fq_namespace': fq_namespace, 'public_includes': list(incl.get_code() for incl in ctx.includes() if not incl.private), @@ -55,15 +52,8 @@ class HeaderSourcePairEmitter: header = env.get_template(f"{template_name}.tmpl.h").render(**jinja_context) source = env.get_template(f"{template_name}.tmpl.cpp").render(**jinja_context) - with open(path.join(self._output_directory, self._header_filename), 'w') as headerfile: + with open(self._ospec.get_header_filepath(), 'w') as headerfile: headerfile.write(header) - with open(path.join(self._output_directory, self._source_filename), 'w') as cppfile: + with open(self._ospec.get_impl_filepath(), 'w') as cppfile: cppfile.write(source) - - -def get_prelude_comment(ctx: SfgContext): - if not ctx.prelude_comment: - return "" - - return "/*\n" + indent(ctx.prelude_comment, "* ", predicate=lambda _: True) + "*/\n" diff --git a/src/pystencilssfg/emitters/jinja_filters.py b/src/pystencilssfg/emitters/jinja_filters.py index 7152c98e486b16030076b63d938ecb2bf7b48114..f6ee5e19ffba51589f64465f6eac2b8acbb20541 100644 --- a/src/pystencilssfg/emitters/jinja_filters.py +++ b/src/pystencilssfg/emitters/jinja_filters.py @@ -1,4 +1,5 @@ from jinja2 import pass_context +from textwrap import indent from pystencils.astnodes import KernelFunction from pystencils import Backend @@ -7,6 +8,13 @@ from pystencils.backends import generate_c from pystencilssfg.source_components import SfgFunction +def format_prelude_comment(prelude_comment: str): + if not prelude_comment: + return "" + + return "/*\n" + indent(prelude_comment, "* ", predicate=lambda _: True) + "*/\n" + + @pass_context def generate_kernel_definition(ctx, ast: KernelFunction): return generate_c(ast, dialect=Backend.C) @@ -23,6 +31,7 @@ def generate_function_body(func: SfgFunction): def add_filters_to_jinja(jinja_env): + jinja_env.filters['format_prelude_comment'] = format_prelude_comment jinja_env.filters['generate_kernel_definition'] = generate_kernel_definition jinja_env.filters['generate_function_parameter_list'] = generate_function_parameter_list jinja_env.filters['generate_function_body'] = generate_function_body diff --git a/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.cpp b/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.cpp index 12b81bf6146298a88992b0a1aa0294af002e68cf..b1d37ccc1cf4de281db76adc82996c79705a5c91 100644 --- a/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.cpp +++ b/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.cpp @@ -1,4 +1,4 @@ -{{ prelude }} +{{ prelude_comment | format_prelude_comment }} #include "{{header_filename}}" diff --git a/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.h b/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.h index 312e50ede2d495aac55a30dbef8b04d078f4d02c..40ce1edd094e3322f3d55a8824888fa8f1acc8c1 100644 --- a/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.h +++ b/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.h @@ -1,4 +1,4 @@ -{{ prelude }} +{{ prelude_comment | format_prelude_comment }} #pragma once diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py index 12449f83f09736ab1b03b004d412c7f13848a7b1..273d1421ce90dc2e7400dbf09bdf7029a9da05bb 100644 --- a/src/pystencilssfg/generator.py +++ b/src/pystencilssfg/generator.py @@ -27,10 +27,7 @@ class SourceFileGenerator: self._context = SfgContext(config.outer_namespace, config.codestyle, argv=script_args) from .emitters import HeaderSourcePairEmitter - self._emitter = HeaderSourcePairEmitter(basename, - config.header_extension, - config.source_extension, - config.output_directory) + self._emitter = HeaderSourcePairEmitter(config.get_output_spec(basename)) def clean_files(self): for file in self._emitter.output_files: