Skip to content
Snippets Groups Projects
Commit b90499a3 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

some cleanup

 - move numpy_struct to class composer and use its sequencers
 - fix some type errors
parent fc465a22
Branches
Tags
1 merge request!17Improved Source File and Code Structure Modelling
Pipeline #73631 failed with stages
in 1 minute and 4 seconds
This commit is part of merge request !17. Comments created here will be created in the context of that merge request.
from __future__ import annotations
from typing import Sequence, TypeAlias
from abc import ABC, abstractmethod
import numpy as np
import sympy as sp
from functools import reduce
from warnings import warn
......@@ -33,10 +32,6 @@ from ..ir.source_components import (
SfgFunction,
SfgKernelNamespace,
SfgKernelHandle,
SfgClass,
SfgConstructor,
SfgMemberVariable,
SfgClassKeyword,
SfgEntityDecl,
SfgEntityDef,
SfgNamespaceBlock,
......@@ -86,7 +81,7 @@ SequencerArg: TypeAlias = tuple | ExprLike | SfgCallTreeNode | SfgNodeBuilder
class KernelsAdder:
def __init__(self, ctx: SfgContext, loc: SfgNamespaceBlock):
self._ctx = ctx
self._loc = SfgNamespaceBlock
self._loc = loc
assert isinstance(loc.namespace, SfgKernelNamespace)
self._kernel_namespace = loc.namespace
......@@ -110,6 +105,8 @@ class KernelsAdder:
khandle = SfgKernelHandle(kernel_name, self._kernel_namespace, kernel)
self._kernel_namespace.add_kernel(khandle)
self._loc.elements.append(SfgEntityDef(khandle))
for header in kernel.required_headers:
assert self._ctx.impl_file is not None
self._ctx.impl_file.includes.append(HeaderFile.parse(header))
......@@ -242,7 +239,7 @@ class SfgBasicComposer(SfgIComposer):
self._cursor.write_impl(kns_block)
return KernelsAdder(self._ctx, kns_block)
def include(self, header_file: str | HeaderFile, private: bool = False):
def include(self, header: str | HeaderFile, private: bool = False):
"""Include a header file.
Args:
......@@ -262,7 +259,7 @@ class SfgBasicComposer(SfgIComposer):
#include <vector>
#include "custom.h"
"""
header_file = HeaderFile.parse(header_file)
header_file = HeaderFile.parse(header)
if private:
if self._ctx.impl_file is None:
......@@ -273,21 +270,6 @@ class SfgBasicComposer(SfgIComposer):
else:
self._ctx.header_file.includes.append(header_file)
def numpy_struct(
self, name: str, dtype: np.dtype, add_constructor: bool = True
) -> SfgClass:
"""Add a numpy structured data type as a C++ struct
Returns:
The created class object
"""
cls = self._struct_from_numpy_dtype(
name, dtype, add_constructor=add_constructor
)
self._cursor.add_entity(cls)
self._cursor.write_header(SfgEntityDecl(cls))
return cls
def kernel_function(self, name: str, kernel: Kernel | SfgKernelHandle):
"""Create a function comprising just a single kernel call.
......@@ -295,9 +277,11 @@ class SfgBasicComposer(SfgIComposer):
ast_or_kernel_handle: Either a pystencils AST, or a kernel handle for an already registered AST.
"""
if isinstance(kernel, Kernel):
kernel = self.kernels.add(kernel, name)
khandle = self.kernels.add(kernel, name)
else:
khandle = kernel
self.function(name)(self.call(kernel))
self.function(name)(self.call(khandle))
def function(
self,
......@@ -536,41 +520,6 @@ class SfgBasicComposer(SfgIComposer):
]
return SfgDeferredVectorMapping(components, rhs)
def _struct_from_numpy_dtype(
self, struct_name: str, dtype: np.dtype, add_constructor: bool = True
):
cls = SfgClass(
struct_name,
self._cursor.current_namespace,
class_keyword=SfgClassKeyword.STRUCT,
)
fields = dtype.fields
if fields is None:
raise SfgException(f"Numpy dtype {dtype} is not a structured type.")
constr_params = []
constr_inits = []
for member_name, type_info in fields.items():
member_type = create_type(type_info[0])
member = SfgMemberVariable(member_name, member_type)
arg = SfgVar(f"{member_name}_", member_type)
cls.default.append_member(member)
constr_params.append(arg)
constr_inits.append(f"{member}({arg})")
if add_constructor:
cls.default.append_member(
SfgEntityDef(SfgConstructor(constr_params, constr_inits))
)
return cls
def make_statements(arg: ExprLike) -> SfgStatements:
return SfgStatements(str(arg), (), depends(arg), includes(arg))
......
from __future__ import annotations
from typing import Sequence
from itertools import takewhile, dropwhile
import numpy as np
from pystencils.types import PsCustomType, UserTypeSpec, create_type
......@@ -196,6 +197,16 @@ class SfgClassComposer(SfgComposerMixIn):
"""
return self._class(class_name, SfgClassKeyword.STRUCT, bases)
def numpy_struct(
self, name: str, dtype: np.dtype, add_constructor: bool = True
):
"""Add a numpy structured data type as a C++ struct
Returns:
The created class object
"""
return self._struct_from_numpy_dtype(name, dtype, add_constructor)
@property
def public(self) -> SfgClassComposer.VisibilityBlockSequencer:
"""Create a `public` visibility block in a class body"""
......@@ -241,6 +252,8 @@ class SfgClassComposer(SfgComposerMixIn):
# INTERNALS
def _class(self, class_name: str, keyword: SfgClassKeyword, bases: Sequence[str]):
# TODO: Return a `CppClass` instance representing the generated class
if self._cursor.get_entity(class_name) is not None:
raise ValueError(
f"Another entity with name {class_name} already exists in the current namespace."
......@@ -288,3 +301,30 @@ class SfgClassComposer(SfgComposerMixIn):
self._cursor.write_header(SfgEntityDef(cls))
return sequencer
def _struct_from_numpy_dtype(
self, struct_name: str, dtype: np.dtype, add_constructor: bool = True
):
fields = dtype.fields
if fields is None:
raise SfgException(f"Numpy dtype {dtype} is not a structured type.")
members: list[SfgClassComposer.ConstructorBuilder | SfgVar] = []
if add_constructor:
ctor = self.constructor()
members.append(ctor)
for member_name, type_info in fields.items():
member_type = create_type(type_info[0])
member = SfgVar(member_name, member_type)
members.append(member)
if add_constructor:
arg = SfgVar(f"{member_name}_", member_type)
ctor.add_param(arg)
ctor.init(member)(arg)
return self.struct(
struct_name,
)(*members)
......@@ -15,8 +15,6 @@ from .call_tree import (
)
from .source_components import (
SfgHeaderInclude,
SfgEmptyLines,
SfgKernelNamespace,
SfgKernelHandle,
SfgKernelParamVar,
......@@ -25,7 +23,6 @@ from .source_components import (
SfgClassKeyword,
SfgClassMember,
SfgVisibilityBlock,
SfgInClassDefinition,
SfgMemberVariable,
SfgMethod,
SfgConstructor,
......@@ -47,8 +44,6 @@ __all__ = [
"SfgBranch",
"SfgSwitchCase",
"SfgSwitch",
"SfgHeaderInclude",
"SfgEmptyLines",
"SfgKernelNamespace",
"SfgKernelHandle",
"SfgKernelParamVar",
......@@ -57,7 +52,6 @@ __all__ = [
"SfgClassKeyword",
"SfgClassMember",
"SfgVisibilityBlock",
"SfgInClassDefinition",
"SfgMemberVariable",
"SfgMethod",
"SfgConstructor",
......
......@@ -404,18 +404,7 @@ class SfgConstructor(SfgClassMember):
class SfgClass(SfgCodeEntity):
"""Models a C++ class.
### Adding members to classes
Members are never added directly to a class. Instead, they are added to
an [SfgVisibilityBlock][pystencilssfg.source_components.SfgVisibilityBlock]
which defines their syntactic position and visibility modifier in the code.
At the top of every class, there is a default visibility block
accessible through the `default` property.
To add members with custom visibility, create a new SfgVisibilityBlock,
add members to the block, and add the block using `append_visibility_block`.
"""
"""A C++ class."""
__match_args__ = ("class_name",)
......@@ -524,12 +513,6 @@ class SfgClass(SfgCodeEntity):
self._member_vars[variable.name] = variable
SourceEntity_T = TypeVar(
"SourceEntity_T", bound=SfgFunction | SfgClassMember | SfgClass, covariant=True
)
"""Source entities that may have declarations and definitions."""
# =========================================================================================================
#
# SYNTACTICAL ELEMENTS
......@@ -540,6 +523,12 @@ SourceEntity_T = TypeVar(
# =========================================================================================================
SourceEntity_T = TypeVar(
"SourceEntity_T", bound=SfgKernelHandle | SfgFunction | SfgClassMember | SfgClass, covariant=True
)
"""Source entities that may have declarations and definitions."""
class SfgEntityDecl(Generic[SourceEntity_T]):
"""Declaration of a function, class, method, or constructor"""
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment