diff --git a/pdm.lock b/pdm.lock index 315541a523f73d4afdf538ef3db0901aa3f34e12..c6638a202638b423a78dee411100e6da56792cce 100644 --- a/pdm.lock +++ b/pdm.lock @@ -2,10 +2,10 @@ # It is not intended for manual editing. [metadata] -groups = ["default", "docs", "interactive", "code_quality"] +groups = ["default", "code_quality", "docs", "interactive"] strategy = ["cross_platform"] lock_version = "4.4" -content_hash = "sha256:f9e6e0eed785eecd687dfee3afe479681a2082cee5c00285e1975fbf8b7d50df" +content_hash = "sha256:2c854f8da4b29c3080cd89c774409f95c47d3532c953cf10ecaa67d0b77ff9cf" [[package]] name = "appdirs" @@ -16,15 +16,6 @@ files = [ {file = "appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41"}, ] -[[package]] -name = "appnope" -version = "0.1.3" -summary = "Disable App Nap on macOS >= 10.9" -files = [ - {file = "appnope-0.1.3-py2.py3-none-any.whl", hash = "sha256:265a455292d0bd8a72453494fa24df5a11eb18373a60c7c0430889f22548605e"}, - {file = "appnope-0.1.3.tar.gz", hash = "sha256:02bd91c4de869fbb1e1c50aafc4098827a7a54ab2f39d9dcba6c9547ed920e24"}, -] - [[package]] name = "asttokens" version = "2.4.1" @@ -210,35 +201,34 @@ files = [ [[package]] name = "idna" -version = "3.4" +version = "3.6" requires_python = ">=3.5" summary = "Internationalized Domain Names in Applications (IDNA)" files = [ - {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, - {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, + {file = "idna-3.6-py3-none-any.whl", hash = "sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f"}, + {file = "idna-3.6.tar.gz", hash = "sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca"}, ] [[package]] name = "ipython" -version = "8.17.2" +version = "8.18.1" requires_python = ">=3.9" summary = "IPython: Productive Interactive Computing" dependencies = [ - "appnope; sys_platform == \"darwin\"", "colorama; sys_platform == \"win32\"", "decorator", "exceptiongroup; python_version < \"3.11\"", "jedi>=0.16", "matplotlib-inline", "pexpect>4.3; sys_platform != \"win32\"", - "prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30", + "prompt-toolkit<3.1.0,>=3.0.41", "pygments>=2.4.0", "stack-data", "traitlets>=5", ] files = [ - {file = "ipython-8.17.2-py3-none-any.whl", hash = "sha256:1e4d1d666a023e3c93585ba0d8e962867f7a111af322efff6b9c58062b3e5444"}, - {file = "ipython-8.17.2.tar.gz", hash = "sha256:126bb57e1895594bb0d91ea3090bbd39384f6fe87c3d57fd558d0670f50339bb"}, + {file = "ipython-8.18.1-py3-none-any.whl", hash = "sha256:e8267419d72d81955ec1177f8a29aaa90ac80ad647499201119e2f05e99aa397"}, + {file = "ipython-8.18.1.tar.gz", hash = "sha256:ca6f079bb33457c66e233e4580ebfc4128855b4cf6370dddd73842a9563e8a27"}, ] [[package]] @@ -400,7 +390,7 @@ files = [ [[package]] name = "mkdocs-material" -version = "9.4.11" +version = "9.4.14" requires_python = ">=3.8" summary = "Documentation that simply works" dependencies = [ @@ -417,18 +407,18 @@ dependencies = [ "requests~=2.26", ] files = [ - {file = "mkdocs_material-9.4.11-py3-none-any.whl", hash = "sha256:794b81d74df4fd7dee952dd4502f7b6a7913a1fc56021e5f36f8e96eb20ffb25"}, - {file = "mkdocs_material-9.4.11.tar.gz", hash = "sha256:82c2bdbdc8445854f400d12831a8b0f7602efaaead7b264ac3c45aa3aa240755"}, + {file = "mkdocs_material-9.4.14-py3-none-any.whl", hash = "sha256:dbc78a4fea97b74319a6aa9a2f0be575a6028be6958f813ba367188f7b8428f6"}, + {file = "mkdocs_material-9.4.14.tar.gz", hash = "sha256:a511d3ff48fa8718b033e7e37d17abd9cc1de0fdf0244a625ca2ae2387e2416d"}, ] [[package]] name = "mkdocs-material-extensions" -version = "1.3" +version = "1.3.1" requires_python = ">=3.8" summary = "Extension pack for Python Markdown and MkDocs Material." files = [ - {file = "mkdocs_material_extensions-1.3-py3-none-any.whl", hash = "sha256:0297cc48ba68a9fdd1ef3780a3b41b534b0d0df1d1181a44676fda5f464eeadc"}, - {file = "mkdocs_material_extensions-1.3.tar.gz", hash = "sha256:f0446091503acb110a7cab9349cbc90eeac51b58d1caa92a704a81ca1e24ddbd"}, + {file = "mkdocs_material_extensions-1.3.1-py3-none-any.whl", hash = "sha256:adff8b62700b25cb77b53358dad940f3ef973dd6db797907c49e3c2ef3ab4e31"}, + {file = "mkdocs_material_extensions-1.3.1.tar.gz", hash = "sha256:10c9511cea88f568257f960358a467d12b970e1f7b2c0e5fb2bb48cab1928443"}, ] [[package]] @@ -605,14 +595,14 @@ files = [ [[package]] name = "pexpect" -version = "4.8.0" +version = "4.9.0" summary = "Pexpect allows easy control of interactive console applications." dependencies = [ "ptyprocess>=0.5", ] files = [ - {file = "pexpect-4.8.0-py2.py3-none-any.whl", hash = "sha256:0b48a55dcb3c05f3329815901ea4fc1537514d6ba867a152b581d69ae3710937"}, - {file = "pexpect-4.8.0.tar.gz", hash = "sha256:fc65a43959d153d0114afe13997d439c22823a27cefceb5ff35c2178c6784c0c"}, + {file = "pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523"}, + {file = "pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f"}, ] [[package]] @@ -688,16 +678,16 @@ files = [ [[package]] name = "pymdown-extensions" -version = "10.4" +version = "10.5" requires_python = ">=3.8" summary = "Extension pack for Python Markdown." dependencies = [ - "markdown>=3.2", + "markdown>=3.5", "pyyaml", ] files = [ - {file = "pymdown_extensions-10.4-py3-none-any.whl", hash = "sha256:cfc28d6a09d19448bcbf8eee3ce098c7d17ff99f7bd3069db4819af181212037"}, - {file = "pymdown_extensions-10.4.tar.gz", hash = "sha256:bc46f11749ecd4d6b71cf62396104b4a200bad3498cb0f5dad1b8502fe461a35"}, + {file = "pymdown_extensions-10.5-py3-none-any.whl", hash = "sha256:1f0ca8bb5beff091315f793ee17683bc1390731f6ac4c5eb01e27464b80fe879"}, + {file = "pymdown_extensions-10.5.tar.gz", hash = "sha256:1b60f1e462adbec5a1ed79dac91f666c9c0d241fa294de1989f29d20096cfd0b"}, ] [[package]] @@ -899,12 +889,12 @@ files = [ [[package]] name = "traitlets" -version = "5.13.0" +version = "5.14.0" requires_python = ">=3.8" summary = "Traitlets Python configuration system" files = [ - {file = "traitlets-5.13.0-py3-none-any.whl", hash = "sha256:baf991e61542da48fe8aef8b779a9ea0aa38d8a54166ee250d5af5ecf4486619"}, - {file = "traitlets-5.13.0.tar.gz", hash = "sha256:9b232b9430c8f57288c1024b34a8f0251ddcc47268927367a0dd3eeaca40deb5"}, + {file = "traitlets-5.14.0-py3-none-any.whl", hash = "sha256:f14949d23829023013c47df20b4a76ccd1a85effb786dc060f34de7948361b33"}, + {file = "traitlets-5.14.0.tar.gz", hash = "sha256:fcdaa8ac49c04dfa0ed3ee3384ef6dfdb5d6f3741502be247279407679296772"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index cc7875a9af6fe60f18b6e9265c3ca7c18a047f7d..85ad9662ac2ef54325bf442ed8d5d6f63cd26674 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,6 @@ authors = [ {name = "Frederik Hennig", email = "frederik.hennig@fau.de"}, ] dependencies = [ - "jinja2>=3.1.2", "pystencils>=1.3.2", ] requires-python = ">=3.10" diff --git a/src/pystencilssfg/cli.py b/src/pystencilssfg/cli.py index 972aae13259182b66524ccaef3845dbf30aad8e1..933c28a457dd4a25f2a4a1f5c794646ff1f7c000 100644 --- a/src/pystencilssfg/cli.py +++ b/src/pystencilssfg/cli.py @@ -72,7 +72,7 @@ def list_files(args): _, scriptname = path.split(args.codegen_script) basename = path.splitext(scriptname)[0] - from .emitters import HeaderSourcePairEmitter + from .emission import HeaderSourcePairEmitter emitter = HeaderSourcePairEmitter(config.get_output_spec(basename)) diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index 48ecfd2bb8368f9d534ca8c753d229333f8178cc..69fa58dab860863e5d987aa651ca581ef03b29a0 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -36,6 +36,10 @@ class SfgContext: self._functions: dict[str, SfgFunction] = dict() self._classes: dict[str, SfgClass] = dict() + # Standard stuff + self.add_include(SfgHeaderInclude("cstdint", system_header=True)) + self.add_definition("#define RESTRICT __restrict__") + @property def argv(self) -> Sequence[str]: """If this context was created by a `pystencilssfg.SourceFileGenerator`, provides the command diff --git a/src/pystencilssfg/emitters/__init__.py b/src/pystencilssfg/emission/__init__.py similarity index 100% rename from src/pystencilssfg/emitters/__init__.py rename to src/pystencilssfg/emission/__init__.py diff --git a/src/pystencilssfg/emitters/clang_format.py b/src/pystencilssfg/emission/clang_format.py similarity index 100% rename from src/pystencilssfg/emitters/clang_format.py rename to src/pystencilssfg/emission/clang_format.py diff --git a/src/pystencilssfg/emission/header_source_pair.py b/src/pystencilssfg/emission/header_source_pair.py new file mode 100644 index 0000000000000000000000000000000000000000..b21b9952a7c77c82187cdbdb8fb9e2ee2ce1c8dd --- /dev/null +++ b/src/pystencilssfg/emission/header_source_pair.py @@ -0,0 +1,42 @@ +from os import path, makedirs + +from ..configuration import SfgOutputSpec +from ..context import SfgContext +from .printers import SfgHeaderPrinter, SfgImplPrinter + +from .clang_format import invoke_clang_format + + +class HeaderSourcePairEmitter: + def __init__(self, output_spec: SfgOutputSpec): + self._basename = output_spec.basename + self._output_directory = output_spec.output_directory + self._header_filename = output_spec.get_header_filename() + self._impl_filename = output_spec.get_impl_filename() + + self._ospec = output_spec + + @property + def output_files(self) -> tuple[str, str]: + return ( + path.join(self._output_directory, self._header_filename), + path.join(self._output_directory, self._impl_filename) + ) + + def write_files(self, ctx: SfgContext): + header_printer = SfgHeaderPrinter(ctx, self._ospec) + impl_printer = SfgImplPrinter(ctx, self._ospec) + + header = header_printer.get_code() + impl = impl_printer.get_code() + + header = invoke_clang_format(header, ctx.codestyle) + impl = invoke_clang_format(impl, ctx.codestyle) + + makedirs(self._output_directory, exist_ok=True) + + with open(self._ospec.get_header_filepath(), 'w') as headerfile: + headerfile.write(header) + + with open(self._ospec.get_impl_filepath(), 'w') as cppfile: + cppfile.write(impl) diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py new file mode 100644 index 0000000000000000000000000000000000000000..bfd48c8d52d6922ec475a06c3aee19364dadf9b8 --- /dev/null +++ b/src/pystencilssfg/emission/printers.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +from textwrap import indent +from itertools import chain, repeat, cycle + +from pystencils.astnodes import KernelFunction +from pystencils import Backend +from pystencils.backends import generate_c + +from ..context import SfgContext +from ..configuration import SfgOutputSpec +from ..visitors import visitor +from ..exceptions import SfgException + +from ..source_components import ( + SfgEmptyLines, + SfgHeaderInclude, + SfgKernelNamespace, + SfgFunction, + SfgClass, + SfgConstructor, + SfgMemberVariable, + SfgMethod, + SfgVisibility, +) + + +def interleave(*iters): + try: + for iter in cycle(iters): + yield next(iter) + except StopIteration: + pass + + +class SfgGeneralPrinter: + + @visitor + def visit(self, obj: object) -> str: + raise SfgException(f"Can't print object of type {type(obj)}") + + @visit.case(SfgEmptyLines) + def emptylines(self, el: SfgEmptyLines) -> str: + return "\n" * el.lines + + @visit.case(str) + def string(self, s: str) -> str: + return s + + @visit.case(SfgHeaderInclude) + def include(self, incl: SfgHeaderInclude) -> str: + if incl.system_header: + return f"#include <{incl.file}>" + else: + return f'#include "{incl.file}"' + + def prelude(self, ctx: SfgContext) -> str: + if ctx.prelude_comment: + return "/*\n" + indent(ctx.prelude_comment, "* ", predicate=lambda _: True) + "*/\n" + else: + return "" + + +class SfgHeaderPrinter(SfgGeneralPrinter): + + def __init__(self, ctx: SfgContext, output_spec: SfgOutputSpec): + self._output_spec = output_spec + self._ctx = ctx + + def get_code(self) -> str: + return self.visit(self._ctx) + + @visitor + def visit(self, obj: object) -> str: + return super().visit(obj) + + @visit.case(SfgContext) + def frame(self, ctx: SfgContext) -> str: + code = super().prelude(ctx) + + code += "\n#pragma once\n\n" + + includes = filter(lambda incl: not incl.private, ctx.includes()) + code += "\n".join(self.visit(incl) for incl in includes) + code += "\n\n" + + fq_namespace = ctx.fully_qualified_namespace + if fq_namespace is not None: + code += f"namespace {fq_namespace} {{\n\n" + + parts = interleave( + chain( + ctx.definitions(), + ctx.classes(), + ctx.functions() + ), + repeat(SfgEmptyLines(1)) + ) + + code += "\n".join(self.visit(p) for p in parts) + + if fq_namespace is not None: + code += f"}} // namespace {fq_namespace}\n" + + return code + + @visit.case(SfgFunction) + def function(self, func: SfgFunction): + params = sorted(list(func.parameters), key=lambda p: p.name) + param_list = ", ".join(f"{param.dtype} {param.name}" for param in params) + return f"void {func.name} ( {param_list} );" + + @visit.case(SfgClass) + def sfg_class(self, cls: SfgClass): + code = f"{cls.class_keyword} {cls.class_name} \n" + + if cls.base_classes: + code += f" : {','.join(cls.base_classes)}\n" + + code += "{\n" + for visibility in ( + SfgVisibility.DEFAULT, + SfgVisibility.PUBLIC, + SfgVisibility.PRIVATE, + ): + if visibility != SfgVisibility.DEFAULT: + code += f"\n{visibility}:\n" + for member in cls.members(visibility): + code += self._ctx.codestyle.indent(self.visit(member)) + "\n" + code += "};\n" + + return code + + @visit.case(SfgConstructor) + def sfg_constructor(self, constr: SfgConstructor): + code = f"{constr.owning_class.class_name} (" + code += ", ".join(f"{param.dtype} {param.name}" for param in constr.parameters) + code += ")\n" + if constr.initializers: + code += " : " + ", ".join(constr.initializers) + "\n" + if constr.body: + code += "{\n" + self._ctx.codestyle.indent(constr.body) + "\n}\n" + else: + code += "{ }\n" + return code + + @visit.case(SfgMemberVariable) + def sfg_member_var(self, var: SfgMemberVariable): + return f"{var.dtype} {var.name};" + + @visit.case(SfgMethod) + def sfg_method(self, method: SfgMethod): + code = f"void {method.name} (" + code += ", ".join(f"{param.dtype} {param.name}" for param in method.parameters) + code += ");" + return code + + +def delimiter(content): + return f"""\ +/************************************************************************************* + * {content} +*************************************************************************************/ +""" + + +class SfgImplPrinter(SfgGeneralPrinter): + def __init__(self, ctx: SfgContext, output_spec: SfgOutputSpec): + self._output_spec = output_spec + self._ctx = ctx + + def get_code(self) -> str: + return self.visit(self._ctx) + + @visitor + def visit(self, obj: object) -> str: + return super().visit(obj) + + @visit.case(SfgContext) + def frame(self, ctx: SfgContext) -> str: + code = super().prelude(ctx) + + code += f'\n#include "{self._output_spec.get_header_filename()}"\n\n' + + includes = filter(lambda incl: incl.private, ctx.includes()) + code += "\n".join(self.visit(incl) for incl in includes) + + code += "\n\n#define FUNC_PREFIX inline\n\n" + + fq_namespace = ctx.fully_qualified_namespace + if fq_namespace is not None: + code += f"namespace {fq_namespace} {{\n\n" + + parts = interleave( + chain( + [delimiter("Kernels")], + ctx.kernel_namespaces(), + [delimiter("Functions")], + ctx.functions(), + [delimiter("Class Methods")], + ctx.classes() + ), + repeat(SfgEmptyLines(1)) + ) + + code += "\n".join(self.visit(p) for p in parts) + + if fq_namespace is not None: + code += f"}} // namespace {fq_namespace}\n" + + return code + + @visit.case(SfgKernelNamespace) + def kernel_namespace(self, kns: SfgKernelNamespace) -> str: + code = f"namespace {kns.name} {{\n\n" + code += "\n\n".join(self.visit(ast) for ast in kns.asts) + code += f"\n}} // namespace {kns.name}\n" + return code + + @visit.case(KernelFunction) + def kernel(self, kfunc: KernelFunction) -> str: + return generate_c(kfunc, dialect=Backend.C) + + @visit.case(SfgFunction) + def function(self, func: SfgFunction) -> str: + return self.method_or_func(func, func.name) + + @visit.case(SfgClass) + def sfg_class(self, cls: SfgClass) -> str: + return "\n".join(self.visit(m) for m in cls.methods()) + + @visit.case(SfgMethod) + def sfg_method(self, method: SfgMethod) -> str: + return self.method_or_func(method, f"{method.owning_class.class_name}::{method.name}") + + def method_or_func(self, func: SfgFunction, fully_qualified_name: str) -> str: + params = sorted(list(func.parameters), key=lambda p: p.name) + param_list = ", ".join(f"{param.dtype} {param.name}" for param in params) + code = f"void {fully_qualified_name} ({param_list}) {{\n" + code += self._ctx.codestyle.indent(func.tree.get_code(self._ctx)) + code += "}\n" + return code diff --git a/src/pystencilssfg/emitters/class_declaration.py b/src/pystencilssfg/emitters/class_declaration.py deleted file mode 100644 index f80c7ca688d7fa7cc5b98ab35f434ed07d97ea31..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/emitters/class_declaration.py +++ /dev/null @@ -1,67 +0,0 @@ -from ..context import SfgContext -from ..visitors import visitor -from ..source_components import ( - SfgClass, - SfgConstructor, - SfgMemberVariable, - SfgMethod, - SfgVisibility, -) -from ..exceptions import SfgException - - -class ClassDeclarationPrinter: - def __init__(self, ctx: SfgContext): - self._codestyle = ctx.codestyle - - def print(self, cls: SfgClass): - return self.visit(cls, cls) - - @visitor - def visit(self, obj: object, cls: SfgClass) -> str: - raise SfgException("Can't print this.") - - @visit.case(SfgClass) - def sfg_class(self, cls: SfgClass, _: SfgClass): - code = f"{cls.class_keyword} {cls.class_name} \n" - - if cls.base_classes: - code += f" : {','.join(cls.base_classes)}\n" - - code += "{\n" - for visibility in ( - SfgVisibility.DEFAULT, - SfgVisibility.PUBLIC, - SfgVisibility.PRIVATE, - ): - if visibility != SfgVisibility.DEFAULT: - code += f"\n{visibility}:\n" - for member in cls.members(visibility): - code += self._codestyle.indent(self.visit(member, cls)) + "\n" - code += "};\n" - - return code - - @visit.case(SfgConstructor) - def sfg_constructor(self, constr: SfgConstructor, cls: SfgClass): - code = f"{cls.class_name} (" - code += ", ".join(f"{param.dtype} {param.name}" for param in constr.parameters) - code += ")\n" - if constr.initializers: - code += " : " + ", ".join(constr.initializers) + "\n" - if constr.body: - code += "{\n" + self._codestyle.indent(constr.body) + "\n}\n" - else: - code += "{ }\n" - return code - - @visit.case(SfgMemberVariable) - def sfg_member_var(self, var: SfgMemberVariable, _: SfgClass): - return f"{var.dtype} {var.name};" - - @visit.case(SfgMethod) - def sfg_method(self, method: SfgMethod, _: SfgClass): - code = f"void {method.name} (" - code += ", ".join(f"{param.dtype} {param.name}" for param in method.parameters) - code += ");" - return code diff --git a/src/pystencilssfg/emitters/header_source_pair.py b/src/pystencilssfg/emitters/header_source_pair.py deleted file mode 100644 index c5fd1c373b1596ce4e682e23e9e181faaaebffe5..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/emitters/header_source_pair.py +++ /dev/null @@ -1,67 +0,0 @@ -from jinja2 import Environment, PackageLoader, StrictUndefined - -from os import path, makedirs - -from ..configuration import SfgOutputSpec -from ..context import SfgContext - -from .clang_format import invoke_clang_format - - -class HeaderSourcePairEmitter: - def __init__(self, output_spec: SfgOutputSpec): - self._basename = output_spec.basename - self._output_directory = output_spec.output_directory - self._header_filename = output_spec.get_header_filename() - self._impl_filename = output_spec.get_impl_filename() - - self._ospec = output_spec - - @property - def output_files(self) -> tuple[str, str]: - return ( - path.join(self._output_directory, self._header_filename), - path.join(self._output_directory, self._impl_filename) - ) - - def write_files(self, ctx: SfgContext): - fq_namespace = ctx.fully_qualified_namespace - - jinja_context = { - 'ctx': ctx, - 'header_filename': self._header_filename, - 'source_filename': self._impl_filename, - 'basename': self._basename, - 'prelude_comment': ctx.prelude_comment, - 'definitions': tuple(ctx.definitions()), - 'fq_namespace': fq_namespace, - 'public_includes': tuple(incl.get_code() for incl in ctx.includes() if not incl.private), - 'private_includes': tuple(incl.get_code() for incl in ctx.includes() if incl.private), - 'kernel_namespaces': tuple(ctx.kernel_namespaces()), - 'functions': tuple(ctx.functions()), - 'classes': tuple(ctx.classes()) - } - - template_name = "HeaderSourcePair" - - env = Environment(loader=PackageLoader('pystencilssfg.emitters'), - undefined=StrictUndefined, - trim_blocks=True, - lstrip_blocks=True) - - from .jinja_filters import add_filters_to_jinja - add_filters_to_jinja(env) - - header = env.get_template(f"{template_name}.tmpl.h").render(**jinja_context) - source = env.get_template(f"{template_name}.tmpl.cpp").render(**jinja_context) - - header = invoke_clang_format(header, ctx.codestyle) - source = invoke_clang_format(source, ctx.codestyle) - - makedirs(self._output_directory, exist_ok=True) - - with open(self._ospec.get_header_filepath(), 'w') as headerfile: - headerfile.write(header) - - with open(self._ospec.get_impl_filepath(), 'w') as cppfile: - cppfile.write(source) diff --git a/src/pystencilssfg/emitters/jinja_filters.py b/src/pystencilssfg/emitters/jinja_filters.py deleted file mode 100644 index 67d96f018282944e28b234f3f1df5dfc7a3e726b..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/emitters/jinja_filters.py +++ /dev/null @@ -1,48 +0,0 @@ -from jinja2 import pass_context -from textwrap import indent - -from pystencils.astnodes import KernelFunction -from pystencils import Backend -from pystencils.backends import generate_c - -from pystencilssfg.source_components import SfgFunction, SfgClass -from .class_declaration import ClassDeclarationPrinter - - -def format_prelude_comment(prelude_comment: str): - if not prelude_comment: - return "" - - return "/*\n" + indent(prelude_comment, "* ", predicate=lambda _: True) + "*/\n" - - -@pass_context -def generate_kernel_definition(ctx, ast: KernelFunction): - return generate_c(ast, dialect=Backend.C) - - -@pass_context -def generate_function_parameter_list(ctx, func: SfgFunction): - params = sorted(list(func.parameters), key=lambda p: p.name) - return ", ".join(f"{param.dtype} {param.name}" for param in params) - - -@pass_context -def generate_function_body(ctx, func: SfgFunction): - return func.get_code(ctx["ctx"]) - - -@pass_context -def print_class_declaration(ctx, cls: SfgClass): - return ClassDeclarationPrinter(ctx["ctx"]).print(cls) - - -def add_filters_to_jinja(jinja_env): - jinja_env.filters["format_prelude_comment"] = format_prelude_comment - jinja_env.filters["generate_kernel_definition"] = generate_kernel_definition - jinja_env.filters[ - "generate_function_parameter_list" - ] = generate_function_parameter_list - jinja_env.filters["generate_function_body"] = generate_function_body - - jinja_env.filters["print_class_declaration"] = print_class_declaration diff --git a/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.cpp b/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.cpp deleted file mode 100644 index 48c965336939f947737fb7cdbdb3889d5f14fb04..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.cpp +++ /dev/null @@ -1,57 +0,0 @@ -{{ prelude_comment | format_prelude_comment }} - -#include "{{header_filename}}" - -{% for incl in private_includes %} -{{incl}} -{% endfor %} - -#define FUNC_PREFIX inline - -{% if fq_namespace is not none %} -namespace {{fq_namespace}} { -{% endif %} - -/************************************************************************************* - * Kernels -*************************************************************************************/ - -{% for kns in kernel_namespaces %} -namespace {{ kns.name }} { - -{% for ast in kns.asts %} -{{ ast | generate_kernel_definition }} -{% endfor %} - -} // namespace {{ kns.name }} -{% endfor %} - -/************************************************************************************* - * Functions -*************************************************************************************/ - -{% for function in functions %} -void {{ function.name }} ( {{ function | generate_function_parameter_list }} ) { - {{ function | generate_function_body | indent(ctx.codestyle.indent_width) }} -} - - -{% endfor -%} - -/************************************************************************************* - * Class Methods -*************************************************************************************/ - -{% for cls in classes %} -{% for method in cls.methods() %} -void {{ cls.class_name }}::{{ method.name }} ( {{ method | generate_function_parameter_list }} ) { - {{ method | generate_function_body | indent(ctx.codestyle.indent_width) }} -} - - -{% endfor %} -{% endfor -%} - -{% if fq_namespace is not none %} -} // namespace {{fq_namespace}} -{% endif %} diff --git a/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.h b/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.h deleted file mode 100644 index 95b3ed8e207e829f48ee64a29fecc44a8c581517..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/emitters/templates/HeaderSourcePair.tmpl.h +++ /dev/null @@ -1,31 +0,0 @@ -{{ prelude_comment | format_prelude_comment }} - -#pragma once - -#include <cstdint> - -{% for incl in public_includes %} -{{incl}} -{% endfor %} - -{% for definition in definitions %} -{{ definition }} -{% endfor %} - -#define RESTRICT __restrict__ - -{% if fq_namespace is not none %} -namespace {{fq_namespace}} { -{% endif %} - -{% for cls in classes %} -{{ cls | print_class_declaration }} -{% endfor %} - -{% for function in functions %} -void {{ function.name }} ( {{ function | generate_function_parameter_list }} ); -{% endfor %} - -{% if fq_namespace is not none %} -} // namespace {{fq_namespace}} -{% endif %} \ No newline at end of file diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py index cea341b5570efa0954b1a6bc8c41f07e4a652049..03fd82ccf854ecde1f8802036f10d5f8294160cc 100644 --- a/src/pystencilssfg/generator.py +++ b/src/pystencilssfg/generator.py @@ -27,7 +27,7 @@ class SourceFileGenerator: self._context = SfgContext(config.outer_namespace, config.codestyle, argv=script_args) - from .emitters import HeaderSourcePairEmitter + from .emission import HeaderSourcePairEmitter self._emitter = HeaderSourcePairEmitter(config.get_output_spec(basename)) def clean_files(self): diff --git a/src/pystencilssfg/printing/header_printer.py b/src/pystencilssfg/printing/header_printer.py deleted file mode 100644 index e49ba9c7131eada3a4770b1df7d41b7af8ea6404..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/printing/header_printer.py +++ /dev/null @@ -1,81 +0,0 @@ -from __future__ import annotations - -from textwrap import indent -from itertools import chain, repeat - -from ..context import SfgContext -from ..configuration import SfgOutputSpec -from ..visitors import visitor -from ..exceptions import SfgException - -from ..source_components import ( - SfgEmptyLines, SfgHeaderInclude -) - - -def interleave(*iters): - try: - for iter in iters: - yield next(iter) - except StopIteration: - pass - - -class SfgHeaderPrinter: - def __init__(self, output_spec: SfgOutputSpec): - self._output_spec = output_spec - - def code_string(self, ctx: SfgContext) -> str: - return self.visit(ctx) - - @visitor - def visit(self, obj: object) -> str: - raise SfgException(f"Can't print object of type {type(obj)}") - - @visit.case(SfgEmptyLines) - def emptylines(self, el: SfgEmptyLines) -> str: - return "\n" * el.lines - - @visit.case(str) - def string(self, s: str) -> str: - return s - - @visit.case(SfgHeaderInclude) - def include(self, incl: SfgHeaderInclude) -> str: - if incl.system_header: - return f"#include <{incl.file}>" - else: - return f'#include "{incl.file}"' - - @visit.case(SfgContext) - def frame(self, ctx: SfgContext) -> str: - code = "" - - if ctx.prelude_comment: - code += "/*\n" + indent(ctx.prelude_comment, "* ", predicate=lambda _: True) + "*/\n" - - code += "\n#pragma once\n\n" - - includes = filter(lambda incl: not incl.private, ctx.includes()) - code += "\n".join(self.visit(incl) for incl in includes) - code += "\n" - - fq_namespace = ctx.fully_qualified_namespace - if fq_namespace is not None: - code += f"namespace {fq_namespace} {{\n" - - parts = interleave( - chain( - ctx.definitions(), - ctx.classes(), - ctx.functions() - ), - repeat(SfgEmptyLines(1)) - ) - - code += "".join(self.visit(p) for p in parts) - - if fq_namespace is not None: - code += f"}} \\ namespace {fq_namespace}\n" - - return code