Skip to content
Snippets Groups Projects
Commit 0657df2a authored by Frederik Hennig's avatar Frederik Hennig
Browse files

integrated clang-format

parent 8f789bf6
No related merge requests found
Pipeline #57800 passed with stages
in 51 seconds
# type: ignore
from pystencilssfg import SourceFileGenerator, SfgConfiguration
from pystencilssfg.configuration import SfgCodeStyle
from pystencilssfg.source_concepts import SrcObject
from pystencilssfg.source_components import SfgClass, SfgMemberVariable, SfgConstructor, SfgMethod, SfgVisibility
......@@ -7,7 +8,11 @@ from pystencils import fields, kernel
sfg_config = SfgConfiguration(
output_directory="out/test_classes",
outer_namespace="gen_code"
outer_namespace="gen_code",
codestyle=SfgCodeStyle(
code_style="Mozilla",
force_clang_format=True
)
)
f, g = fields("f, g(1): double[2D]")
......
......@@ -34,7 +34,7 @@ from __future__ import annotations
from typing import Sequence, Any
from os import path
from enum import Enum, auto
from dataclasses import dataclass, replace, asdict, InitVar
from dataclasses import dataclass, replace, fields, InitVar
from argparse import ArgumentParser
from textwrap import indent
......@@ -70,6 +70,9 @@ class SfgCodeStyle:
force_clang_format: bool = False
"""If set to True, abort code generation if `clang-format` binary cannot be found."""
clang_format_binary: str = "clang-format"
"""Path to the clang-format executable"""
def indent(self, s: str):
prefix = " " * self.indent_width
return indent(s, prefix)
......@@ -143,7 +146,7 @@ class SfgConfiguration:
self.impl_extension = self.impl_extension[1:]
def override(self, other: SfgConfiguration):
other_dict: dict[str, Any] = {k: v for k, v in asdict(other).items() if v is not None}
other_dict: dict[str, Any] = {k: v for k, v in _shallow_dict(other).items() if v is not None}
return replace(self, **other_dict)
def get_output_spec(self, basename: str) -> SfgOutputSpec:
......@@ -254,7 +257,7 @@ def merge_configurations(project_config: SfgConfiguration | None,
config = config.override(project_config)
if cmdline_config is not None:
cmdline_dict = asdict(cmdline_config)
cmdline_dict = _shallow_dict(cmdline_config)
# Commandline config completely overrides project and default config
config = config.override(cmdline_config)
else:
......@@ -262,7 +265,7 @@ def merge_configurations(project_config: SfgConfiguration | None,
if script_config is not None:
# User config may only set values not specified on the command line
script_dict = asdict(script_config)
script_dict = _shallow_dict(script_config)
for key, cmdline_value in cmdline_dict.items():
if cmdline_value is not None and script_dict[key] is not None:
raise SfgException(
......@@ -293,3 +296,9 @@ def _get_file_extensions(cfgsrc: SfgConfigSource, extensions: Sequence[str]):
raise SfgConfigException(cfgsrc, f"Don't know how to interpret file extension '.{ext}'")
return h_ext, src_ext
def _shallow_dict(obj):
"""Workaround to create a shallow dict of a dataclass object, see
https://docs.python.org/3/library/dataclasses.html#dataclasses.asdict."""
return dict((field.name, getattr(obj, field.name)) for field in fields(obj))
import subprocess
import shutil
from ..configuration import SfgCodeStyle
from ..exceptions import SfgException
def invoke_clang_format(code: str, codestyle: SfgCodeStyle) -> str:
args = [codestyle.clang_format_binary, f"--style={codestyle.code_style}"]
if not shutil.which("clang-format"):
if codestyle.force_clang_format:
raise SfgException("Could not find clang-format binary.")
else:
return code
result = subprocess.run(args, input=code, capture_output=True, text=True)
if result.returncode != 0:
if codestyle.force_clang_format:
raise SfgException(f"Call to clang-format failed: \n{result.stderr}")
else:
return code
return result.stdout
......@@ -5,6 +5,8 @@ from os import path, makedirs
from ..configuration import SfgOutputSpec
from ..context import SfgContext
from .clang_format import invoke_clang_format
class HeaderSourcePairEmitter:
def __init__(self, output_spec: SfgOutputSpec):
......@@ -53,6 +55,9 @@ class HeaderSourcePairEmitter:
header = env.get_template(f"{template_name}.tmpl.h").render(**jinja_context)
source = env.get_template(f"{template_name}.tmpl.cpp").render(**jinja_context)
header = invoke_clang_format(header, ctx.codestyle)
source = invoke_clang_format(source, ctx.codestyle)
makedirs(self._output_directory, exist_ok=True)
with open(self._ospec.get_header_filepath(), 'w') as headerfile:
......
......@@ -6,7 +6,7 @@ from pystencils import Backend
from pystencils.backends import generate_c
from pystencilssfg.source_components import SfgFunction, SfgClass
from .classes_printing import ClassDeclarationPrinter
from .class_declaration import ClassDeclarationPrinter
def format_prelude_comment(prelude_comment: str):
......
#include "test_classes.h"
#define FUNC_PREFIX inline
/*************************************************************************************
* Kernels
*************************************************************************************/
namespace kernels {
FUNC_PREFIX void kernel(double * RESTRICT _data_f, double * RESTRICT const _data_g, int64_t const _size_f_0, int64_t const _size_f_1, int64_t const _stride_f_0, int64_t const _stride_f_1, int64_t const _stride_g_0, int64_t const _stride_g_1)
{
for (int64_t ctr_0 = 0; ctr_0 < _size_f_0; ctr_0 += 1)
{
for (int64_t ctr_1 = 0; ctr_1 < _size_f_1; ctr_1 += 1)
{
_data_f[_stride_f_0*ctr_0 + _stride_f_1*ctr_1] = 3.0*_data_g[_stride_g_0*ctr_0 + _stride_g_1*ctr_1];
}
}
}
} // namespace kernels
/*************************************************************************************
* Functions
*************************************************************************************/
#pragma once
#include <cstdint>
#define RESTRICT __restrict__
class MyClass
:
{
// default:
};
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment