diff --git a/integration/test_classes.py b/integration/test_classes.py index 446921ccbdbed683a8a2445ad163ad9f51e60bea..40d7098fd848458ba8f4e162eec64f6589ab4cfd 100644 --- a/integration/test_classes.py +++ b/integration/test_classes.py @@ -1,5 +1,6 @@ # type: ignore from pystencilssfg import SourceFileGenerator, SfgConfiguration +from pystencilssfg.configuration import SfgCodeStyle from pystencilssfg.source_concepts import SrcObject from pystencilssfg.source_components import SfgClass, SfgMemberVariable, SfgConstructor, SfgMethod, SfgVisibility @@ -7,7 +8,11 @@ from pystencils import fields, kernel sfg_config = SfgConfiguration( output_directory="out/test_classes", - outer_namespace="gen_code" + outer_namespace="gen_code", + codestyle=SfgCodeStyle( + code_style="Mozilla", + force_clang_format=True + ) ) f, g = fields("f, g(1): double[2D]") diff --git a/src/pystencilssfg/configuration.py b/src/pystencilssfg/configuration.py index 8eaab0796cb74488f551994f2748ee6445f7e9bd..abaac9d42af528367a6297c53b5571cd89fae360 100644 --- a/src/pystencilssfg/configuration.py +++ b/src/pystencilssfg/configuration.py @@ -34,7 +34,7 @@ from __future__ import annotations from typing import Sequence, Any from os import path from enum import Enum, auto -from dataclasses import dataclass, replace, asdict, InitVar +from dataclasses import dataclass, replace, fields, InitVar from argparse import ArgumentParser from textwrap import indent @@ -70,6 +70,9 @@ class SfgCodeStyle: force_clang_format: bool = False """If set to True, abort code generation if `clang-format` binary cannot be found.""" + clang_format_binary: str = "clang-format" + """Path to the clang-format executable""" + def indent(self, s: str): prefix = " " * self.indent_width return indent(s, prefix) @@ -143,7 +146,7 @@ class SfgConfiguration: self.impl_extension = self.impl_extension[1:] def override(self, other: SfgConfiguration): - other_dict: dict[str, Any] = {k: v for k, v in asdict(other).items() if v is not None} + other_dict: dict[str, Any] = {k: v for k, v in _shallow_dict(other).items() if v is not None} return replace(self, **other_dict) def get_output_spec(self, basename: str) -> SfgOutputSpec: @@ -254,7 +257,7 @@ def merge_configurations(project_config: SfgConfiguration | None, config = config.override(project_config) if cmdline_config is not None: - cmdline_dict = asdict(cmdline_config) + cmdline_dict = _shallow_dict(cmdline_config) # Commandline config completely overrides project and default config config = config.override(cmdline_config) else: @@ -262,7 +265,7 @@ def merge_configurations(project_config: SfgConfiguration | None, if script_config is not None: # User config may only set values not specified on the command line - script_dict = asdict(script_config) + script_dict = _shallow_dict(script_config) for key, cmdline_value in cmdline_dict.items(): if cmdline_value is not None and script_dict[key] is not None: raise SfgException( @@ -293,3 +296,9 @@ def _get_file_extensions(cfgsrc: SfgConfigSource, extensions: Sequence[str]): raise SfgConfigException(cfgsrc, f"Don't know how to interpret file extension '.{ext}'") return h_ext, src_ext + + +def _shallow_dict(obj): + """Workaround to create a shallow dict of a dataclass object, see + https://docs.python.org/3/library/dataclasses.html#dataclasses.asdict.""" + return dict((field.name, getattr(obj, field.name)) for field in fields(obj)) diff --git a/src/pystencilssfg/emitters/clang_format.py b/src/pystencilssfg/emitters/clang_format.py new file mode 100644 index 0000000000000000000000000000000000000000..309908fcf16e09e89c029799445fea6ae88a2b23 --- /dev/null +++ b/src/pystencilssfg/emitters/clang_format.py @@ -0,0 +1,25 @@ +import subprocess +import shutil + +from ..configuration import SfgCodeStyle +from ..exceptions import SfgException + + +def invoke_clang_format(code: str, codestyle: SfgCodeStyle) -> str: + args = [codestyle.clang_format_binary, f"--style={codestyle.code_style}"] + + if not shutil.which("clang-format"): + if codestyle.force_clang_format: + raise SfgException("Could not find clang-format binary.") + else: + return code + + result = subprocess.run(args, input=code, capture_output=True, text=True) + + if result.returncode != 0: + if codestyle.force_clang_format: + raise SfgException(f"Call to clang-format failed: \n{result.stderr}") + else: + return code + + return result.stdout diff --git a/src/pystencilssfg/emitters/classes_printing.py b/src/pystencilssfg/emitters/class_declaration.py similarity index 100% rename from src/pystencilssfg/emitters/classes_printing.py rename to src/pystencilssfg/emitters/class_declaration.py diff --git a/src/pystencilssfg/emitters/header_source_pair.py b/src/pystencilssfg/emitters/header_source_pair.py index 553f84b4054ec4f01cf821d3edc05e92edacffa7..c5fd1c373b1596ce4e682e23e9e181faaaebffe5 100644 --- a/src/pystencilssfg/emitters/header_source_pair.py +++ b/src/pystencilssfg/emitters/header_source_pair.py @@ -5,6 +5,8 @@ from os import path, makedirs from ..configuration import SfgOutputSpec from ..context import SfgContext +from .clang_format import invoke_clang_format + class HeaderSourcePairEmitter: def __init__(self, output_spec: SfgOutputSpec): @@ -53,6 +55,9 @@ class HeaderSourcePairEmitter: header = env.get_template(f"{template_name}.tmpl.h").render(**jinja_context) source = env.get_template(f"{template_name}.tmpl.cpp").render(**jinja_context) + header = invoke_clang_format(header, ctx.codestyle) + source = invoke_clang_format(source, ctx.codestyle) + makedirs(self._output_directory, exist_ok=True) with open(self._ospec.get_header_filepath(), 'w') as headerfile: diff --git a/src/pystencilssfg/emitters/jinja_filters.py b/src/pystencilssfg/emitters/jinja_filters.py index 31169f4aa9e21b0248630207cce14d234e64dd04..67d96f018282944e28b234f3f1df5dfc7a3e726b 100644 --- a/src/pystencilssfg/emitters/jinja_filters.py +++ b/src/pystencilssfg/emitters/jinja_filters.py @@ -6,7 +6,7 @@ from pystencils import Backend from pystencils.backends import generate_c from pystencilssfg.source_components import SfgFunction, SfgClass -from .classes_printing import ClassDeclarationPrinter +from .class_declaration import ClassDeclarationPrinter def format_prelude_comment(prelude_comment: str): diff --git a/test_classses_out/test_classes.cpp b/test_classses_out/test_classes.cpp deleted file mode 100644 index 93b26bc2e71d7e7db9b6bb590a1f27b3c3e4017a..0000000000000000000000000000000000000000 --- a/test_classses_out/test_classes.cpp +++ /dev/null @@ -1,32 +0,0 @@ - - -#include "test_classes.h" - - -#define FUNC_PREFIX inline - - -/************************************************************************************* - * Kernels -*************************************************************************************/ - -namespace kernels { - -FUNC_PREFIX void kernel(double * RESTRICT _data_f, double * RESTRICT const _data_g, int64_t const _size_f_0, int64_t const _size_f_1, int64_t const _stride_f_0, int64_t const _stride_f_1, int64_t const _stride_g_0, int64_t const _stride_g_1) -{ - for (int64_t ctr_0 = 0; ctr_0 < _size_f_0; ctr_0 += 1) - { - for (int64_t ctr_1 = 0; ctr_1 < _size_f_1; ctr_1 += 1) - { - _data_f[_stride_f_0*ctr_0 + _stride_f_1*ctr_1] = 3.0*_data_g[_stride_g_0*ctr_0 + _stride_g_1*ctr_1]; - } - } -} - -} // namespace kernels - -/************************************************************************************* - * Functions -*************************************************************************************/ - - diff --git a/test_classses_out/test_classes.h b/test_classses_out/test_classes.h deleted file mode 100644 index 320266589cbf54882ae87b599a668053e62155c2..0000000000000000000000000000000000000000 --- a/test_classses_out/test_classes.h +++ /dev/null @@ -1,24 +0,0 @@ - - -#pragma once - -#include <cstdint> - - - -#define RESTRICT __restrict__ - - - -class MyClass - : -{ - -// default: - - -}; - - - -