From 2ecaf2e4406b763c1b72ee03f7e3c9983a2f6cc1 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 15 Nov 2023 19:55:17 +0900 Subject: [PATCH] introduced flake8 --- .flake8 | 2 + pystencilssfg/__init__.py | 4 +- pystencilssfg/__main__.py | 2 +- pystencilssfg/configuration.py | 20 +++++--- pystencilssfg/context.py | 49 +++++++++---------- pystencilssfg/emitters/cpu/basic_cpu.py | 1 + pystencilssfg/emitters/cpu/jinja_filters.py | 4 ++ pystencilssfg/kernel_namespace.py | 4 +- pystencilssfg/source_components/__init__.py | 4 +- pystencilssfg/source_components/function.py | 13 ++--- .../source_components/header_include.py | 11 +++-- pystencilssfg/source_concepts/__init__.py | 4 +- pystencilssfg/source_concepts/cpp/__init__.py | 6 +-- .../source_concepts/cpp/std_mdspan.py | 38 ++++++++------ .../source_concepts/cpp/std_vector.py | 24 ++++----- .../source_concepts/source_objects.py | 21 ++++---- pystencilssfg/tree/__init__.py | 8 +-- pystencilssfg/tree/basic_nodes.py | 31 ++++++------ pystencilssfg/tree/builders.py | 36 ++++++-------- pystencilssfg/tree/conditional.py | 25 +++++----- pystencilssfg/tree/deferred_nodes.py | 17 ++++--- pystencilssfg/types.py | 1 - 22 files changed, 171 insertions(+), 154 deletions(-) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..aa079ec --- /dev/null +++ b/.flake8 @@ -0,0 +1,2 @@ +[flake8] +max-line-length=120 diff --git a/pystencilssfg/__init__.py b/pystencilssfg/__init__.py index 100f11f..34bd784 100644 --- a/pystencilssfg/__init__.py +++ b/pystencilssfg/__init__.py @@ -4,8 +4,8 @@ from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle from .types import PsType, SrcType __all__ = [ - SourceFileGenerator, SfgContext, SfgKernelNamespace, SfgKernelHandle, - PsType, SrcType + "SourceFileGenerator", "SfgContext", "SfgKernelNamespace", "SfgKernelHandle", + "PsType", "SrcType" ] __version__ = "0.0.0" diff --git a/pystencilssfg/__main__.py b/pystencilssfg/__main__.py index 36413da..3c793c4 100644 --- a/pystencilssfg/__main__.py +++ b/pystencilssfg/__main__.py @@ -35,7 +35,7 @@ def main(): def version(args, argv): from . import __version__ - print(version) + print(__version__) exit(0) diff --git a/pystencilssfg/configuration.py b/pystencilssfg/configuration.py index 4bb4559..2d68d35 100644 --- a/pystencilssfg/configuration.py +++ b/pystencilssfg/configuration.py @@ -97,17 +97,17 @@ def run_configurator(configurator_script: str): if not path.exists(configurator_script): raise SfgConfigException(SfgConfigSource.PROJECT, f"Configurator script not found: {configurator_script} is not a file.") - + cfg_spec = importlib.util.spec_from_file_location(configurator_script) configurator = importlib.util.module_from_spec(cfg_spec) - if not hasattr(project_config, "sfg_config"): + if not hasattr(configurator, "sfg_config"): raise SfgConfigException(SfgConfigSource.PROJECT, "Project configurator does not define function `sfg_config`.") project_config = configurator.sfg_config() if not isinstance(project_config, SfgConfiguration): raise SfgConfigException(SfgConfigSource.PROJECT, "sfg_config did not return a SfgConfiguration object.") - + return project_config @@ -115,12 +115,15 @@ def add_config_args_to_parser(parser: ArgumentParser): config_group = parser.add_argument_group("Configuration") config_group.add_argument("--sfg-output-dir", - type=str, default=None, dest='output_directory') + type=str, default=None, dest='output_directory') config_group.add_argument("--sfg-file-extensions", - type=str, default=None, dest='file_extensions', help="Comma-separated list of file extensions") + type=str, + default=None, + dest='file_extensions', + help="Comma-separated list of file extensions") config_group.add_argument("--sfg-header-only", default=None, action='store_true', dest='header_only') config_group.add_argument("--sfg-configurator", type=str, default=None, dest='configurator_script') - + return parser @@ -151,7 +154,7 @@ def config_from_commandline(argv: List[str]): parser = ArgumentParser("pystencilssfg", description="pystencils Source File Generator", allow_abbrev=False) - + add_config_args_to_parser(parser) args, script_args = parser.parse_known_args(argv) @@ -182,7 +185,8 @@ def merge_configurations(project_config: SfgConfiguration, for key, cmdline_value in cmdline_dict.items(): if cmdline_value is not None and script_dict[key] is not None: raise SfgException( - f"Conflicting configuration: Parameter {key} was specified both in the script and on the command line.") + "Conflicting configuration:" + + f" Parameter {key} was specified both in the script and on the command line.") config = config.override(script_config) diff --git a/pystencilssfg/context.py b/pystencilssfg/context.py index 954f3f0..69ad66c 100644 --- a/pystencilssfg/context.py +++ b/pystencilssfg/context.py @@ -1,5 +1,4 @@ from typing import Generator, Union, Optional, Sequence -from dataclasses import dataclass import sys import os @@ -19,7 +18,6 @@ from .source_concepts import SrcField, TypedSymbolOrObject from .source_components import SfgFunction, SfgHeaderInclude - class SourceFileGenerator: def __init__(self, sfg_config: SfgConfiguration = None): if sfg_config and not isinstance(sfg_config, SfgConfiguration): @@ -61,7 +59,7 @@ class SfgContext: # Source Components self._includes = set() - self._kernel_namespaces = { self._default_kernel_namespace.name : self._default_kernel_namespace } + self._kernel_namespaces = {self._default_kernel_namespace.name: self._default_kernel_namespace} self._functions = dict() @property @@ -71,15 +69,15 @@ class SfgContext: @property def root_namespace(self) -> str: return self._config.base_namespace - + @property def codestyle(self) -> SfgCodeStyle: return self._config.codestyle - #---------------------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------------------- # Source Component Getters - #---------------------------------------------------------------------------------------------- - + # ---------------------------------------------------------------------------------------------- + def includes(self) -> Generator[SfgHeaderInclude, None, None]: yield from self._includes @@ -89,9 +87,9 @@ class SfgContext: def functions(self) -> Generator[SfgFunction, None, None]: yield from self._functions.values() - #---------------------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------------------- # Source Component Adders - #---------------------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------------------- def add_include(self, include: SfgHeaderInclude): self._includes.add(include) @@ -99,14 +97,14 @@ class SfgContext: def add_function(self, func: SfgFunction): if func.name in self._functions: raise ValueError(f"Duplicate function: {func.name}") - + self._functions[func.name] = func for incl in CollectIncludes().visit(func._tree): self.add_include(incl) - #---------------------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------------------- # Factory-like Adders - #---------------------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------------------- @property def kernels(self) -> SfgKernelNamespace: @@ -125,23 +123,26 @@ class SfgContext: if header_file.startswith("<") and header_file.endswith(">"): header_file = header_file[1:-1] system_header = True - + self.add_include(SfgHeaderInclude(header_file, system_header=system_header)) - def function(self, + def function(self, name: str, - ast_or_kernel_handle : Optional[Union[KernelFunction, SfgKernelHandle]] = None): + ast_or_kernel_handle: Optional[Union[KernelFunction, SfgKernelHandle]] = None): if name in self._functions: raise ValueError(f"Duplicate function: {name}") - + if ast_or_kernel_handle is not None: if isinstance(ast_or_kernel_handle, KernelFunction): khandle = self._default_kernel_namespace.add(ast_or_kernel_handle) - tree = SfgKernelCallNode(self, khandle) + tree = SfgKernelCallNode(khandle) elif isinstance(ast_or_kernel_handle, SfgKernelCallNode): tree = ast_or_kernel_handle else: - raise TypeError(f"Invalid type of argument `ast_or_kernel_handle`!") + raise TypeError("Invalid type of argument `ast_or_kernel_handle`!") + + func = SfgFunction(self, name, tree) + self.add_function(func) else: def sequencer(*args: SfgCallTreeNode): tree = make_sequence(*args) @@ -149,25 +150,23 @@ class SfgContext: self.add_function(func) return sequencer - - #---------------------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------------------- # In-Sequence builders to be used within the second phase of SfgContext.function(). - #---------------------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------------------- def call(self, kernel_handle: SfgKernelHandle) -> SfgKernelCallNode: return SfgKernelCallNode(kernel_handle) - + @property def branch(self) -> SfgBranchBuilder: return SfgBranchBuilder() - + def map_field(self, field: Field, src_object: Optional[SrcField] = None) -> SfgSequence: if src_object is None: raise NotImplementedError("Automatic field extraction is not implemented yet.") else: return SfgDeferredFieldMapping(field, src_object) - + def map_param(self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str): return SfgStatements(mapping, (lhs,), (rhs,)) - \ No newline at end of file diff --git a/pystencilssfg/emitters/cpu/basic_cpu.py b/pystencilssfg/emitters/cpu/basic_cpu.py index fdb081c..69c7188 100644 --- a/pystencilssfg/emitters/cpu/basic_cpu.py +++ b/pystencilssfg/emitters/cpu/basic_cpu.py @@ -5,6 +5,7 @@ from os import path from ...configuration import SfgConfiguration from ...context import SfgContext + class BasicCpuEmitter: def __init__(self, basename: str, config: SfgConfiguration): self._basename = basename diff --git a/pystencilssfg/emitters/cpu/jinja_filters.py b/pystencilssfg/emitters/cpu/jinja_filters.py index d3602da..7152c98 100644 --- a/pystencilssfg/emitters/cpu/jinja_filters.py +++ b/pystencilssfg/emitters/cpu/jinja_filters.py @@ -6,18 +6,22 @@ from pystencils.backends import generate_c from pystencilssfg.source_components import SfgFunction + @pass_context def generate_kernel_definition(ctx, ast: KernelFunction): return generate_c(ast, dialect=Backend.C) + @pass_context def generate_function_parameter_list(ctx, func: SfgFunction): params = sorted(list(func.parameters), key=lambda p: p.name) return ", ".join(f"{param.dtype} {param.name}" for param in params) + def generate_function_body(func: SfgFunction): return func.get_code() + def add_filters_to_jinja(jinja_env): jinja_env.filters['generate_kernel_definition'] = generate_kernel_definition jinja_env.filters['generate_function_parameter_list'] = generate_function_parameter_list diff --git a/pystencilssfg/kernel_namespace.py b/pystencilssfg/kernel_namespace.py index 890ff60..1204e6c 100644 --- a/pystencilssfg/kernel_namespace.py +++ b/pystencilssfg/kernel_namespace.py @@ -3,6 +3,7 @@ from typing import Sequence from pystencils import CreateKernelConfig, create_kernel from pystencils.astnodes import KernelFunction + class SfgKernelNamespace: def __init__(self, ctx, name: str): self._ctx = ctx @@ -59,7 +60,7 @@ class SfgKernelHandle: @property def fully_qualified_name(self): return f"{self._ctx.root_namespace}::{self.kernel_namespace.name}::{self.kernel_name}" - + @property def parameters(self): return self._parameters @@ -71,4 +72,3 @@ class SfgKernelHandle: @property def fields(self): return self.fields - \ No newline at end of file diff --git a/pystencilssfg/source_components/__init__.py b/pystencilssfg/source_components/__init__.py index 307c354..143d02c 100644 --- a/pystencilssfg/source_components/__init__.py +++ b/pystencilssfg/source_components/__init__.py @@ -2,5 +2,5 @@ from .function import SfgFunction from .header_include import SfgHeaderInclude __all__ = [ - SfgFunction, SfgHeaderInclude -] \ No newline at end of file + "SfgFunction", "SfgHeaderInclude" +] diff --git a/pystencilssfg/source_components/function.py b/pystencilssfg/source_components/function.py index 26ff659..cb15809 100644 --- a/pystencilssfg/source_components/function.py +++ b/pystencilssfg/source_components/function.py @@ -1,21 +1,19 @@ from __future__ import annotations from typing import TYPE_CHECKING +from ..tree import SfgCallTreeNode +from ..tree.visitors import ExpandingParameterCollector + if TYPE_CHECKING: from ..context import SfgContext -from ..tree import SfgCallTreeNode -from ..tree.visitors import FlattenSequences, ExpandingParameterCollector class SfgFunction: def __init__(self, ctx: SfgContext, name: str, tree: SfgCallTreeNode): self._ctx = ctx self._name = name self._tree = tree - - flattener = FlattenSequences() - # flattener.visit(self._tree) - + param_collector = ExpandingParameterCollector(self._ctx) self._parameters = param_collector.visit(self._tree) @@ -26,11 +24,10 @@ class SfgFunction: @property def parameters(self): return self._parameters - + @property def tree(self): return self._tree def get_code(self): return self._tree.get_code(self._ctx) - diff --git a/pystencilssfg/source_components/header_include.py b/pystencilssfg/source_components/header_include.py index 3345456..fb4aad6 100644 --- a/pystencilssfg/source_components/header_include.py +++ b/pystencilssfg/source_components/header_include.py @@ -1,7 +1,8 @@ from __future__ import annotations + class SfgHeaderInclude: - def __init__(self, header_file: str, system_header : bool = False, private: bool = False): + def __init__(self, header_file: str, system_header: bool = False, private: bool = False): self._header_file = header_file self._system_header = system_header self._private = private @@ -9,7 +10,7 @@ class SfgHeaderInclude: @property def system_header(self): return self._system_header - + @property def private(self): return self._private @@ -19,12 +20,12 @@ class SfgHeaderInclude: return f"#include <{self._header_file}>" else: return f'#include "{self._header_file}"' - + def __hash__(self) -> int: return hash((self._header_file, self._system_header, self._private)) - + def __eq__(self, other: SfgHeaderInclude) -> bool: - return (isinstance(other, SfgHeaderInclude) + return (isinstance(other, SfgHeaderInclude) and self._header_file == other._header_file and self._system_header == other._system_header and self._private == other._private) diff --git a/pystencilssfg/source_concepts/__init__.py b/pystencilssfg/source_concepts/__init__.py index 0531fd2..6375f3e 100644 --- a/pystencilssfg/source_concepts/__init__.py +++ b/pystencilssfg/source_concepts/__init__.py @@ -1,5 +1,5 @@ from .source_objects import SrcObject, SrcField, SrcVector, TypedSymbolOrObject __all__ = [ - SrcObject, SrcField, SrcVector, TypedSymbolOrObject -] \ No newline at end of file + "SrcObject", "SrcField", "SrcVector", "TypedSymbolOrObject" +] diff --git a/pystencilssfg/source_concepts/cpp/__init__.py b/pystencilssfg/source_concepts/cpp/__init__.py index 4d3397a..206e2e3 100644 --- a/pystencilssfg/source_concepts/cpp/__init__.py +++ b/pystencilssfg/source_concepts/cpp/__init__.py @@ -1,6 +1,6 @@ from .std_mdspan import std_mdspan from .std_vector import std_vector, std_vector_ref -__all__= [ - std_mdspan, std_vector, std_vector_ref -] \ No newline at end of file +__all__ = [ + "std_mdspan", "std_vector", "std_vector_ref" +] diff --git a/pystencilssfg/source_concepts/cpp/std_mdspan.py b/pystencilssfg/source_concepts/cpp/std_mdspan.py index dcb5265..c87b5b1 100644 --- a/pystencilssfg/source_concepts/cpp/std_mdspan.py +++ b/pystencilssfg/source_concepts/cpp/std_mdspan.py @@ -8,21 +8,27 @@ from ...source_components.header_include import SfgHeaderInclude from ...types import PsType, cpp_typename from ...exceptions import SfgException + class std_mdspan(SrcField): dynamic_extent = "std::dynamic_extent" - def __init__(self, identifer: str, T: PsType, extents: Tuple[int, str], extents_type: PsType = int, reference: bool = False): + def __init__(self, identifer: str, + T: PsType, + extents: Tuple[int, str], + extents_type: PsType = int, + reference: bool = False): T = cpp_typename(T) extents_type = cpp_typename(extents_type) - typestring = f"std::mdspan< {T}, std::extents< {extents_type}, {', '.join(str(e) for e in extents)} > > {'&' if reference else ''}" + extents_str = f"std::extents< {extents_type}, {', '.join(str(e) for e in extents)} >" + typestring = f"std::mdspan< {T}, {extents_str} > {'&' if reference else ''}" super().__init__(typestring, identifer) self._extents = extents @property def required_includes(self) -> Set[SfgHeaderInclude]: - return { SfgHeaderInclude("experimental/mdspan", system_header=True) } + return {SfgHeaderInclude("experimental/mdspan", system_header=True)} def extract_ptr(self, ptr_symbol: FieldPointerSymbol): return SfgStatements( @@ -37,33 +43,35 @@ class std_mdspan(SrcField): if isinstance(size, FieldShapeSymbol): raise SfgException(f"Cannot extract size in coordinate {coordinate} from a {dim}-dimensional mdspan!") elif size != 1: - raise SfgException(f"Cannot map field with size {size} in coordinate {coordinate} to {dim}-dimensional mdspan!") + raise SfgException( + f"Cannot map field with size {size} in coordinate {coordinate} to {dim}-dimensional mdspan!") else: # trivial trailing index dimensions are OK -> do nothing return SfgStatements(f"// {self._identifier}.extents().extent({coordinate}) == 1", (), ()) if isinstance(size, FieldShapeSymbol): return SfgStatements( - f"{size.dtype} {size.name} = {self._identifier}.extents().extent({coordinate});", - (size, ), - (self, ) - ) + f"{size.dtype} {size.name} = {self._identifier}.extents().extent({coordinate});", + (size, ), + (self, ) + ) else: return SfgStatements( f"assert( {self._identifier}.extents().extent({coordinate}) == {size} );", (), (self, ) ) - + def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements: if coordinate >= len(self._extents): - raise SfgException(f"Cannot extract stride in coordinate {coordinate} from a {len(self._extents)}-dimensional mdspan") - + raise SfgException( + f"Cannot extract stride in coordinate {coordinate} from a {len(self._extents)}-dimensional mdspan") + if isinstance(stride, FieldStrideSymbol): return SfgStatements( - f"{stride.dtype} {stride.name} = {self._identifier}.stride({coordinate});", - (stride, ), - (self, ) - ) + f"{stride.dtype} {stride.name} = {self._identifier}.stride({coordinate});", + (stride, ), + (self, ) + ) else: return SfgStatements( f"assert( {self._identifier}.stride({coordinate}) == {stride} );", diff --git a/pystencilssfg/source_concepts/cpp/std_vector.py b/pystencilssfg/source_concepts/cpp/std_vector.py index d01e881..f203f06 100644 --- a/pystencilssfg/source_concepts/cpp/std_vector.py +++ b/pystencilssfg/source_concepts/cpp/std_vector.py @@ -1,4 +1,4 @@ -from typing import Set, Union, Tuple +from typing import Set, Union from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol @@ -9,6 +9,7 @@ from ...types import SrcType, PsType, cpp_typename from ...source_components.header_include import SfgHeaderInclude from ...exceptions import SfgException + class std_vector(SrcVector, SrcField): def __init__(self, identifer: str, T: Union[SrcType, PsType], unsafe: bool = False): typestring = f"std::vector< {cpp_typename(T)} >" @@ -19,39 +20,40 @@ class std_vector(SrcVector, SrcField): @property def required_includes(self) -> Set[SfgHeaderInclude]: - return { SfgHeaderInclude("vector", system_header=True) } - + return {SfgHeaderInclude("vector", system_header=True)} + def extract_ptr(self, ptr_symbol: FieldPointerSymbol): if ptr_symbol.dtype != self._element_type: if self._unsafe: mapping = f"{ptr_symbol.dtype} {ptr_symbol.name} = ({ptr_symbol.dtype}) {self._identifier}.data();" else: - raise SfgException("Field type and std::vector element type do not match, and unsafe extraction was not enabled.") + raise SfgException( + "Field type and std::vector element type do not match, and unsafe extraction was not enabled.") else: mapping = f"{ptr_symbol.dtype} {ptr_symbol.name} = {self._identifier}.data();" return SfgStatements(mapping, (ptr_symbol,), (self,)) - + def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements: if coordinate > 0: raise SfgException(f"Cannot extract size in coordinate {coordinate} from std::vector") if isinstance(size, FieldShapeSymbol): return SfgStatements( - f"{size.dtype} {size.name} = {self._identifier}.size();", - (size, ), - (self, ) - ) + f"{size.dtype} {size.name} = {self._identifier}.size();", + (size, ), + (self, ) + ) else: return SfgStatements( f"assert( {self._identifier}.size() == {size} );", (), (self, ) ) - + def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements: if coordinate > 0: raise SfgException(f"Cannot extract stride in coordinate {coordinate} from std::vector") - + if isinstance(stride, FieldStrideSymbol): return SfgStatements(f"{stride.dtype} {stride.name} = 1;", (stride, ), ()) else: diff --git a/pystencilssfg/source_concepts/source_objects.py b/pystencilssfg/source_concepts/source_objects.py index ef08675..6e2849c 100644 --- a/pystencilssfg/source_concepts/source_objects.py +++ b/pystencilssfg/source_concepts/source_objects.py @@ -1,10 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Union, Set, TypeAlias, NewType - -if TYPE_CHECKING: - from ..source_components import SfgHeaderInclude - from ..tree import SfgStatements, SfgSequence +from typing import TYPE_CHECKING, Optional, Union, Set, TypeAlias from abc import ABC, abstractmethod @@ -13,15 +9,20 @@ from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeS from ..types import SrcType +if TYPE_CHECKING: + from ..source_components import SfgHeaderInclude + from ..tree import SfgStatements, SfgSequence + + class SrcObject: """C/C++ object of nonprimitive type. - + Two objects are identical if they have the same identifier and type string.""" def __init__(self, src_type: SrcType, identifier: Optional[str]): self._src_type = src_type self._identifier = identifier - + @property def identifier(self): return self._identifier @@ -38,12 +39,12 @@ class SrcObject: @property def required_includes(self) -> Set[SfgHeaderInclude]: return set() - + def __hash__(self) -> int: return hash((self._identifier, self._src_type)) - + def __eq__(self, other: SrcObject) -> bool: - return (isinstance(other, SrcObject) + return (isinstance(other, SrcObject) and self._identifier == other._identifier and self._src_type == other._src_type) diff --git a/pystencilssfg/tree/__init__.py b/pystencilssfg/tree/__init__.py index 44603b7..b76d8cb 100644 --- a/pystencilssfg/tree/__init__.py +++ b/pystencilssfg/tree/__init__.py @@ -3,7 +3,7 @@ from .conditional import SfgBranch, SfgCondition from .builders import make_sequence __all__ = [ - SfgCallTreeNode, SfgKernelCallNode, SfgSequence, SfgBlock, SfgStatements, - SfgCondition, SfgBranch, - make_sequence -] \ No newline at end of file + "SfgCallTreeNode", "SfgKernelCallNode", "SfgSequence", "SfgBlock", "SfgStatements", + "SfgCondition", "SfgBranch", + "make_sequence" +] diff --git a/pystencilssfg/tree/basic_nodes.py b/pystencilssfg/tree/basic_nodes.py index 6fbf99a..a54e587 100644 --- a/pystencilssfg/tree/basic_nodes.py +++ b/pystencilssfg/tree/basic_nodes.py @@ -5,13 +5,14 @@ from abc import ABC, abstractmethod from itertools import chain from ..kernel_namespace import SfgKernelHandle -from ..source_concepts.source_objects import SrcObject, TypedSymbolOrObject +from ..source_concepts.source_objects import SrcObject, TypedSymbolOrObject from ..exceptions import SfgException if TYPE_CHECKING: from ..context import SfgContext from ..source_components import SfgHeaderInclude + class SfgCallTreeNode(ABC): """Base class for all nodes comprising SFG call trees. """ @@ -38,11 +39,11 @@ class SfgCallTreeNode(ABC): class SfgCallTreeLeaf(SfgCallTreeNode, ABC): - + @property def children(self) -> Sequence[SfgCallTreeNode]: return () - + def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None: raise SfgException("Leaf nodes have no children.") @@ -54,7 +55,7 @@ class SfgCallTreeLeaf(SfgCallTreeNode, ABC): class SfgStatements(SfgCallTreeLeaf): """Represents (a sequence of) statements in the source language. - + This class groups together arbitrary code strings (e.g. sequences of C++ statements, cf. https://en.cppreference.com/w/cpp/language/statements), and annotates them with the set of symbols read and written by these statements. @@ -74,7 +75,7 @@ class SfgStatements(SfgCallTreeLeaf): defined_params: Sequence[TypedSymbolOrObject], required_params: Sequence[TypedSymbolOrObject]): self._code_string = code_string - + self._defined_params = set(defined_params) self._required_params = set(required_params) @@ -82,19 +83,19 @@ class SfgStatements(SfgCallTreeLeaf): for obj in chain(required_params, defined_params): if isinstance(obj, SrcObject): self._required_includes |= obj.required_includes - + @property def required_parameters(self) -> Set[TypedSymbolOrObject]: return self._required_params - + @property def defined_parameters(self) -> Set[TypedSymbolOrObject]: return self._defined_params - + @property def required_includes(self) -> Set[SfgHeaderInclude]: return self._required_includes - + def get_code(self, ctx: SfgContext) -> str: return self._code_string @@ -106,28 +107,28 @@ class SfgSequence(SfgCallTreeNode): @property def children(self) -> Sequence[SfgCallTreeNode]: return self._children - + def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None: self._children[child_idx] = node - + def get_code(self, ctx: SfgContext) -> str: return "\n".join(c.get_code(ctx) for c in self._children) class SfgBlock(SfgCallTreeNode): def __init__(self, subtree: SfgCallTreeNode): - super().__init__(ctx) + super().__init__() self._subtree = subtree @property def children(self) -> Sequence[SfgCallTreeNode]: return [self._subtree] - + def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None: match child_idx: case 0: self._subtree = node case _: raise IndexError(f"Invalid child index: {child_idx}. SfgBlock has only a single child.") - + def get_code(self, ctx: SfgContext) -> str: subtree_code = ctx.codestyle.indent(self._subtree.get_code(ctx)) @@ -141,7 +142,7 @@ class SfgKernelCallNode(SfgCallTreeLeaf): @property def required_parameters(self) -> Set[TypedSymbolOrObject]: return set(p.symbol for p in self._kernel_handle.parameters) - + def get_code(self, ctx: SfgContext) -> str: ast_params = self._kernel_handle.parameters fnc_name = self._kernel_handle.fully_qualified_name diff --git a/pystencilssfg/tree/builders.py b/pystencilssfg/tree/builders.py index 5c35c56..0e006af 100644 --- a/pystencilssfg/tree/builders.py +++ b/pystencilssfg/tree/builders.py @@ -1,21 +1,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Sequence - -if TYPE_CHECKING: - from ..context import SfgContext from abc import ABC, abstractmethod -from pystencils import Field - from .basic_nodes import SfgCallTreeNode, SfgSequence, SfgBlock, SfgStatements from .conditional import SfgCondition, SfgCustomCondition, SfgBranch - + + class SfgNodeBuilder(ABC): @abstractmethod def resolve(self) -> SfgCallTreeNode: pass + def make_sequence(*args) -> SfgSequence: children = [] for i, arg in enumerate(args): @@ -27,13 +23,13 @@ def make_sequence(*args) -> SfgSequence: children.append(SfgStatements(arg, (), ())) elif isinstance(arg, tuple): # Tuples are treated as blocks - subseq = self(*arg) + subseq = make_sequence(*arg) children.append(SfgBlock(subseq)) else: raise TypeError(f"Sequence argument {i} has invalid type.") - + return SfgSequence(children) - + class SfgBranchBuilder(SfgNodeBuilder): def __init__(self): @@ -45,30 +41,30 @@ class SfgBranchBuilder(SfgNodeBuilder): def __call__(self, *args) -> SfgBranchBuilder: match self._phase: - case 0: # Condition + case 0: # Condition if len(args) != 1: raise ValueError("Must specify exactly one argument as branch condition!") - + cond = args[0] - + if isinstance(cond, str): cond = SfgCustomCondition(cond) elif not isinstance(cond, SfgCondition): - raise ValueError("Invalid type for branch condition. Must be either `str` or a subclass of `SfgCondition`.") - + raise ValueError( + "Invalid type for branch condition. Must be either `str` or a subclass of `SfgCondition`.") + self._cond = cond - case 1: # Then-branch + case 1: # Then-branch self._branch_true = make_sequence(*args) - case 2: # Else-branch + case 2: # Else-branch self._branch_false = make_sequence(*args) - case _: # There's no third branch! + case _: # There's no third branch! raise TypeError("Branch construct already complete.") self._phase += 1 return self - + def resolve(self) -> SfgCallTreeNode: return SfgBranch(self._cond, self._branch_true, self._branch_false) - diff --git a/pystencilssfg/tree/conditional.py b/pystencilssfg/tree/conditional.py index 8731344..a2f3862 100644 --- a/pystencilssfg/tree/conditional.py +++ b/pystencilssfg/tree/conditional.py @@ -1,18 +1,17 @@ from __future__ import annotations from typing import TYPE_CHECKING, Sequence, Optional, Set +from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf +from ..source_concepts.source_objects import TypedSymbolOrObject + if TYPE_CHECKING: from ..context import SfgContext -from jinja2.filters import do_indent -from pystencils.typing import TypedSymbol - -from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf -from ..source_concepts.source_objects import TypedSymbolOrObject class SfgCondition(SfgCallTreeLeaf): pass + class SfgCustomCondition(SfgCondition): def __init__(self, cond_text: str): self._cond_text = cond_text @@ -22,40 +21,42 @@ class SfgCustomCondition(SfgCondition): def get_code(self, ctx: SfgContext) -> str: return self._cond_text - + # class IntEven(SfgCondition): # def __init__(self, ) class SfgBranch(SfgCallTreeNode): - def __init__(self, cond: SfgCondition, branch_true: SfgCallTreeNode, branch_false: Optional[SfgCallTreeNode] = None): + def __init__(self, + cond: SfgCondition, + branch_true: SfgCallTreeNode, + branch_false: Optional[SfgCallTreeNode] = None): self._cond = cond self._branch_true = branch_true self._branch_false = branch_false - + @property def children(self) -> Sequence[SfgCallTreeNode]: if self._branch_false is not None: return (self._branch_true, self._branch_false) else: return (self._branch_true,) - + def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None: match child_idx: case 0: self._branch_true = node case 1: self._branch_false = node case _: raise IndexError(f"Invalid child index: {child_idx}. SfgBlock has only two children.") - + def get_code(self, ctx: SfgContext) -> str: code = f"if({self._cond.get_code(ctx)}) {{\n" code += ctx.codestyle.indent(self._branch_true.get_code(ctx)) code += "\n}" - + if self._branch_false is not None: code += "else {\n" code += ctx.codestyle.indent(self._branch_false.get_code(ctx)) code += "\n}" return code - diff --git a/pystencilssfg/tree/deferred_nodes.py b/pystencilssfg/tree/deferred_nodes.py index b14ba2d..8e31dfb 100644 --- a/pystencilssfg/tree/deferred_nodes.py +++ b/pystencilssfg/tree/deferred_nodes.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Sequence, Set, Union, Iterable +from typing import TYPE_CHECKING, Sequence, Set if TYPE_CHECKING: from ..context import SfgContext @@ -18,10 +18,11 @@ from ..source_concepts import SrcField class SfgDeferredNode(SfgCallTreeNode, ABC): - """Nodes of this type are inserted as placeholders into the kernel call tree and need to be expanded at a later time. - - Subclasses of SfgDeferredNode correspond to nodes that cannot be created yet because information required for their - construction is not yet known. + """Nodes of this type are inserted as placeholders into the kernel call tree + and need to be expanded at a later time. + + Subclasses of SfgDeferredNode correspond to nodes that cannot be created yet + because information required for their construction is not yet known. """ @property @@ -33,7 +34,7 @@ class SfgDeferredNode(SfgCallTreeNode, ABC): def get_code(self, ctx: SfgContext) -> str: raise SfgException("Deferred nodes can not generate code; they need to be expanded first.") - + @abstractmethod def expand(self, ctx: SfgContext, *args, **kwargs) -> SfgCallTreeNode: pass @@ -64,7 +65,7 @@ class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode): for c, s in enumerate(self._field.shape): if isinstance(s, FieldShapeSymbol) and s not in visible_params: continue - else: + else: shape.append((c, s)) # Find required strides @@ -72,7 +73,7 @@ class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode): for c, s in enumerate(self._field.strides): if isinstance(s, FieldStrideSymbol) and s not in visible_params: continue - else: + else: strides.append((c, s)) return make_sequence( diff --git a/pystencilssfg/types.py b/pystencilssfg/types.py index fe12909..a0ea1b1 100644 --- a/pystencilssfg/types.py +++ b/pystencilssfg/types.py @@ -37,4 +37,3 @@ def cpp_typename(type_obj: Union[str, SrcType, PsType]): return numpy_name_to_c(np.dtype(type_obj).name) else: raise ValueError(f"Don't know how to interpret type object {type_obj}.") - -- GitLab