From 43daedc4687dd0ba43455faa209bf302cf86b988 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 14 Nov 2023 22:55:24 +0900 Subject: [PATCH] some minor refactoring --- Notes.md | 27 --------------------------- Readme.md | 2 +- pystencilssfg/tree/basic_nodes.py | 13 +++++-------- pystencilssfg/tree/visitors.py | 31 ++++++++++++++++++------------- 4 files changed, 24 insertions(+), 49 deletions(-) delete mode 100644 Notes.md diff --git a/Notes.md b/Notes.md deleted file mode 100644 index 1c8539b..0000000 --- a/Notes.md +++ /dev/null @@ -1,27 +0,0 @@ - -# Build System Integration - -## Configurator Script - -The configurator script should configure the code generator and provide global configuration to all codegen scripts. -In the CMake integration, it can be specified globally via the `PystencilsSfg_CONFIGURATOR_SCRIPT` cache variable. - -To decide and implement: - - - Use `runpy` and communicate via a global variable, or use `importlib.util.spec_from_file_location` and communicate via - a function call? In either case, there needs to be concensus about at least one name in the configurator script. - - Allow specifying a separate configurator file at `pystencilssfg_generate_target_sources`? Sound sensible... It's basically - for free with the potential to add lots of flexibility - -## Generator flags - -Two separate lists of flags may be passed to generator scripts: Some may be evaluated by the SFG, and the rest -will be passed on to the user script. - -Arguments to the SFG include: - - - Path of the configurator script - - Output directory - -How to separate user from generator arguments? - diff --git a/Readme.md b/Readme.md index 8bf4b6a..1f9ee2d 100644 --- a/Readme.md +++ b/Readme.md @@ -1,3 +1,3 @@ -# pystencils Source File Generator (ps-sfg) +# pystencils Source File Generator (pystencils-sfg) diff --git a/pystencilssfg/tree/basic_nodes.py b/pystencilssfg/tree/basic_nodes.py index 75fceb5..6fbf99a 100644 --- a/pystencilssfg/tree/basic_nodes.py +++ b/pystencilssfg/tree/basic_nodes.py @@ -1,19 +1,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Sequence, Set, Union, Iterable - -if TYPE_CHECKING: - from ..context import SfgContext - from ..source_components import SfgHeaderInclude +from typing import TYPE_CHECKING, Sequence, Set from abc import ABC, abstractmethod from itertools import chain from ..kernel_namespace import SfgKernelHandle from ..source_concepts.source_objects import SrcObject, TypedSymbolOrObject - from ..exceptions import SfgException -from pystencils.typing import TypedSymbol +if TYPE_CHECKING: + from ..context import SfgContext + from ..source_components import SfgHeaderInclude class SfgCallTreeNode(ABC): """Base class for all nodes comprising SFG call trees. """ @@ -72,7 +69,7 @@ class SfgStatements(SfgCallTreeLeaf): required_objects: Objects (as `SrcObject` or `TypedSymbol`) that are required as input to these statements. """ - def __init__(self, + def __init__(self, code_string: str, defined_params: Sequence[TypedSymbolOrObject], required_params: Sequence[TypedSymbolOrObject]): diff --git a/pystencilssfg/tree/visitors.py b/pystencilssfg/tree/visitors.py index 9da7db2..bc466c0 100644 --- a/pystencilssfg/tree/visitors.py +++ b/pystencilssfg/tree/visitors.py @@ -1,8 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Sequence, Set, Union, Iterable -if TYPE_CHECKING: - from ..context import SfgContext +from typing import TYPE_CHECKING, Set from functools import reduce @@ -12,8 +10,13 @@ from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgState from .deferred_nodes import SfgParamCollectionDeferredNode +if TYPE_CHECKING: + from ..context import SfgContext + + class FlattenSequences(): """Flattens any nested sequences occuring in a kernel call tree.""" + def visit(self, node: SfgCallTreeNode) -> None: if isinstance(node, SfgSequence): return self._visit_SfgSequence(node) @@ -23,14 +26,14 @@ class FlattenSequences(): def _visit_SfgSequence(self, sequence: SfgSequence) -> None: children_flattened = [] - + def flatten(seq: SfgSequence): for c in seq.children: if isinstance(c, SfgSequence): flatten(c) else: children_flattened.append(c) - + flatten(sequence) for c in children_flattened: @@ -49,14 +52,15 @@ class CollectIncludes: class ExpandingParameterCollector(): + """Collects all parameters required but not defined in a kernel call tree. + Expands any deferred nodes of type `SfgParamCollectionDeferredNode` found within sequences on the way. + """ + def __init__(self, ctx: SfgContext) -> None: self._ctx = ctx self._flattener = FlattenSequences() - """Collects all parameters required but not defined in a kernel call tree. - Expands any deferred nodes of type `SfgParamCollectionDeferredNode` found within sequences on the way. - """ def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]: if isinstance(node, SfgCallTreeLeaf): return self._visit_SfgCallTreeLeaf(node) @@ -72,13 +76,13 @@ class ExpandingParameterCollector(): """ Only in a sequence may parameters be defined and visible to subsequent nodes. """ - + params = set() def iter_nested_sequences(seq: SfgSequence, visible_params: Set[TypedSymbol]): for i in range(len(seq.children) - 1, -1, -1): c = seq.children[i] - + if isinstance(c, SfgParamCollectionDeferredNode): c = c.expand(self._ctx, visible_params=visible_params) seq.replace_child(i, c) @@ -88,7 +92,7 @@ class ExpandingParameterCollector(): else: if isinstance(c, SfgStatements): visible_params -= c.defined_parameters - + visible_params |= self.visit(c) iter_nested_sequences(sequence, params) @@ -108,6 +112,7 @@ class ParameterCollector(): Requires that all sequences in the tree are flattened. """ + def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]: if isinstance(node, SfgCallTreeLeaf): return self._visit_SfgCallTreeLeaf(node) @@ -123,12 +128,12 @@ class ParameterCollector(): """ Only in a sequence may parameters be defined and visible to subsequent nodes. """ - + params = set() for c in sequence.children[::-1]: if isinstance(c, SfgStatements): params -= c.defined_parameters - + assert not isinstance(c, SfgSequence), "Sequence not flattened." params |= self.visit(c) return params -- GitLab