diff --git a/pystencilssfg/configuration.py b/pystencilssfg/configuration.py index 18eabe8d213e017777d30c5feb397dcade4a7410..c712cb15a18fd3d8af6d76a6f88694d8a50d315e 100644 --- a/pystencilssfg/configuration.py +++ b/pystencilssfg/configuration.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from typing import List, Sequence from enum import Enum, auto -from dataclasses import dataclass, replace +from dataclasses import dataclass, replace, asdict, fields from argparse import ArgumentParser from jinja2.filters import do_indent @@ -44,13 +46,18 @@ class SfgConfiguration: if self.header_only: raise SfgException( "Header-only code generation is not implemented yet.") - - if self.header_extension[0] == '.': + + if self.header_extension and self.header_extension[0] == '.': self.header_extension = self.header_extension[1:] - if self.source_extension[0] == '.': + if self.source_extension and self.source_extension[0] == '.': self.source_extension = self.source_extension[1:] + def override(self, other: SfgConfiguration): + other_dict = asdict(other) + other_dict = {k: v for k, v in other_dict.items() if v is not None} + return replace(self, **other_dict) + DEFAULT_CONFIG = SfgConfiguration( header_extension='h', @@ -62,32 +69,11 @@ DEFAULT_CONFIG = SfgConfiguration( ) -def get_file_extensions(self, extensions: Sequence[str]): - h_ext = None - src_ext = None - - extensions = ((ext[1:] if ext[0] == '.' else ext) for ext in extensions) - - for ext in extensions: - if ext in HEADER_FILE_EXTENSIONS: - if h_ext is not None: - raise ValueError("Multiple header file extensions found.") - h_ext = ext - elif ext in SOURCE_FILE_EXTENSIONS: - if src_ext is not None: - raise ValueError("Multiple source file extensions found.") - src_ext = ext - else: - raise ValueError(f"Don't know how to interpret extension '.{ext}'") - - return h_ext, src_ext - - def run_configurator(configurator_script: str): raise NotImplementedError() -def config_from_commandline(self, argv: List[str]): +def config_from_commandline(argv: List[str]): parser = ArgumentParser("pystencilssfg", description="pystencils Source File Generator", allow_abbrev=False) @@ -109,7 +95,7 @@ def config_from_commandline(self, argv: List[str]): project_config = None if args.file_extensions is not None: - h_ext, src_ext = get_file_extensions(args.file_extensions) + h_ext, src_ext = _get_file_extensions(args.file_extensions) else: h_ext, src_ext = None, None @@ -130,23 +116,44 @@ def merge_configurations(project_config: SfgConfiguration, config = DEFAULT_CONFIG if project_config is not None: - config = replace(DEFAULT_CONFIG, **(project_config.asdict())) + config = config.override(project_config) if cmdline_config is not None: - cmdline_dict = cmdline_config.asdict() + cmdline_dict = asdict(cmdline_config) # Commandline config completely overrides project and default config - config = replace(config, **cmdline_dict) + config = config.override(cmdline_config) else: cmdline_dict = {} if script_config is not None: # User config may only set values not specified on the command line - script_dict = script_config.asdict() + script_dict = asdict(script_config) for key, cmdline_value in cmdline_dict.items(): if cmdline_value is not None and script_dict[key] is not None: raise SfgException( f"Conflicting configuration: Parameter {key} was specified both in the script and on the command line.") - config = replace(config, **script_dict) + config = config.override(script_config) return config + + +def _get_file_extensions(extensions: Sequence[str]): + h_ext = None + src_ext = None + + extensions = ((ext[1:] if ext[0] == '.' else ext) for ext in extensions) + + for ext in extensions: + if ext in HEADER_FILE_EXTENSIONS: + if h_ext is not None: + raise ValueError("Multiple header file extensions found.") + h_ext = ext + elif ext in SOURCE_FILE_EXTENSIONS: + if src_ext is not None: + raise ValueError("Multiple source file extensions found.") + src_ext = ext + else: + raise ValueError(f"Don't know how to interpret extension '.{ext}'") + + return h_ext, src_ext diff --git a/pystencilssfg/context.py b/pystencilssfg/context.py index feece85c2f8c0f91e38f58abe7dd2f46e40a844d..954f3f09f287b51b8ce162d0ff0cf3227ca4f0ca 100644 --- a/pystencilssfg/context.py +++ b/pystencilssfg/context.py @@ -21,7 +21,10 @@ from .source_components import SfgFunction, SfgHeaderInclude class SourceFileGenerator: - def __init__(self, sfg_config: SfgConfiguration): + def __init__(self, sfg_config: SfgConfiguration = None): + if sfg_config and not isinstance(sfg_config, SfgConfiguration): + raise TypeError("sfg_config is not an SfgConfiguration.") + import __main__ scriptpath = __main__.__file__ scriptname = path.split(scriptpath)[1] @@ -34,7 +37,7 @@ class SourceFileGenerator: self._context = SfgContext(script_args, config) from .emitters.cpu.basic_cpu import BasicCpuEmitter - self._emitter = BasicCpuEmitter(self._context, basename, config.output_directory) + self._emitter = BasicCpuEmitter(basename, config) def clean_files(self): for file in self._emitter.output_files: @@ -47,7 +50,7 @@ class SourceFileGenerator: def __exit__(self, exc_type, exc_value, traceback): if exc_type is None: - self._emitter.write_files() + self._emitter.write_files(self._context) class SfgContext: diff --git a/pystencilssfg/emitters/cpu/basic_cpu.py b/pystencilssfg/emitters/cpu/basic_cpu.py index 954b300c1b6a7f7b347aeac7b89a3edc6ee9f2c1..962acda78aa2354c69ab96e5767a1fca1d968235 100644 --- a/pystencilssfg/emitters/cpu/basic_cpu.py +++ b/pystencilssfg/emitters/cpu/basic_cpu.py @@ -2,15 +2,15 @@ from jinja2 import Environment, PackageLoader, StrictUndefined from os import path +from ...configuration import SfgConfiguration from ...context import SfgContext class BasicCpuEmitter: - def __init__(self, ctx: SfgContext, basename: str, output_directory: str): - self._ctx = ctx + def __init__(self, basename: str, config: SfgConfiguration): self._basename = basename - self._output_directory = output_directory - self._header_filename = basename + ".h" - self._cpp_filename = basename + ".cpp" + self._output_directory = config.output_directory + self._header_filename = f"{basename}.{config.header_extension}" + self._cpp_filename = f"{basename}.{config.source_extension}" @property def output_files(self) -> str: @@ -19,15 +19,15 @@ class BasicCpuEmitter: path.join(self._output_directory, self._cpp_filename) ) - def write_files(self): + def write_files(self, ctx: SfgContext): jinja_context = { - 'ctx': self._ctx, + 'ctx': ctx, 'basename': self._basename, - 'root_namespace': self._ctx.root_namespace, - 'public_includes': list(incl.get_code() for incl in self._ctx.includes() if not incl.private), - 'private_includes': list(incl.get_code() for incl in self._ctx.includes() if incl.private), - 'kernel_namespaces': list(self._ctx.kernel_namespaces()), - 'functions': list(self._ctx.functions()) + 'root_namespace': ctx.root_namespace, + 'public_includes': list(incl.get_code() for incl in ctx.includes() if not incl.private), + 'private_includes': list(incl.get_code() for incl in ctx.includes() if incl.private), + 'kernel_namespaces': list(ctx.kernel_namespaces()), + 'functions': list(ctx.functions()) } template_name = "BasicCpu" diff --git a/tests/cmake_integration/kernels.py b/tests/cmake_integration/kernels.py index 7a679447fd9883a3b3593c67ca293afd6f9bab99..d780b8f04669b5b33ee5210e91cc03a60a762723 100644 --- a/tests/cmake_integration/kernels.py +++ b/tests/cmake_integration/kernels.py @@ -7,7 +7,7 @@ from pystencilssfg import SourceFileGenerator from pystencilssfg.source_concepts.cpp import std_mdspan -with SourceFileGenerator("poisson") as sfg: +with SourceFileGenerator() as sfg: src, dst = ps.fields("src, dst(1) : double[2D]") h = sp.Symbol('h')