diff --git a/src/pystencilssfg/__init__.py b/src/pystencilssfg/__init__.py index 88a635ffd96cfea562ef3ad344ddb0c8b3bfa2e2..48de907236b8f515f4219be92e7b92e7888c53e0 100644 --- a/src/pystencilssfg/__init__.py +++ b/src/pystencilssfg/__init__.py @@ -1,9 +1,10 @@ from .configuration import SfgConfiguration from .generator import SourceFileGenerator from .composer import SfgComposer +from .context import SfgContext __all__ = [ - "SourceFileGenerator", "SfgComposer", "SfgConfiguration" + "SourceFileGenerator", "SfgComposer", "SfgConfiguration", "SfgContext" ] from . import _version diff --git a/src/pystencilssfg/composer.py b/src/pystencilssfg/composer.py index ce5799f22232478766bb19456922a642f17d7244..fa1e8a81418bb50ac6ce064adb9b13cdc1cbe44f 100644 --- a/src/pystencilssfg/composer.py +++ b/src/pystencilssfg/composer.py @@ -21,6 +21,10 @@ class SfgComposer: def __init__(self, ctx: SfgContext): self._ctx = ctx + @property + def context(self): + return self._ctx + @property def kernels(self) -> SfgKernelNamespace: """The default kernel namespace.""" diff --git a/src/pystencilssfg/emitters/cpu/templates/BasicCpu.tmpl.cpp b/src/pystencilssfg/emitters/cpu/templates/BasicCpu.tmpl.cpp index dfb7a43b0feb1a1a98b8a2237d6c014f7f7f0a67..6132d661fef182f73570c1c67f4a6e14ad3fe168 100644 --- a/src/pystencilssfg/emitters/cpu/templates/BasicCpu.tmpl.cpp +++ b/src/pystencilssfg/emitters/cpu/templates/BasicCpu.tmpl.cpp @@ -29,9 +29,11 @@ namespace {{ kns.name }} { *************************************************************************************/ {% for function in functions %} + void {{ function.name }} ( {{ function | generate_function_parameter_list }} ) { {{ function | generate_function_body | indent(2) }} } + {% endfor %} {% if fq_namespace is not none %} diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py index 1c0e104517144fa8aee5ebc6ab9b83d4c2d7fbe1..9c656ca66112cb51a095ba32b56c979db7125ee2 100644 --- a/src/pystencilssfg/source_components.py +++ b/src/pystencilssfg/source_components.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Sequence +from dataclasses import replace from pystencils import CreateKernelConfig, create_kernel from pystencils.astnodes import KernelFunction @@ -54,17 +55,32 @@ class SfgKernelNamespace: def asts(self): yield from self._asts.values() - def add(self, ast: KernelFunction): + def add(self, ast: KernelFunction, name: str | None = None): """Adds an existing pystencils AST to this namespace.""" - astname = ast.function_name + if name is not None: + astname = name + else: + astname = ast.function_name + if astname in self._asts: raise ValueError(f"Duplicate ASTs: An AST with name {astname} already exists in namespace {self._name}") + if name is not None: + ast.function_name = name + self._asts[astname] = ast return SfgKernelHandle(self._ctx, astname, self, ast.get_parameters()) - def create(self, assignments, config: CreateKernelConfig | None = None): + def create(self, assignments, name: str | None = None, config: CreateKernelConfig | None = None): + if config is None: + config = CreateKernelConfig() + + if name is not None: + if name in self._asts: + raise ValueError(f"Duplicate ASTs: An AST with name {name} already exists in namespace {self._name}") + config = replace(config, function_name=name) + # type: ignore ast = create_kernel(assignments, config=config) return self.add(ast) diff --git a/src/pystencilssfg/tree/__init__.py b/src/pystencilssfg/tree/__init__.py index 8ecc06149874426dfe3ef98c6d8682c4774b9d42..f81c9178c9fa315cd49ae5feba66feb0082cd728 100644 --- a/src/pystencilssfg/tree/__init__.py +++ b/src/pystencilssfg/tree/__init__.py @@ -1,7 +1,7 @@ from .basic_nodes import SfgCallTreeNode, SfgKernelCallNode, SfgBlock, SfgSequence, SfgStatements -from .conditional import SfgBranch, SfgCondition +from .conditional import SfgBranch, SfgCondition, IntEven, IntOdd __all__ = [ "SfgCallTreeNode", "SfgKernelCallNode", "SfgSequence", "SfgBlock", "SfgStatements", - "SfgCondition", "SfgBranch" + "SfgCondition", "SfgBranch", "IntEven", "IntOdd" ] diff --git a/src/pystencilssfg/tree/conditional.py b/src/pystencilssfg/tree/conditional.py index 39663cbf60eace8f13053b6064a8deb6ee53f99e..4c9021a65d5629ec2f74843162f54d08f7e8f0d2 100644 --- a/src/pystencilssfg/tree/conditional.py +++ b/src/pystencilssfg/tree/conditional.py @@ -1,6 +1,8 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional, cast +from pystencils.typing import TypedSymbol, BasicType + from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf from ..source_concepts.source_objects import TypedSymbolOrObject @@ -25,8 +27,36 @@ class SfgCustomCondition(SfgCondition): return self._cond_text -# class IntEven(SfgCondition): -# def __init__(self, ) +class IntEven(SfgCondition): + def __init__(self, symbol: TypedSymbol): + super().__init__() + if not isinstance(symbol.dtype, BasicType) or not symbol.dtype.is_int(): + raise ValueError(f"Symbol {symbol} does not have integer type.") + + self._symbol = symbol + + @property + def required_parameters(self) -> set[TypedSymbolOrObject]: + return {self._symbol} + + def get_code(self, ctx: SfgContext) -> str: + return f"(({self._symbol.name} & 1) ^ 1)" + + +class IntOdd(SfgCondition): + def __init__(self, symbol: TypedSymbol): + super().__init__() + if not isinstance(symbol.dtype, BasicType) or not symbol.dtype.is_int(): + raise ValueError(f"Symbol {symbol} does not have integer type.") + + self._symbol = symbol + + @property + def required_parameters(self) -> set[TypedSymbolOrObject]: + return {self._symbol} + + def get_code(self, ctx: SfgContext) -> str: + return f"({self._symbol.name} & 1)" class SfgBranch(SfgCallTreeNode):