diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 309ee8bf247c922fe5e1e340dc92f170ae08c94c..7c61b61e2940b3c05227005380de8f7d151f4763 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -610,10 +610,10 @@ def make_sequence(*args: SequencerArg) -> SfgSequence: return SfgSequence(children) -class SfgFunctionSequencer: - """Sequencer for constructing free functions. +class SfgFunctionSequencerBase: + """Common base class for function and method sequencers. - This builder uses call sequencing to specify the function's properties. + This builder uses call sequencing to specify the function or method's properties. Example: @@ -641,12 +641,12 @@ class SfgFunctionSequencer: # Attributes self._attributes: list[str] = [] - def returns(self, rtype: UserTypeSpec) -> SfgFunctionSequencer: + def returns(self, rtype: UserTypeSpec): """Set the return type of the function""" self._return_type = create_type(rtype) return self - def params(self, *args: VarLike) -> SfgFunctionSequencer: + def params(self, *args: VarLike): """Specify the parameters for this function. Use this to manually specify the function's parameter list. @@ -657,21 +657,25 @@ class SfgFunctionSequencer: self._params = [asvar(v) for v in args] return self - def inline(self) -> SfgFunctionSequencer: + def inline(self): """Mark this function as ``inline``.""" self._inline = True return self - def constexpr(self) -> SfgFunctionSequencer: + def constexpr(self): """Mark this function as ``constexpr``.""" self._constexpr = True return self - def attr(self, *attrs: str) -> SfgFunctionSequencer: + def attr(self, *attrs: str): """Add attributes to this function""" self._attributes += attrs return self + +class SfgFunctionSequencer(SfgFunctionSequencerBase): + """Sequencer for constructing functions.""" + def __call__(self, *args: SequencerArg) -> None: """Populate the function body""" tree = make_sequence(*args) diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index 0a72e8089ecd5e32be53cd335df57b58b21ec578..f277c7e9e60e36b326d4511f1e3aea07df825dd9 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -3,9 +3,9 @@ from typing import Sequence from itertools import takewhile, dropwhile import numpy as np -from pystencils.types import PsCustomType, UserTypeSpec, create_type +from pystencils.types import create_type -from ..context import SfgContext +from ..context import SfgContext, SfgCursor from ..lang import ( VarLike, ExprLike, @@ -32,9 +32,69 @@ from .mixin import SfgComposerMixIn from .basic_composer import ( make_sequence, SequencerArg, + SfgFunctionSequencerBase, ) +class SfgMethodSequencer(SfgFunctionSequencerBase): + def __init__(self, cursor: SfgCursor, name: str) -> None: + super().__init__(cursor, name) + + self._const: bool = False + self._static: bool = False + self._virtual: bool = False + self._override: bool = False + + self._tree: SfgCallTreeNode + + def const(self): + """Mark this method as ``const``.""" + self._const = True + return self + + def static(self): + """Mark this method as ``static``.""" + self._static = True + return self + + def virtual(self): + """Mark this method as ``virtual``.""" + self._virtual = True + return self + + def override(self): + """Mark this method as ``override``.""" + self._override = True + return self + + def __call__(self, *args: SequencerArg): + self._tree = make_sequence(*args) + return self + + def _resolve(self, ctx: SfgContext, cls: SfgClass, vis_block: SfgVisibilityBlock): + method = SfgMethod( + self._name, + cls, + self._tree, + return_type=self._return_type, + inline=self._inline, + const=self._const, + static=self._static, + constexpr=self._constexpr, + virtual=self._virtual, + override=self._override, + attributes=self._attributes, + required_params=self._params, + ) + cls.add_member(method, vis_block.visibility) + + if self._inline: + vis_block.elements.append(SfgEntityDef(method)) + else: + vis_block.elements.append(SfgEntityDecl(method)) + ctx._cursor.write_impl(SfgEntityDef(method)) + + class SfgClassComposer(SfgComposerMixIn): """Composer for classes and structs. @@ -53,7 +113,7 @@ class SfgClassComposer(SfgComposerMixIn): def __init__(self, visibility: SfgVisibility): self._visibility = visibility self._args: tuple[ - SfgClassComposer.MethodSequencer + SfgMethodSequencer | SfgClassComposer.ConstructorBuilder | VarLike | str, @@ -63,10 +123,7 @@ class SfgClassComposer(SfgComposerMixIn): def __call__( self, *args: ( - SfgClassComposer.MethodSequencer - | SfgClassComposer.ConstructorBuilder - | VarLike - | str + SfgMethodSequencer | SfgClassComposer.ConstructorBuilder | VarLike | str ), ): self._args = args @@ -76,10 +133,7 @@ class SfgClassComposer(SfgComposerMixIn): vis_block = SfgVisibilityBlock(self._visibility) for arg in self._args: match arg: - case ( - SfgClassComposer.MethodSequencer() - | SfgClassComposer.ConstructorBuilder() - ): + case SfgMethodSequencer() | SfgClassComposer.ConstructorBuilder(): arg._resolve(ctx, cls, vis_block) case str(): vis_block.elements.append(arg) @@ -90,43 +144,6 @@ class SfgClassComposer(SfgComposerMixIn): vis_block.elements.append(SfgEntityDef(member_var)) return vis_block - class MethodSequencer: - def __init__( - self, - name: str, - returns: UserTypeSpec = PsCustomType("void"), - inline: bool = False, - const: bool = False, - ) -> None: - self._name = name - self._returns = create_type(returns) - self._inline = inline - self._const = const - self._tree: SfgCallTreeNode - - def __call__(self, *args: SequencerArg): - self._tree = make_sequence(*args) - return self - - def _resolve( - self, ctx: SfgContext, cls: SfgClass, vis_block: SfgVisibilityBlock - ): - method = SfgMethod( - self._name, - cls, - self._tree, - return_type=self._returns, - inline=self._inline, - const=self._const, - ) - cls.add_member(method, vis_block.visibility) - - if self._inline: - vis_block.elements.append(SfgEntityDef(method)) - else: - vis_block.elements.append(SfgEntityDecl(method)) - ctx._cursor.write_impl(SfgEntityDef(method)) - class ConstructorBuilder: """Composer syntax for constructor building. @@ -197,9 +214,7 @@ class SfgClassComposer(SfgComposerMixIn): """ return self._class(class_name, SfgClassKeyword.STRUCT, bases) - def numpy_struct( - self, name: str, dtype: np.dtype, add_constructor: bool = True - ): + def numpy_struct(self, name: str, dtype: np.dtype, add_constructor: bool = True): """Add a numpy structured data type as a C++ struct Returns: @@ -230,13 +245,7 @@ class SfgClassComposer(SfgComposerMixIn): """ return SfgClassComposer.ConstructorBuilder(*params) - def method( - self, - name: str, - returns: UserTypeSpec = PsCustomType("void"), - inline: bool = False, - const: bool = False, - ): + def method(self, name: str): """In a class or struct body or visibility block, add a method. The usage is similar to :any:`SfgBasicComposer.function`. @@ -247,7 +256,7 @@ class SfgClassComposer(SfgComposerMixIn): const: Whether or not the method is const-qualified. """ - return SfgClassComposer.MethodSequencer(name, returns, inline, const) + return SfgMethodSequencer(self._cursor, name) # INTERNALS @@ -270,7 +279,7 @@ class SfgClassComposer(SfgComposerMixIn): def sequencer( *args: ( SfgClassComposer.VisibilityBlockSequencer - | SfgClassComposer.MethodSequencer + | SfgMethodSequencer | SfgClassComposer.ConstructorBuilder | VarLike | str diff --git a/src/pystencilssfg/emission/file_printer.py b/src/pystencilssfg/emission/file_printer.py index dea6605b10dddc792df316eefc50c40d01a73126..765bf70550504fd499746504236f1adaa224664a 100644 --- a/src/pystencilssfg/emission/file_printer.py +++ b/src/pystencilssfg/emission/file_printer.py @@ -183,11 +183,14 @@ class SfgFilePrinter: if func.attributes: code += "[[" + ", ".join(func.attributes) + "]]" - if func.inline: + if func.inline and not inclass: code += "inline " - if isinstance(func, SfgMethod) and func.static: - code += "static " + if isinstance(func, SfgMethod) and inclass: + if func.static: + code += "static " + if func.virtual: + code += "virtual " if func.constexpr: code += "constexpr " @@ -200,7 +203,10 @@ class SfgFilePrinter: code += f"{func.owning_class.name}::" code += f"{func.name}({params_str})" - if isinstance(func, SfgMethod) and func.const: - code += " const" + if isinstance(func, SfgMethod): + if func.const: + code += " const" + if func.override and inclass: + code += " override" return code diff --git a/src/pystencilssfg/ir/entities.py b/src/pystencilssfg/ir/entities.py index 3cb828c2072b9fb118bce90d4a40cccf52fb937c..b5822624c5455eff98f41572f46e3aa037746ac1 100644 --- a/src/pystencilssfg/ir/entities.py +++ b/src/pystencilssfg/ir/entities.py @@ -204,7 +204,7 @@ class SfgKernelNamespace(SfgNamespace): self._kernels[kernel.name] = kernel -@dataclass(frozen=True) +@dataclass(frozen=True, match_args=False) class CommonFunctionProperties: tree: SfgCallTreeNode parameters: tuple[SfgVar, ...] @@ -213,6 +213,26 @@ class CommonFunctionProperties: constexpr: bool attributes: Sequence[str] + @staticmethod + def collect_params(tree: SfgCallTreeNode, required_params: Sequence[SfgVar] | None): + from .postprocessing import CallTreePostProcessing + + param_collector = CallTreePostProcessing() + params_set = param_collector(tree).function_params + + if required_params is not None: + if not (params_set <= set(required_params)): + extras = params_set - set(required_params) + raise SfgException( + "Extraenous function parameters: " + f"Found free variables {extras} that were not listed in manually specified function parameters." + ) + parameters = tuple(required_params) + else: + parameters = tuple(sorted(params_set, key=lambda p: p.name)) + + return parameters + class SfgFunction(SfgCodeEntity, CommonFunctionProperties): """A free function.""" @@ -232,21 +252,7 @@ class SfgFunction(SfgCodeEntity, CommonFunctionProperties): ): super().__init__(name, namespace) - from .postprocessing import CallTreePostProcessing - - param_collector = CallTreePostProcessing() - params_set = param_collector(tree).function_params - - if required_params is not None: - if not (params_set <= set(required_params)): - extras = params_set - set(required_params) - raise SfgException( - "Extraenous function parameters: " - f"Found free variables {extras} that were not listed in manually specified function parameters." - ) - parameters = tuple(required_params) - else: - parameters = tuple(sorted(params_set, key=lambda p: p.name)) + parameters = self.collect_params(tree, required_params) CommonFunctionProperties.__init__( self, @@ -349,21 +355,20 @@ class SfgMethod(SfgClassMember, CommonFunctionProperties): const: bool = False, static: bool = False, constexpr: bool = False, + virtual: bool = False, + override: bool = False, attributes: Sequence[str] = (), + required_params: Sequence[SfgVar] | None = None, ): super().__init__(cls) self._name = name - - from .postprocessing import CallTreePostProcessing - - param_collector = CallTreePostProcessing() - parameters = tuple( - sorted(param_collector(tree).function_params, key=lambda p: p.name) - ) - self._static = static self._const = const + self._virtual = virtual + self._override = override + + parameters = self.collect_params(tree, required_params) CommonFunctionProperties.__init__( self, @@ -387,6 +392,14 @@ class SfgMethod(SfgClassMember, CommonFunctionProperties): def const(self) -> bool: return self._const + @property + def virtual(self) -> bool: + return self._virtual + + @property + def override(self) -> bool: + return self._override + class SfgConstructor(SfgClassMember): """Constructor of a class""" diff --git a/tests/generator_scripts/index.yaml b/tests/generator_scripts/index.yaml index 57cee9172c26f897b8480ef3322dfde85a1c4481..68352fe2c2904c11e551955eeb29a6bf9424e126 100644 --- a/tests/generator_scripts/index.yaml +++ b/tests/generator_scripts/index.yaml @@ -48,6 +48,10 @@ SimpleClasses: output-mode: header-only ComposerFeatures: + expect-code: + hpp: + - regex: >- + \[\[nodiscard\]\]\s*static\s*double\s*geometric\(\s*double\s*q,\s*uint64_t\s*k\) Conditionals: expect-code: diff --git a/tests/generator_scripts/source/ComposerFeatures.harness.cpp b/tests/generator_scripts/source/ComposerFeatures.harness.cpp index f24b726558346f7ad148365b2f3fc102f81f6b89..4ca651aef66e714d39cf65cf624a5ba9de52dabd 100644 --- a/tests/generator_scripts/source/ComposerFeatures.harness.cpp +++ b/tests/generator_scripts/source/ComposerFeatures.harness.cpp @@ -1,6 +1,11 @@ #include "ComposerFeatures.hpp" -/* factorial is constexpr -> evaluate at compile-time */ +#include <cmath> + +#undef NDEBUG +#include <cassert> + +/* Evaluate constexpr functions at compile-time */ static_assert( factorial(0) == 1 ); static_assert( factorial(1) == 1 ); static_assert( factorial(2) == 2 ); @@ -8,6 +13,23 @@ static_assert( factorial(3) == 6 ); static_assert( factorial(4) == 24 ); static_assert( factorial(5) == 120 ); +static_assert( ConstexprMath::abs(ConstexprMath::geometric(0.5, 0) - 1.0) < 1e-10 ); +static_assert( ConstexprMath::abs(ConstexprMath::geometric(0.5, 1) - 1.5) < 1e-10 ); +static_assert( ConstexprMath::abs(ConstexprMath::geometric(0.5, 2) - 1.75) < 1e-10 ); +static_assert( ConstexprMath::abs(ConstexprMath::geometric(0.5, 3) - 1.875) < 1e-10 ); + int main(void) { - return 0; + assert( std::fabs(Series::geometric(0.5, 0) - 1.0) < 1e-10 ); + assert( std::fabs(Series::geometric(0.5, 1) - 1.5) < 1e-10 ); + assert( std::fabs(Series::geometric(0.5, 2) - 1.75) < 1e-10 ); + assert( std::fabs(Series::geometric(0.5, 3) - 1.875) < 1e-10 ); + + inheritance_test::Parent p; + assert( p.compute() == 24 ); + + inheritance_test::Child c; + assert( c.compute() == 31 ); + + auto & cp = dynamic_cast< inheritance_test::Parent & >(c); + assert( cp.compute() == 31 ); } diff --git a/tests/generator_scripts/source/ComposerFeatures.py b/tests/generator_scripts/source/ComposerFeatures.py index fd64de541e96756cf6815f3a031f3317155ea108..ab97d7acc01884c679b2645ccce149b15005569c 100644 --- a/tests/generator_scripts/source/ComposerFeatures.py +++ b/tests/generator_scripts/source/ComposerFeatures.py @@ -3,11 +3,67 @@ from pystencilssfg import SourceFileGenerator with SourceFileGenerator() as sfg: - # Inline constexpr function with explicit parameter list - sfg.function("factorial").params(sfg.var("n", "uint64")).returns("uint64").inline().constexpr()( - sfg.branch("n == 0")( - "return 1;" - )( - "return n * factorial(n - 1);" + sfg.function("factorial").params(sfg.var("n", "uint64")).returns( + "uint64" + ).inline().constexpr()( + sfg.branch("n == 0")("return 1;")("return n * factorial(n - 1);") + ) + + q = sfg.var("q", "double") + k = sfg.var("k", "uint64_t") + x = sfg.var("x", "double") + + sfg.include("<cmath>") + + sfg.struct("Series")( + sfg.method("geometric") + .static() + .attr("nodiscard") + .params(q, k) + .returns("double")( + sfg.branch("k == 0")( + "return 1.0;" + )( + "return Series::geometric(q, k - 1) + std::pow(q, k);" + ) + ) + ) + + sfg.struct("ConstexprMath")( + sfg.method("abs").static().constexpr().inline() + .params(x) + .returns("double") + ( + "if (x >= 0.0) return x; else return -x;" + ), + + sfg.method("geometric") + .static() + .constexpr() + .inline() + .params(q, k) + .returns("double")( + sfg.branch("k == 0")( + "return 1.0;" + )( + "return 1 + q * ConstexprMath::geometric(q, k - 1);" + ) ) ) + + with sfg.namespace("inheritance_test"): + sfg.klass("Parent")( + sfg.public( + sfg.method("compute").returns("int").virtual().const()( + "return 24;" + ) + ) + ) + + sfg.klass("Child", bases=["public Parent"])( + sfg.public( + sfg.method("compute").returns("int").override().const()( + "return 31;" + ) + ) + ) diff --git a/tests/generator_scripts/source/SimpleClasses.py b/tests/generator_scripts/source/SimpleClasses.py index 26502f0e149c11d470e700269f5ff526aff3ce85..d2bc6da77c01914d0c477977209efbef005da39b 100644 --- a/tests/generator_scripts/source/SimpleClasses.py +++ b/tests/generator_scripts/source/SimpleClasses.py @@ -12,7 +12,7 @@ with SourceFileGenerator() as sfg: sfg.klass("Point")( sfg.public( sfg.constructor(x, y, z).init(x_)(x).init(y_)(y).init(z_)(z), - sfg.method("getX", returns="const int64_t", const=True, inline=True)( + sfg.method("getX").returns("const int64_t").const().inline()( "return this->x_;" ), ), @@ -22,7 +22,7 @@ with SourceFileGenerator() as sfg: sfg.klass("SpecialPoint", bases=["public Point"])( sfg.public( "using Point::Point;", - sfg.method("getY", returns="const int64_t", const=True, inline=True)( + sfg.method("getY").returns("const int64_t").const().inline()( "return this->y_;" ), )