From ca3ebac4c2919538df94755ba62b228053fd4c95 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 16 Nov 2023 21:27:52 +0900 Subject: [PATCH] Introduced composer and refactored structure --- src/pystencilssfg/__init__.py | 9 +- src/pystencilssfg/composer.py | 150 +++++++++++++++++ src/pystencilssfg/context.py | 151 ++++-------------- src/pystencilssfg/generator.py | 40 +++++ ...rnel_namespace.py => source_components.py} | 68 +++++++- .../source_components/__init__.py | 6 - .../source_components/function.py | 33 ---- .../source_components/header_include.py | 31 ---- .../source_concepts/cpp/std_mdspan.py | 2 +- .../source_concepts/cpp/std_vector.py | 2 +- .../source_concepts/source_objects.py | 2 +- src/pystencilssfg/tree/__init__.py | 4 +- src/pystencilssfg/tree/basic_nodes.py | 2 +- src/pystencilssfg/tree/builders.py | 70 -------- src/pystencilssfg/tree/deferred_nodes.py | 17 +- src/pystencilssfg/tree/visitors.py | 2 +- tests/TestSequencing.py | 2 +- tests/mdspan/kernels.py | 2 +- 18 files changed, 305 insertions(+), 288 deletions(-) create mode 100644 src/pystencilssfg/composer.py create mode 100644 src/pystencilssfg/generator.py rename src/pystencilssfg/{kernel_namespace.py => source_components.py} (53%) delete mode 100644 src/pystencilssfg/source_components/__init__.py delete mode 100644 src/pystencilssfg/source_components/function.py delete mode 100644 src/pystencilssfg/source_components/header_include.py delete mode 100644 src/pystencilssfg/tree/builders.py diff --git a/src/pystencilssfg/__init__.py b/src/pystencilssfg/__init__.py index 13847b5..ae3cd1a 100644 --- a/src/pystencilssfg/__init__.py +++ b/src/pystencilssfg/__init__.py @@ -1,11 +1,8 @@ -from .context import SourceFileGenerator, SfgContext -from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle - -from .types import PsType, SrcType +from .generator import SourceFileGenerator +from .composer import SfgComposer __all__ = [ - "SourceFileGenerator", "SfgContext", "SfgKernelNamespace", "SfgKernelHandle", - "PsType", "SrcType" + "SourceFileGenerator", "SfgComposer", ] from . import _version diff --git a/src/pystencilssfg/composer.py b/src/pystencilssfg/composer.py new file mode 100644 index 0000000..ed0dced --- /dev/null +++ b/src/pystencilssfg/composer.py @@ -0,0 +1,150 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Optional +from abc import ABC, abstractmethod + +from pystencils import Field +from pystencils.astnodes import KernelFunction + +from .tree import SfgCallTreeNode, SfgKernelCallNode, SfgStatements, SfgSequence, SfgBlock +from .tree.deferred_nodes import SfgDeferredFieldMapping +from .tree.conditional import SfgCondition, SfgCustomCondition, SfgBranch +from .source_components import SfgFunction, SfgHeaderInclude, SfgKernelNamespace, SfgKernelHandle +from .source_concepts import SrcField, TypedSymbolOrObject + +if TYPE_CHECKING: + from .context import SfgContext + + +class SfgComposer: + def __init__(self, ctx: SfgContext): + self._ctx = ctx + + @property + def kernels(self) -> SfgKernelNamespace: + return self._ctx._default_kernel_namespace + + def kernel_namespace(self, name: str) -> SfgKernelNamespace: + kns = self._ctx.get_kernel_namespace(name) + if kns is None: + kns = SfgKernelNamespace(self, name) + self._ctx.add_kernel_namespace(kns) + + return kns + + def include(self, header_file: str): + system_header = False + if header_file.startswith("<") and header_file.endswith(">"): + header_file = header_file[1:-1] + system_header = True + + self._ctx.add_include(SfgHeaderInclude(header_file, system_header=system_header)) + + def kernel_function(self, name: str, ast_or_kernel_handle: KernelFunction | SfgKernelHandle): + if self._ctx.get_function(name) is not None: + raise ValueError(f"Function {name} already exists.") + + if isinstance(ast_or_kernel_handle, KernelFunction): + khandle = self._ctx.default_kernel_namespace.add(ast_or_kernel_handle) + tree = SfgKernelCallNode(khandle) + elif isinstance(ast_or_kernel_handle, SfgKernelCallNode): + tree = ast_or_kernel_handle + else: + raise TypeError("Invalid type of argument `ast_or_kernel_handle`!") + + func = SfgFunction(self._ctx, name, tree) + self._ctx.add_function(func) + + def function(self, name: str): + if self._ctx.get_function(name) is not None: + raise ValueError(f"Function {name} already exists.") + + def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder): + tree = make_sequence(*args) + func = SfgFunction(self._ctx, name, tree) + self._ctx.add_function(func) + + return sequencer + + def call(self, kernel_handle: SfgKernelHandle) -> SfgKernelCallNode: + return SfgKernelCallNode(kernel_handle) + + def seq(self, *args: SfgCallTreeNode) -> SfgSequence: + return make_sequence(*args) + + @property + def branch(self) -> SfgBranchBuilder: + return SfgBranchBuilder() + + def map_field(self, field: Field, src_object: Optional[SrcField] = None) -> SfgDeferredFieldMapping: + if src_object is None: + raise NotImplementedError("Automatic field extraction is not implemented yet.") + else: + return SfgDeferredFieldMapping(field, src_object) + + def map_param(self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str): + return SfgStatements(mapping, (lhs,), (rhs,)) + + +class SfgNodeBuilder(ABC): + @abstractmethod + def resolve(self) -> SfgCallTreeNode: + pass + + +def make_sequence(*args: tuple | str | SfgCallTreeNode | SfgNodeBuilder) -> SfgSequence: + children = [] + for i, arg in enumerate(args): + if isinstance(arg, SfgNodeBuilder): + children.append(arg.resolve()) + elif isinstance(arg, SfgCallTreeNode): + children.append(arg) + elif isinstance(arg, str): + children.append(SfgStatements(arg, (), ())) + elif isinstance(arg, tuple): + # Tuples are treated as blocks + subseq = make_sequence(*arg) + children.append(SfgBlock(subseq)) + else: + raise TypeError(f"Sequence argument {i} has invalid type.") + + return SfgSequence(children) + + +class SfgBranchBuilder(SfgNodeBuilder): + def __init__(self): + self._phase = 0 + + self._cond = None + self._branch_true = SfgSequence(()) + self._branch_false = None + + def __call__(self, *args) -> SfgBranchBuilder: + match self._phase: + case 0: # Condition + if len(args) != 1: + raise ValueError("Must specify exactly one argument as branch condition!") + + cond = args[0] + + if isinstance(cond, str): + cond = SfgCustomCondition(cond) + elif not isinstance(cond, SfgCondition): + raise ValueError( + "Invalid type for branch condition. Must be either `str` or a subclass of `SfgCondition`.") + + self._cond = cond + + case 1: # Then-branch + self._branch_true = make_sequence(*args) + case 2: # Else-branch + self._branch_false = make_sequence(*args) + case _: # There's no third branch! + raise TypeError("Branch construct already complete.") + + self._phase += 1 + + return self + + def resolve(self) -> SfgCallTreeNode: + assert self._cond is not None + return SfgBranch(self._cond, self._branch_true, self._branch_false) diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index e11d827..a83fe12 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -1,54 +1,8 @@ -from typing import Generator, Union, Optional, Sequence +from typing import Generator, Sequence -import sys -import os -from os import path - - -from pystencils import Field -from pystencils.astnodes import KernelFunction - -from .configuration import SfgConfiguration, config_from_commandline, merge_configurations, SfgCodeStyle -from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle -from .tree import SfgCallTreeNode, SfgKernelCallNode, SfgStatements -from .tree.deferred_nodes import SfgDeferredFieldMapping -from .tree.builders import SfgBranchBuilder, make_sequence +from .configuration import SfgConfiguration, SfgCodeStyle from .tree.visitors import CollectIncludes -from .source_concepts import SrcField, TypedSymbolOrObject -from .source_components import SfgFunction, SfgHeaderInclude - - -class SourceFileGenerator: - def __init__(self, sfg_config: SfgConfiguration | None = None): - if sfg_config and not isinstance(sfg_config, SfgConfiguration): - raise TypeError("sfg_config is not an SfgConfiguration.") - - import __main__ - scriptpath = __main__.__file__ - scriptname = path.split(scriptpath)[1] - basename = path.splitext(scriptname)[0] - - project_config, cmdline_config, script_args = config_from_commandline(sys.argv) - - config = merge_configurations(project_config, cmdline_config, sfg_config) - - self._context = SfgContext(script_args, config) - - from .emitters.cpu.basic_cpu import BasicCpuEmitter - self._emitter = BasicCpuEmitter(basename, config) - - def clean_files(self): - for file in self._emitter.output_files: - if path.exists(file): - os.remove(file) - - def __enter__(self): - self.clean_files() - return self._context - - def __exit__(self, exc_type, exc_value, traceback): - if exc_type is None: - self._emitter.write_files(self._context) +from .source_components import SfgHeaderInclude, SfgKernelNamespace, SfgFunction class SfgContext: @@ -91,98 +45,49 @@ class SfgContext: return self._config.codestyle # ---------------------------------------------------------------------------------------------- - # Source Component Getters + # Kernel Namespaces # ---------------------------------------------------------------------------------------------- def includes(self) -> Generator[SfgHeaderInclude, None, None]: yield from self._includes - def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]: - yield from self._kernel_namespaces.values() - - def functions(self) -> Generator[SfgFunction, None, None]: - yield from self._functions.values() - - # ---------------------------------------------------------------------------------------------- - # Source Component Adders - # ---------------------------------------------------------------------------------------------- - def add_include(self, include: SfgHeaderInclude): self._includes.add(include) - def add_function(self, func: SfgFunction): - if func.name in self._functions: - raise ValueError(f"Duplicate function: {func.name}") - - self._functions[func.name] = func - for incl in CollectIncludes().visit(func._tree): - self.add_include(incl) - # ---------------------------------------------------------------------------------------------- - # Factory-like Adders + # Kernel Namespaces # ---------------------------------------------------------------------------------------------- @property - def kernels(self) -> SfgKernelNamespace: + def default_kernel_namespace(self) -> SfgKernelNamespace: return self._default_kernel_namespace - def kernel_namespace(self, name: str) -> SfgKernelNamespace: - if name in self._kernel_namespaces: - raise ValueError(f"Duplicate kernel namespace: {name}") - - kns = SfgKernelNamespace(self, name) - self._kernel_namespaces[name] = kns - return kns - - def include(self, header_file: str): - system_header = False - if header_file.startswith("<") and header_file.endswith(">"): - header_file = header_file[1:-1] - system_header = True - - self.add_include(SfgHeaderInclude(header_file, system_header=system_header)) - - def function(self, - name: str, - ast_or_kernel_handle: Optional[Union[KernelFunction, SfgKernelHandle]] = None): - if name in self._functions: - raise ValueError(f"Duplicate function: {name}") - - if ast_or_kernel_handle is not None: - if isinstance(ast_or_kernel_handle, KernelFunction): - khandle = self._default_kernel_namespace.add(ast_or_kernel_handle) - tree = SfgKernelCallNode(khandle) - elif isinstance(ast_or_kernel_handle, SfgKernelCallNode): - tree = ast_or_kernel_handle - else: - raise TypeError("Invalid type of argument `ast_or_kernel_handle`!") - - func = SfgFunction(self, name, tree) - self.add_function(func) - else: - def sequencer(*args: SfgCallTreeNode): - tree = make_sequence(*args) - func = SfgFunction(self, name, tree) - self.add_function(func) - - return sequencer + def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]: + yield from self._kernel_namespaces.values() + + def get_kernel_namespace(self, str) -> SfgKernelNamespace | None: + return self._kernel_namespaces.get(str) + + def add_kernel_namespace(self, namespace: SfgKernelNamespace): + if namespace.name in self._kernel_namespaces: + raise ValueError(f"Duplicate kernel namespace: {namespace.name}") + + self._kernel_namespaces[namespace.name] = namespace # ---------------------------------------------------------------------------------------------- - # In-Sequence builders to be used within the second phase of SfgContext.function(). + # Functions # ---------------------------------------------------------------------------------------------- - def call(self, kernel_handle: SfgKernelHandle) -> SfgKernelCallNode: - return SfgKernelCallNode(kernel_handle) + def functions(self) -> Generator[SfgFunction, None, None]: + yield from self._functions.values() - @property - def branch(self) -> SfgBranchBuilder: - return SfgBranchBuilder() + def get_function(self, name: str) -> SfgFunction | None: + return self._functions.get(name, None) - def map_field(self, field: Field, src_object: Optional[SrcField] = None) -> SfgDeferredFieldMapping: - if src_object is None: - raise NotImplementedError("Automatic field extraction is not implemented yet.") - else: - return SfgDeferredFieldMapping(field, src_object) + def add_function(self, func: SfgFunction) -> None: + if func.name in self._functions: + raise ValueError(f"Duplicate function: {func.name}") - def map_param(self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str): - return SfgStatements(mapping, (lhs,), (rhs,)) + self._functions[func.name] = func + for incl in CollectIncludes().visit(func._tree): + self.add_include(incl) diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py new file mode 100644 index 0000000..9f1c98d --- /dev/null +++ b/src/pystencilssfg/generator.py @@ -0,0 +1,40 @@ +import sys +import os +from os import path + +from .configuration import SfgConfiguration, config_from_commandline, merge_configurations +from .context import SfgContext +from .composer import SfgComposer + + +class SourceFileGenerator: + def __init__(self, sfg_config: SfgConfiguration | None = None): + if sfg_config and not isinstance(sfg_config, SfgConfiguration): + raise TypeError("sfg_config is not an SfgConfiguration.") + + import __main__ + scriptpath = __main__.__file__ + scriptname = path.split(scriptpath)[1] + basename = path.splitext(scriptname)[0] + + project_config, cmdline_config, script_args = config_from_commandline(sys.argv) + + config = merge_configurations(project_config, cmdline_config, sfg_config) + + self._context = SfgContext(script_args, config) + + from .emitters.cpu.basic_cpu import BasicCpuEmitter + self._emitter = BasicCpuEmitter(basename, config) + + def clean_files(self): + for file in self._emitter.output_files: + if path.exists(file): + os.remove(file) + + def __enter__(self): + self.clean_files() + return SfgComposer(self._context) + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + self._emitter.write_files(self._context) diff --git a/src/pystencilssfg/kernel_namespace.py b/src/pystencilssfg/source_components.py similarity index 53% rename from src/pystencilssfg/kernel_namespace.py rename to src/pystencilssfg/source_components.py index f11b8f1..1c0e104 100644 --- a/src/pystencilssfg/kernel_namespace.py +++ b/src/pystencilssfg/source_components.py @@ -1,8 +1,44 @@ -from typing import Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from pystencils import CreateKernelConfig, create_kernel from pystencils.astnodes import KernelFunction +if TYPE_CHECKING: + from .context import SfgContext + from .tree import SfgCallTreeNode + + +class SfgHeaderInclude: + def __init__(self, header_file: str, system_header: bool = False, private: bool = False): + self._header_file = header_file + self._system_header = system_header + self._private = private + + @property + def system_header(self): + return self._system_header + + @property + def private(self): + return self._private + + def get_code(self): + if self._system_header: + return f"#include <{self._header_file}>" + else: + return f'#include "{self._header_file}"' + + def __hash__(self) -> int: + return hash((self._header_file, self._system_header, self._private)) + + def __eq__(self, other: object) -> bool: + return (isinstance(other, SfgHeaderInclude) + and self._header_file == other._header_file + and self._system_header == other._system_header + and self._private == other._private) + class SfgKernelNamespace: def __init__(self, ctx, name: str): @@ -28,7 +64,8 @@ class SfgKernelNamespace: return SfgKernelHandle(self._ctx, astname, self, ast.get_parameters()) - def create(self, assignments, config: CreateKernelConfig = None): + def create(self, assignments, config: CreateKernelConfig | None = None): + # type: ignore ast = create_kernel(assignments, config=config) return self.add(ast) @@ -74,3 +111,30 @@ class SfgKernelHandle: @property def fields(self): return self.fields + + +class SfgFunction: + def __init__(self, ctx: SfgContext, name: str, tree: SfgCallTreeNode): + self._ctx = ctx + self._name = name + self._tree = tree + + from .tree.visitors import ExpandingParameterCollector + + param_collector = ExpandingParameterCollector(self._ctx) + self._parameters = param_collector.visit(self._tree) + + @property + def name(self): + return self._name + + @property + def parameters(self): + return self._parameters + + @property + def tree(self): + return self._tree + + def get_code(self): + return self._tree.get_code(self._ctx) diff --git a/src/pystencilssfg/source_components/__init__.py b/src/pystencilssfg/source_components/__init__.py deleted file mode 100644 index 143d02c..0000000 --- a/src/pystencilssfg/source_components/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .function import SfgFunction -from .header_include import SfgHeaderInclude - -__all__ = [ - "SfgFunction", "SfgHeaderInclude" -] diff --git a/src/pystencilssfg/source_components/function.py b/src/pystencilssfg/source_components/function.py deleted file mode 100644 index cb15809..0000000 --- a/src/pystencilssfg/source_components/function.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -from ..tree import SfgCallTreeNode -from ..tree.visitors import ExpandingParameterCollector - -if TYPE_CHECKING: - from ..context import SfgContext - - -class SfgFunction: - def __init__(self, ctx: SfgContext, name: str, tree: SfgCallTreeNode): - self._ctx = ctx - self._name = name - self._tree = tree - - param_collector = ExpandingParameterCollector(self._ctx) - self._parameters = param_collector.visit(self._tree) - - @property - def name(self): - return self._name - - @property - def parameters(self): - return self._parameters - - @property - def tree(self): - return self._tree - - def get_code(self): - return self._tree.get_code(self._ctx) diff --git a/src/pystencilssfg/source_components/header_include.py b/src/pystencilssfg/source_components/header_include.py deleted file mode 100644 index fc9cd87..0000000 --- a/src/pystencilssfg/source_components/header_include.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - - -class SfgHeaderInclude: - def __init__(self, header_file: str, system_header: bool = False, private: bool = False): - self._header_file = header_file - self._system_header = system_header - self._private = private - - @property - def system_header(self): - return self._system_header - - @property - def private(self): - return self._private - - def get_code(self): - if self._system_header: - return f"#include <{self._header_file}>" - else: - return f'#include "{self._header_file}"' - - def __hash__(self) -> int: - return hash((self._header_file, self._system_header, self._private)) - - def __eq__(self, other: object) -> bool: - return (isinstance(other, SfgHeaderInclude) - and self._header_file == other._header_file - and self._system_header == other._system_header - and self._private == other._private) diff --git a/src/pystencilssfg/source_concepts/cpp/std_mdspan.py b/src/pystencilssfg/source_concepts/cpp/std_mdspan.py index 2550991..bd5c0ea 100644 --- a/src/pystencilssfg/source_concepts/cpp/std_mdspan.py +++ b/src/pystencilssfg/source_concepts/cpp/std_mdspan.py @@ -4,7 +4,7 @@ from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeS from ...tree import SfgStatements from ..source_objects import SrcField -from ...source_components.header_include import SfgHeaderInclude +from ...source_components import SfgHeaderInclude from ...types import PsType, cpp_typename, SrcType from ...exceptions import SfgException diff --git a/src/pystencilssfg/source_concepts/cpp/std_vector.py b/src/pystencilssfg/source_concepts/cpp/std_vector.py index a4e5ac2..1b4c7d8 100644 --- a/src/pystencilssfg/source_concepts/cpp/std_vector.py +++ b/src/pystencilssfg/source_concepts/cpp/std_vector.py @@ -6,7 +6,7 @@ from ...tree import SfgStatements from ..source_objects import SrcField, SrcVector from ..source_objects import TypedSymbolOrObject from ...types import SrcType, PsType, cpp_typename -from ...source_components.header_include import SfgHeaderInclude +from ...source_components import SfgHeaderInclude from ...exceptions import SfgException diff --git a/src/pystencilssfg/source_concepts/source_objects.py b/src/pystencilssfg/source_concepts/source_objects.py index f7e5a8a..be1eebd 100644 --- a/src/pystencilssfg/source_concepts/source_objects.py +++ b/src/pystencilssfg/source_concepts/source_objects.py @@ -71,7 +71,7 @@ class SrcField(SrcObject, ABC): def extract_parameters(self, field: Field) -> SfgSequence: ptr = FieldPointerSymbol(field.name, field.dtype, False) - from ..tree import make_sequence + from ..composer import make_sequence return make_sequence( self.extract_ptr(ptr), diff --git a/src/pystencilssfg/tree/__init__.py b/src/pystencilssfg/tree/__init__.py index b76d8cb..8ecc061 100644 --- a/src/pystencilssfg/tree/__init__.py +++ b/src/pystencilssfg/tree/__init__.py @@ -1,9 +1,7 @@ from .basic_nodes import SfgCallTreeNode, SfgKernelCallNode, SfgBlock, SfgSequence, SfgStatements from .conditional import SfgBranch, SfgCondition -from .builders import make_sequence __all__ = [ "SfgCallTreeNode", "SfgKernelCallNode", "SfgSequence", "SfgBlock", "SfgStatements", - "SfgCondition", "SfgBranch", - "make_sequence" + "SfgCondition", "SfgBranch" ] diff --git a/src/pystencilssfg/tree/basic_nodes.py b/src/pystencilssfg/tree/basic_nodes.py index 1363f35..b4015d0 100644 --- a/src/pystencilssfg/tree/basic_nodes.py +++ b/src/pystencilssfg/tree/basic_nodes.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Sequence from abc import ABC, abstractmethod from itertools import chain -from ..kernel_namespace import SfgKernelHandle +from ..source_components import SfgKernelHandle from ..source_concepts.source_objects import SrcObject, TypedSymbolOrObject if TYPE_CHECKING: diff --git a/src/pystencilssfg/tree/builders.py b/src/pystencilssfg/tree/builders.py deleted file mode 100644 index 0e006af..0000000 --- a/src/pystencilssfg/tree/builders.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod - -from .basic_nodes import SfgCallTreeNode, SfgSequence, SfgBlock, SfgStatements -from .conditional import SfgCondition, SfgCustomCondition, SfgBranch - - -class SfgNodeBuilder(ABC): - @abstractmethod - def resolve(self) -> SfgCallTreeNode: - pass - - -def make_sequence(*args) -> SfgSequence: - children = [] - for i, arg in enumerate(args): - if isinstance(arg, SfgNodeBuilder): - children.append(arg.resolve()) - elif isinstance(arg, SfgCallTreeNode): - children.append(arg) - elif isinstance(arg, str): - children.append(SfgStatements(arg, (), ())) - elif isinstance(arg, tuple): - # Tuples are treated as blocks - subseq = make_sequence(*arg) - children.append(SfgBlock(subseq)) - else: - raise TypeError(f"Sequence argument {i} has invalid type.") - - return SfgSequence(children) - - -class SfgBranchBuilder(SfgNodeBuilder): - def __init__(self): - self._phase = 0 - - self._cond = None - self._branch_true = SfgSequence(()) - self._branch_false = None - - def __call__(self, *args) -> SfgBranchBuilder: - match self._phase: - case 0: # Condition - if len(args) != 1: - raise ValueError("Must specify exactly one argument as branch condition!") - - cond = args[0] - - if isinstance(cond, str): - cond = SfgCustomCondition(cond) - elif not isinstance(cond, SfgCondition): - raise ValueError( - "Invalid type for branch condition. Must be either `str` or a subclass of `SfgCondition`.") - - self._cond = cond - - case 1: # Then-branch - self._branch_true = make_sequence(*args) - case 2: # Else-branch - self._branch_false = make_sequence(*args) - case _: # There's no third branch! - raise TypeError("Branch construct already complete.") - - self._phase += 1 - - return self - - def resolve(self) -> SfgCallTreeNode: - return SfgBranch(self._cond, self._branch_true, self._branch_false) diff --git a/src/pystencilssfg/tree/deferred_nodes.py b/src/pystencilssfg/tree/deferred_nodes.py index ae364bb..bf2cd32 100644 --- a/src/pystencilssfg/tree/deferred_nodes.py +++ b/src/pystencilssfg/tree/deferred_nodes.py @@ -8,8 +8,7 @@ from pystencils.typing import FieldPointerSymbol, FieldShapeSymbol, FieldStrideS from ..exceptions import SfgException -from .basic_nodes import SfgCallTreeNode -from .builders import make_sequence +from .basic_nodes import SfgCallTreeNode, SfgSequence from ..source_concepts import SrcField from ..source_concepts.source_objects import TypedSymbolOrObject @@ -73,8 +72,12 @@ class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode): else: strides.append((c, s)) - return make_sequence( - self._src_field.extract_ptr(ptr), - *(self._src_field.extract_size(c, s) for c, s in shape), - *(self._src_field.extract_stride(c, s) for c, s in strides) - ) + nodes = [] + + if ptr is not None: + nodes += [self._src_field.extract_ptr(ptr)] + + nodes += [self._src_field.extract_size(c, s) for c, s in shape] + nodes += [self._src_field.extract_stride(c, s) for c, s in strides] + + return SfgSequence(nodes) diff --git a/src/pystencilssfg/tree/visitors.py b/src/pystencilssfg/tree/visitors.py index 7660cf2..877ed5a 100644 --- a/src/pystencilssfg/tree/visitors.py +++ b/src/pystencilssfg/tree/visitors.py @@ -23,7 +23,7 @@ class FlattenSequences(): @visit.case(SfgSequence) def sequence(self, sequence: SfgSequence) -> None: - children_flattened = [] + children_flattened: list[SfgCallTreeNode] = [] def flatten(seq: SfgSequence): for c in seq.children: diff --git a/tests/TestSequencing.py b/tests/TestSequencing.py index ad9c522..9e97052 100644 --- a/tests/TestSequencing.py +++ b/tests/TestSequencing.py @@ -9,7 +9,7 @@ with SourceFileGenerator() as sfg: lb_ast_even = create_lb_ast(lbm_config=lb_config, timestep=Timestep.EVEN) lb_ast_even.function_name = "streamCollide_even" - + lb_ast_odd = create_lb_ast(lbm_config=lb_config, timestep=Timestep.ODD) lb_ast_odd.function_name = "streamCollide_odd" diff --git a/tests/mdspan/kernels.py b/tests/mdspan/kernels.py index 27a482a..3f045ba 100644 --- a/tests/mdspan/kernels.py +++ b/tests/mdspan/kernels.py @@ -14,7 +14,7 @@ def field_t(field: ps.Field): reference=True) -with SourceFileGenerator("poisson") as sfg: +with SourceFileGenerator() as sfg: src, dst = ps.fields("src, dst(1) : double[2D]") h = sp.Symbol('h') -- GitLab