diff --git a/src/pystencilssfg/__init__.py b/src/pystencilssfg/__init__.py index 13847b51d1ffc2da80827c6ad4b01ef9548698bf..ae3cd1a774207f5a8d2515fce5f98254411ce366 100644 --- a/src/pystencilssfg/__init__.py +++ b/src/pystencilssfg/__init__.py @@ -1,11 +1,8 @@ -from .context import SourceFileGenerator, SfgContext -from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle - -from .types import PsType, SrcType +from .generator import SourceFileGenerator +from .composer import SfgComposer __all__ = [ - "SourceFileGenerator", "SfgContext", "SfgKernelNamespace", "SfgKernelHandle", - "PsType", "SrcType" + "SourceFileGenerator", "SfgComposer", ] from . import _version diff --git a/src/pystencilssfg/composer.py b/src/pystencilssfg/composer.py new file mode 100644 index 0000000000000000000000000000000000000000..ed0dced3c34c56a85e43726b5b091aceb25d1f72 --- /dev/null +++ b/src/pystencilssfg/composer.py @@ -0,0 +1,150 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Optional +from abc import ABC, abstractmethod + +from pystencils import Field +from pystencils.astnodes import KernelFunction + +from .tree import SfgCallTreeNode, SfgKernelCallNode, SfgStatements, SfgSequence, SfgBlock +from .tree.deferred_nodes import SfgDeferredFieldMapping +from .tree.conditional import SfgCondition, SfgCustomCondition, SfgBranch +from .source_components import SfgFunction, SfgHeaderInclude, SfgKernelNamespace, SfgKernelHandle +from .source_concepts import SrcField, TypedSymbolOrObject + +if TYPE_CHECKING: + from .context import SfgContext + + +class SfgComposer: + def __init__(self, ctx: SfgContext): + self._ctx = ctx + + @property + def kernels(self) -> SfgKernelNamespace: + return self._ctx._default_kernel_namespace + + def kernel_namespace(self, name: str) -> SfgKernelNamespace: + kns = self._ctx.get_kernel_namespace(name) + if kns is None: + kns = SfgKernelNamespace(self, name) + self._ctx.add_kernel_namespace(kns) + + return kns + + def include(self, header_file: str): + system_header = False + if header_file.startswith("<") and header_file.endswith(">"): + header_file = header_file[1:-1] + system_header = True + + self._ctx.add_include(SfgHeaderInclude(header_file, system_header=system_header)) + + def kernel_function(self, name: str, ast_or_kernel_handle: KernelFunction | SfgKernelHandle): + if self._ctx.get_function(name) is not None: + raise ValueError(f"Function {name} already exists.") + + if isinstance(ast_or_kernel_handle, KernelFunction): + khandle = self._ctx.default_kernel_namespace.add(ast_or_kernel_handle) + tree = SfgKernelCallNode(khandle) + elif isinstance(ast_or_kernel_handle, SfgKernelCallNode): + tree = ast_or_kernel_handle + else: + raise TypeError("Invalid type of argument `ast_or_kernel_handle`!") + + func = SfgFunction(self._ctx, name, tree) + self._ctx.add_function(func) + + def function(self, name: str): + if self._ctx.get_function(name) is not None: + raise ValueError(f"Function {name} already exists.") + + def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder): + tree = make_sequence(*args) + func = SfgFunction(self._ctx, name, tree) + self._ctx.add_function(func) + + return sequencer + + def call(self, kernel_handle: SfgKernelHandle) -> SfgKernelCallNode: + return SfgKernelCallNode(kernel_handle) + + def seq(self, *args: SfgCallTreeNode) -> SfgSequence: + return make_sequence(*args) + + @property + def branch(self) -> SfgBranchBuilder: + return SfgBranchBuilder() + + def map_field(self, field: Field, src_object: Optional[SrcField] = None) -> SfgDeferredFieldMapping: + if src_object is None: + raise NotImplementedError("Automatic field extraction is not implemented yet.") + else: + return SfgDeferredFieldMapping(field, src_object) + + def map_param(self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str): + return SfgStatements(mapping, (lhs,), (rhs,)) + + +class SfgNodeBuilder(ABC): + @abstractmethod + def resolve(self) -> SfgCallTreeNode: + pass + + +def make_sequence(*args: tuple | str | SfgCallTreeNode | SfgNodeBuilder) -> SfgSequence: + children = [] + for i, arg in enumerate(args): + if isinstance(arg, SfgNodeBuilder): + children.append(arg.resolve()) + elif isinstance(arg, SfgCallTreeNode): + children.append(arg) + elif isinstance(arg, str): + children.append(SfgStatements(arg, (), ())) + elif isinstance(arg, tuple): + # Tuples are treated as blocks + subseq = make_sequence(*arg) + children.append(SfgBlock(subseq)) + else: + raise TypeError(f"Sequence argument {i} has invalid type.") + + return SfgSequence(children) + + +class SfgBranchBuilder(SfgNodeBuilder): + def __init__(self): + self._phase = 0 + + self._cond = None + self._branch_true = SfgSequence(()) + self._branch_false = None + + def __call__(self, *args) -> SfgBranchBuilder: + match self._phase: + case 0: # Condition + if len(args) != 1: + raise ValueError("Must specify exactly one argument as branch condition!") + + cond = args[0] + + if isinstance(cond, str): + cond = SfgCustomCondition(cond) + elif not isinstance(cond, SfgCondition): + raise ValueError( + "Invalid type for branch condition. Must be either `str` or a subclass of `SfgCondition`.") + + self._cond = cond + + case 1: # Then-branch + self._branch_true = make_sequence(*args) + case 2: # Else-branch + self._branch_false = make_sequence(*args) + case _: # There's no third branch! + raise TypeError("Branch construct already complete.") + + self._phase += 1 + + return self + + def resolve(self) -> SfgCallTreeNode: + assert self._cond is not None + return SfgBranch(self._cond, self._branch_true, self._branch_false) diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index e11d82703ddbb7a7bfab4241cdb0a19673d8d33c..a83fe12857af7950aee13cf7df18eab009cf1302 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -1,54 +1,8 @@ -from typing import Generator, Union, Optional, Sequence +from typing import Generator, Sequence -import sys -import os -from os import path - - -from pystencils import Field -from pystencils.astnodes import KernelFunction - -from .configuration import SfgConfiguration, config_from_commandline, merge_configurations, SfgCodeStyle -from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle -from .tree import SfgCallTreeNode, SfgKernelCallNode, SfgStatements -from .tree.deferred_nodes import SfgDeferredFieldMapping -from .tree.builders import SfgBranchBuilder, make_sequence +from .configuration import SfgConfiguration, SfgCodeStyle from .tree.visitors import CollectIncludes -from .source_concepts import SrcField, TypedSymbolOrObject -from .source_components import SfgFunction, SfgHeaderInclude - - -class SourceFileGenerator: - def __init__(self, sfg_config: SfgConfiguration | None = None): - if sfg_config and not isinstance(sfg_config, SfgConfiguration): - raise TypeError("sfg_config is not an SfgConfiguration.") - - import __main__ - scriptpath = __main__.__file__ - scriptname = path.split(scriptpath)[1] - basename = path.splitext(scriptname)[0] - - project_config, cmdline_config, script_args = config_from_commandline(sys.argv) - - config = merge_configurations(project_config, cmdline_config, sfg_config) - - self._context = SfgContext(script_args, config) - - from .emitters.cpu.basic_cpu import BasicCpuEmitter - self._emitter = BasicCpuEmitter(basename, config) - - def clean_files(self): - for file in self._emitter.output_files: - if path.exists(file): - os.remove(file) - - def __enter__(self): - self.clean_files() - return self._context - - def __exit__(self, exc_type, exc_value, traceback): - if exc_type is None: - self._emitter.write_files(self._context) +from .source_components import SfgHeaderInclude, SfgKernelNamespace, SfgFunction class SfgContext: @@ -91,98 +45,49 @@ class SfgContext: return self._config.codestyle # ---------------------------------------------------------------------------------------------- - # Source Component Getters + # Kernel Namespaces # ---------------------------------------------------------------------------------------------- def includes(self) -> Generator[SfgHeaderInclude, None, None]: yield from self._includes - def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]: - yield from self._kernel_namespaces.values() - - def functions(self) -> Generator[SfgFunction, None, None]: - yield from self._functions.values() - - # ---------------------------------------------------------------------------------------------- - # Source Component Adders - # ---------------------------------------------------------------------------------------------- - def add_include(self, include: SfgHeaderInclude): self._includes.add(include) - def add_function(self, func: SfgFunction): - if func.name in self._functions: - raise ValueError(f"Duplicate function: {func.name}") - - self._functions[func.name] = func - for incl in CollectIncludes().visit(func._tree): - self.add_include(incl) - # ---------------------------------------------------------------------------------------------- - # Factory-like Adders + # Kernel Namespaces # ---------------------------------------------------------------------------------------------- @property - def kernels(self) -> SfgKernelNamespace: + def default_kernel_namespace(self) -> SfgKernelNamespace: return self._default_kernel_namespace - def kernel_namespace(self, name: str) -> SfgKernelNamespace: - if name in self._kernel_namespaces: - raise ValueError(f"Duplicate kernel namespace: {name}") - - kns = SfgKernelNamespace(self, name) - self._kernel_namespaces[name] = kns - return kns - - def include(self, header_file: str): - system_header = False - if header_file.startswith("<") and header_file.endswith(">"): - header_file = header_file[1:-1] - system_header = True - - self.add_include(SfgHeaderInclude(header_file, system_header=system_header)) - - def function(self, - name: str, - ast_or_kernel_handle: Optional[Union[KernelFunction, SfgKernelHandle]] = None): - if name in self._functions: - raise ValueError(f"Duplicate function: {name}") - - if ast_or_kernel_handle is not None: - if isinstance(ast_or_kernel_handle, KernelFunction): - khandle = self._default_kernel_namespace.add(ast_or_kernel_handle) - tree = SfgKernelCallNode(khandle) - elif isinstance(ast_or_kernel_handle, SfgKernelCallNode): - tree = ast_or_kernel_handle - else: - raise TypeError("Invalid type of argument `ast_or_kernel_handle`!") - - func = SfgFunction(self, name, tree) - self.add_function(func) - else: - def sequencer(*args: SfgCallTreeNode): - tree = make_sequence(*args) - func = SfgFunction(self, name, tree) - self.add_function(func) - - return sequencer + def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]: + yield from self._kernel_namespaces.values() + + def get_kernel_namespace(self, str) -> SfgKernelNamespace | None: + return self._kernel_namespaces.get(str) + + def add_kernel_namespace(self, namespace: SfgKernelNamespace): + if namespace.name in self._kernel_namespaces: + raise ValueError(f"Duplicate kernel namespace: {namespace.name}") + + self._kernel_namespaces[namespace.name] = namespace # ---------------------------------------------------------------------------------------------- - # In-Sequence builders to be used within the second phase of SfgContext.function(). + # Functions # ---------------------------------------------------------------------------------------------- - def call(self, kernel_handle: SfgKernelHandle) -> SfgKernelCallNode: - return SfgKernelCallNode(kernel_handle) + def functions(self) -> Generator[SfgFunction, None, None]: + yield from self._functions.values() - @property - def branch(self) -> SfgBranchBuilder: - return SfgBranchBuilder() + def get_function(self, name: str) -> SfgFunction | None: + return self._functions.get(name, None) - def map_field(self, field: Field, src_object: Optional[SrcField] = None) -> SfgDeferredFieldMapping: - if src_object is None: - raise NotImplementedError("Automatic field extraction is not implemented yet.") - else: - return SfgDeferredFieldMapping(field, src_object) + def add_function(self, func: SfgFunction) -> None: + if func.name in self._functions: + raise ValueError(f"Duplicate function: {func.name}") - def map_param(self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str): - return SfgStatements(mapping, (lhs,), (rhs,)) + self._functions[func.name] = func + for incl in CollectIncludes().visit(func._tree): + self.add_include(incl) diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1c98d3b312d328bcbd654db5bb9e98539c7d0d --- /dev/null +++ b/src/pystencilssfg/generator.py @@ -0,0 +1,40 @@ +import sys +import os +from os import path + +from .configuration import SfgConfiguration, config_from_commandline, merge_configurations +from .context import SfgContext +from .composer import SfgComposer + + +class SourceFileGenerator: + def __init__(self, sfg_config: SfgConfiguration | None = None): + if sfg_config and not isinstance(sfg_config, SfgConfiguration): + raise TypeError("sfg_config is not an SfgConfiguration.") + + import __main__ + scriptpath = __main__.__file__ + scriptname = path.split(scriptpath)[1] + basename = path.splitext(scriptname)[0] + + project_config, cmdline_config, script_args = config_from_commandline(sys.argv) + + config = merge_configurations(project_config, cmdline_config, sfg_config) + + self._context = SfgContext(script_args, config) + + from .emitters.cpu.basic_cpu import BasicCpuEmitter + self._emitter = BasicCpuEmitter(basename, config) + + def clean_files(self): + for file in self._emitter.output_files: + if path.exists(file): + os.remove(file) + + def __enter__(self): + self.clean_files() + return SfgComposer(self._context) + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + self._emitter.write_files(self._context) diff --git a/src/pystencilssfg/kernel_namespace.py b/src/pystencilssfg/source_components.py similarity index 53% rename from src/pystencilssfg/kernel_namespace.py rename to src/pystencilssfg/source_components.py index f11b8f1b97362f76ba00821097d2795107c142ac..1c0e104517144fa8aee5ebc6ab9b83d4c2d7fbe1 100644 --- a/src/pystencilssfg/kernel_namespace.py +++ b/src/pystencilssfg/source_components.py @@ -1,8 +1,44 @@ -from typing import Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from pystencils import CreateKernelConfig, create_kernel from pystencils.astnodes import KernelFunction +if TYPE_CHECKING: + from .context import SfgContext + from .tree import SfgCallTreeNode + + +class SfgHeaderInclude: + def __init__(self, header_file: str, system_header: bool = False, private: bool = False): + self._header_file = header_file + self._system_header = system_header + self._private = private + + @property + def system_header(self): + return self._system_header + + @property + def private(self): + return self._private + + def get_code(self): + if self._system_header: + return f"#include <{self._header_file}>" + else: + return f'#include "{self._header_file}"' + + def __hash__(self) -> int: + return hash((self._header_file, self._system_header, self._private)) + + def __eq__(self, other: object) -> bool: + return (isinstance(other, SfgHeaderInclude) + and self._header_file == other._header_file + and self._system_header == other._system_header + and self._private == other._private) + class SfgKernelNamespace: def __init__(self, ctx, name: str): @@ -28,7 +64,8 @@ class SfgKernelNamespace: return SfgKernelHandle(self._ctx, astname, self, ast.get_parameters()) - def create(self, assignments, config: CreateKernelConfig = None): + def create(self, assignments, config: CreateKernelConfig | None = None): + # type: ignore ast = create_kernel(assignments, config=config) return self.add(ast) @@ -74,3 +111,30 @@ class SfgKernelHandle: @property def fields(self): return self.fields + + +class SfgFunction: + def __init__(self, ctx: SfgContext, name: str, tree: SfgCallTreeNode): + self._ctx = ctx + self._name = name + self._tree = tree + + from .tree.visitors import ExpandingParameterCollector + + param_collector = ExpandingParameterCollector(self._ctx) + self._parameters = param_collector.visit(self._tree) + + @property + def name(self): + return self._name + + @property + def parameters(self): + return self._parameters + + @property + def tree(self): + return self._tree + + def get_code(self): + return self._tree.get_code(self._ctx) diff --git a/src/pystencilssfg/source_components/__init__.py b/src/pystencilssfg/source_components/__init__.py deleted file mode 100644 index 143d02c962738fce10e7f6b5187ce1148d11df60..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/source_components/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .function import SfgFunction -from .header_include import SfgHeaderInclude - -__all__ = [ - "SfgFunction", "SfgHeaderInclude" -] diff --git a/src/pystencilssfg/source_components/function.py b/src/pystencilssfg/source_components/function.py deleted file mode 100644 index cb15809e5d2f8515c7c99b8f0af1f24600ca2e1e..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/source_components/function.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -from ..tree import SfgCallTreeNode -from ..tree.visitors import ExpandingParameterCollector - -if TYPE_CHECKING: - from ..context import SfgContext - - -class SfgFunction: - def __init__(self, ctx: SfgContext, name: str, tree: SfgCallTreeNode): - self._ctx = ctx - self._name = name - self._tree = tree - - param_collector = ExpandingParameterCollector(self._ctx) - self._parameters = param_collector.visit(self._tree) - - @property - def name(self): - return self._name - - @property - def parameters(self): - return self._parameters - - @property - def tree(self): - return self._tree - - def get_code(self): - return self._tree.get_code(self._ctx) diff --git a/src/pystencilssfg/source_components/header_include.py b/src/pystencilssfg/source_components/header_include.py deleted file mode 100644 index fc9cd87b95005422840eebd4b33e145c1e24813f..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/source_components/header_include.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - - -class SfgHeaderInclude: - def __init__(self, header_file: str, system_header: bool = False, private: bool = False): - self._header_file = header_file - self._system_header = system_header - self._private = private - - @property - def system_header(self): - return self._system_header - - @property - def private(self): - return self._private - - def get_code(self): - if self._system_header: - return f"#include <{self._header_file}>" - else: - return f'#include "{self._header_file}"' - - def __hash__(self) -> int: - return hash((self._header_file, self._system_header, self._private)) - - def __eq__(self, other: object) -> bool: - return (isinstance(other, SfgHeaderInclude) - and self._header_file == other._header_file - and self._system_header == other._system_header - and self._private == other._private) diff --git a/src/pystencilssfg/source_concepts/cpp/std_mdspan.py b/src/pystencilssfg/source_concepts/cpp/std_mdspan.py index 2550991b009799f6ec8235560987f9ac0559c200..bd5c0eae5fc781704068504f4f0fb071ceb9b81b 100644 --- a/src/pystencilssfg/source_concepts/cpp/std_mdspan.py +++ b/src/pystencilssfg/source_concepts/cpp/std_mdspan.py @@ -4,7 +4,7 @@ from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeS from ...tree import SfgStatements from ..source_objects import SrcField -from ...source_components.header_include import SfgHeaderInclude +from ...source_components import SfgHeaderInclude from ...types import PsType, cpp_typename, SrcType from ...exceptions import SfgException diff --git a/src/pystencilssfg/source_concepts/cpp/std_vector.py b/src/pystencilssfg/source_concepts/cpp/std_vector.py index a4e5ac25e996a4473719f5b054ed746f59a3558b..1b4c7d80313d413f6b34f9ef580091d39f749738 100644 --- a/src/pystencilssfg/source_concepts/cpp/std_vector.py +++ b/src/pystencilssfg/source_concepts/cpp/std_vector.py @@ -6,7 +6,7 @@ from ...tree import SfgStatements from ..source_objects import SrcField, SrcVector from ..source_objects import TypedSymbolOrObject from ...types import SrcType, PsType, cpp_typename -from ...source_components.header_include import SfgHeaderInclude +from ...source_components import SfgHeaderInclude from ...exceptions import SfgException diff --git a/src/pystencilssfg/source_concepts/source_objects.py b/src/pystencilssfg/source_concepts/source_objects.py index f7e5a8ae4158588eb193d2c68aaf7bdb2a24afc7..be1eebdc2a82d54d86b46beb8922e73474474b81 100644 --- a/src/pystencilssfg/source_concepts/source_objects.py +++ b/src/pystencilssfg/source_concepts/source_objects.py @@ -71,7 +71,7 @@ class SrcField(SrcObject, ABC): def extract_parameters(self, field: Field) -> SfgSequence: ptr = FieldPointerSymbol(field.name, field.dtype, False) - from ..tree import make_sequence + from ..composer import make_sequence return make_sequence( self.extract_ptr(ptr), diff --git a/src/pystencilssfg/tree/__init__.py b/src/pystencilssfg/tree/__init__.py index b76d8cb5457402eed91fbe192fa4816d564ce131..8ecc06149874426dfe3ef98c6d8682c4774b9d42 100644 --- a/src/pystencilssfg/tree/__init__.py +++ b/src/pystencilssfg/tree/__init__.py @@ -1,9 +1,7 @@ from .basic_nodes import SfgCallTreeNode, SfgKernelCallNode, SfgBlock, SfgSequence, SfgStatements from .conditional import SfgBranch, SfgCondition -from .builders import make_sequence __all__ = [ "SfgCallTreeNode", "SfgKernelCallNode", "SfgSequence", "SfgBlock", "SfgStatements", - "SfgCondition", "SfgBranch", - "make_sequence" + "SfgCondition", "SfgBranch" ] diff --git a/src/pystencilssfg/tree/basic_nodes.py b/src/pystencilssfg/tree/basic_nodes.py index 1363f352ae607837eb1b86b533d6aec45e4f84c9..b4015d0255c0d17167825ad5b6956399f4ae16f4 100644 --- a/src/pystencilssfg/tree/basic_nodes.py +++ b/src/pystencilssfg/tree/basic_nodes.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Sequence from abc import ABC, abstractmethod from itertools import chain -from ..kernel_namespace import SfgKernelHandle +from ..source_components import SfgKernelHandle from ..source_concepts.source_objects import SrcObject, TypedSymbolOrObject if TYPE_CHECKING: diff --git a/src/pystencilssfg/tree/builders.py b/src/pystencilssfg/tree/builders.py deleted file mode 100644 index 0e006af40b26e55937573bda8a70d1338a2e043f..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/tree/builders.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod - -from .basic_nodes import SfgCallTreeNode, SfgSequence, SfgBlock, SfgStatements -from .conditional import SfgCondition, SfgCustomCondition, SfgBranch - - -class SfgNodeBuilder(ABC): - @abstractmethod - def resolve(self) -> SfgCallTreeNode: - pass - - -def make_sequence(*args) -> SfgSequence: - children = [] - for i, arg in enumerate(args): - if isinstance(arg, SfgNodeBuilder): - children.append(arg.resolve()) - elif isinstance(arg, SfgCallTreeNode): - children.append(arg) - elif isinstance(arg, str): - children.append(SfgStatements(arg, (), ())) - elif isinstance(arg, tuple): - # Tuples are treated as blocks - subseq = make_sequence(*arg) - children.append(SfgBlock(subseq)) - else: - raise TypeError(f"Sequence argument {i} has invalid type.") - - return SfgSequence(children) - - -class SfgBranchBuilder(SfgNodeBuilder): - def __init__(self): - self._phase = 0 - - self._cond = None - self._branch_true = SfgSequence(()) - self._branch_false = None - - def __call__(self, *args) -> SfgBranchBuilder: - match self._phase: - case 0: # Condition - if len(args) != 1: - raise ValueError("Must specify exactly one argument as branch condition!") - - cond = args[0] - - if isinstance(cond, str): - cond = SfgCustomCondition(cond) - elif not isinstance(cond, SfgCondition): - raise ValueError( - "Invalid type for branch condition. Must be either `str` or a subclass of `SfgCondition`.") - - self._cond = cond - - case 1: # Then-branch - self._branch_true = make_sequence(*args) - case 2: # Else-branch - self._branch_false = make_sequence(*args) - case _: # There's no third branch! - raise TypeError("Branch construct already complete.") - - self._phase += 1 - - return self - - def resolve(self) -> SfgCallTreeNode: - return SfgBranch(self._cond, self._branch_true, self._branch_false) diff --git a/src/pystencilssfg/tree/deferred_nodes.py b/src/pystencilssfg/tree/deferred_nodes.py index ae364bb22632824112defeeef913da22dc8ac3af..bf2cd32a64b8060227ec7714c567d6c05a9144d3 100644 --- a/src/pystencilssfg/tree/deferred_nodes.py +++ b/src/pystencilssfg/tree/deferred_nodes.py @@ -8,8 +8,7 @@ from pystencils.typing import FieldPointerSymbol, FieldShapeSymbol, FieldStrideS from ..exceptions import SfgException -from .basic_nodes import SfgCallTreeNode -from .builders import make_sequence +from .basic_nodes import SfgCallTreeNode, SfgSequence from ..source_concepts import SrcField from ..source_concepts.source_objects import TypedSymbolOrObject @@ -73,8 +72,12 @@ class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode): else: strides.append((c, s)) - return make_sequence( - self._src_field.extract_ptr(ptr), - *(self._src_field.extract_size(c, s) for c, s in shape), - *(self._src_field.extract_stride(c, s) for c, s in strides) - ) + nodes = [] + + if ptr is not None: + nodes += [self._src_field.extract_ptr(ptr)] + + nodes += [self._src_field.extract_size(c, s) for c, s in shape] + nodes += [self._src_field.extract_stride(c, s) for c, s in strides] + + return SfgSequence(nodes) diff --git a/src/pystencilssfg/tree/visitors.py b/src/pystencilssfg/tree/visitors.py index 7660cf20f347112a67faaff9db238dd10ee6b29f..877ed5a68ec9e97f9d4a772be2a75969fae33864 100644 --- a/src/pystencilssfg/tree/visitors.py +++ b/src/pystencilssfg/tree/visitors.py @@ -23,7 +23,7 @@ class FlattenSequences(): @visit.case(SfgSequence) def sequence(self, sequence: SfgSequence) -> None: - children_flattened = [] + children_flattened: list[SfgCallTreeNode] = [] def flatten(seq: SfgSequence): for c in seq.children: diff --git a/tests/TestSequencing.py b/tests/TestSequencing.py index ad9c522dd8a81340a3d89d4c574c9ee360fdff58..9e97052c5925e5507827906a7b206e3f87b8575f 100644 --- a/tests/TestSequencing.py +++ b/tests/TestSequencing.py @@ -9,7 +9,7 @@ with SourceFileGenerator() as sfg: lb_ast_even = create_lb_ast(lbm_config=lb_config, timestep=Timestep.EVEN) lb_ast_even.function_name = "streamCollide_even" - + lb_ast_odd = create_lb_ast(lbm_config=lb_config, timestep=Timestep.ODD) lb_ast_odd.function_name = "streamCollide_odd" diff --git a/tests/mdspan/kernels.py b/tests/mdspan/kernels.py index 27a482ae02d852c8de439a07ac41fc61efc37744..3f045bad4c63f2877e2d6e2b20df531dfd4ec8f2 100644 --- a/tests/mdspan/kernels.py +++ b/tests/mdspan/kernels.py @@ -14,7 +14,7 @@ def field_t(field: ps.Field): reference=True) -with SourceFileGenerator("poisson") as sfg: +with SourceFileGenerator() as sfg: src, dst = ps.fields("src, dst(1) : double[2D]") h = sp.Symbol('h')