Skip to content
Snippets Groups Projects
Commit 87846725 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Various small changes and additions:

 - IntEven and IntOdd Conditions
 - Blank lines in CPU template
 - Module Exports
parent d477cc99
No related merge requests found
Pipeline #57666 failed with stages
in 4 minutes and 50 seconds
from .configuration import SfgConfiguration
from .generator import SourceFileGenerator
from .composer import SfgComposer
from .context import SfgContext
__all__ = [
"SourceFileGenerator", "SfgComposer", "SfgConfiguration"
"SourceFileGenerator", "SfgComposer", "SfgConfiguration", "SfgContext"
]
from . import _version
......
......@@ -21,6 +21,10 @@ class SfgComposer:
def __init__(self, ctx: SfgContext):
self._ctx = ctx
@property
def context(self):
return self._ctx
@property
def kernels(self) -> SfgKernelNamespace:
"""The default kernel namespace."""
......
......@@ -29,9 +29,11 @@ namespace {{ kns.name }} {
*************************************************************************************/
{% for function in functions %}
void {{ function.name }} ( {{ function | generate_function_parameter_list }} ) {
{{ function | generate_function_body | indent(2) }}
}
{% endfor %}
{% if fq_namespace is not none %}
......
from __future__ import annotations
from typing import TYPE_CHECKING, Sequence
from dataclasses import replace
from pystencils import CreateKernelConfig, create_kernel
from pystencils.astnodes import KernelFunction
......@@ -54,17 +55,32 @@ class SfgKernelNamespace:
def asts(self):
yield from self._asts.values()
def add(self, ast: KernelFunction):
def add(self, ast: KernelFunction, name: str | None = None):
"""Adds an existing pystencils AST to this namespace."""
astname = ast.function_name
if name is not None:
astname = name
else:
astname = ast.function_name
if astname in self._asts:
raise ValueError(f"Duplicate ASTs: An AST with name {astname} already exists in namespace {self._name}")
if name is not None:
ast.function_name = name
self._asts[astname] = ast
return SfgKernelHandle(self._ctx, astname, self, ast.get_parameters())
def create(self, assignments, config: CreateKernelConfig | None = None):
def create(self, assignments, name: str | None = None, config: CreateKernelConfig | None = None):
if config is None:
config = CreateKernelConfig()
if name is not None:
if name in self._asts:
raise ValueError(f"Duplicate ASTs: An AST with name {name} already exists in namespace {self._name}")
config = replace(config, function_name=name)
# type: ignore
ast = create_kernel(assignments, config=config)
return self.add(ast)
......
from .basic_nodes import SfgCallTreeNode, SfgKernelCallNode, SfgBlock, SfgSequence, SfgStatements
from .conditional import SfgBranch, SfgCondition
from .conditional import SfgBranch, SfgCondition, IntEven, IntOdd
__all__ = [
"SfgCallTreeNode", "SfgKernelCallNode", "SfgSequence", "SfgBlock", "SfgStatements",
"SfgCondition", "SfgBranch"
"SfgCondition", "SfgBranch", "IntEven", "IntOdd"
]
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, cast
from pystencils.typing import TypedSymbol, BasicType
from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf
from ..source_concepts.source_objects import TypedSymbolOrObject
......@@ -25,8 +27,36 @@ class SfgCustomCondition(SfgCondition):
return self._cond_text
# class IntEven(SfgCondition):
# def __init__(self, )
class IntEven(SfgCondition):
def __init__(self, symbol: TypedSymbol):
super().__init__()
if not isinstance(symbol.dtype, BasicType) or not symbol.dtype.is_int():
raise ValueError(f"Symbol {symbol} does not have integer type.")
self._symbol = symbol
@property
def required_parameters(self) -> set[TypedSymbolOrObject]:
return {self._symbol}
def get_code(self, ctx: SfgContext) -> str:
return f"(({self._symbol.name} & 1) ^ 1)"
class IntOdd(SfgCondition):
def __init__(self, symbol: TypedSymbol):
super().__init__()
if not isinstance(symbol.dtype, BasicType) or not symbol.dtype.is_int():
raise ValueError(f"Symbol {symbol} does not have integer type.")
self._symbol = symbol
@property
def required_parameters(self) -> set[TypedSymbolOrObject]:
return {self._symbol}
def get_code(self, ctx: SfgContext) -> str:
return f"({self._symbol.name} & 1)"
class SfgBranch(SfgCallTreeNode):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment