Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Union, Set, TypeAlias, NewType
if TYPE_CHECKING:
from ..source_components import SfgHeaderInclude
from ..tree import SfgStatements, SfgSequence
from numpy import dtype
from abc import ABC, abstractmethod
from pystencils import TypedSymbol, Field
from pystencils.typing import AbstractType, FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol
PsType: TypeAlias = Union[type, dtype, AbstractType]
"""Types used in interacting with pystencils.
PsType represents various ways of specifying types within pystencils.
In particular, it encompasses most ways to construct an instance of `AbstractType`,
for example via `create_type`.
(Note that, while `create_type` does accept strings, they are excluded here for
reasons of safety. It is discouraged to use strings for type specifications when working
with pystencils!)
"""
SrcType = NewType('SrcType', str)
"""Nonprimitive C/C++-Types occuring during source file generation.
Nonprimitive C/C++ types are represented by their names.
When necessary, the SFG package checks equality of types by these name strings; it does
not care about typedefs, aliases, namespaces, etc!
"""
class SrcObject:
"""C/C++ object of nonprimitive type.
Two objects are identical if they have the same identifier and type string."""
def __init__(self, src_type: SrcType, identifier: Optional[str]):
self._src_type = src_type
self._identifier = identifier
@property
def identifier(self):
return self._identifier
@property
def name(self):
"""For interface compatibility with ps.TypedSymbol"""
return self._identifier
@property
def dtype(self):
return self._src_type
@property
def required_includes(self) -> Set[SfgHeaderInclude]:
return set()
def __hash__(self) -> int:
return hash((self._identifier, self._src_type))
def __eq__(self, other: SrcObject) -> bool:
return (isinstance(other, SrcObject)
and self._identifier == other._identifier
and self._src_type == other._src_type)
TypedSymbolOrObject: TypeAlias = Union[TypedSymbol, SrcObject]
class SrcField(SrcObject, ABC):
def __init__(self, src_type: SrcType, identifier: Optional[str]):
super().__init__(src_type, identifier)
@abstractmethod
def extract_ptr(self, ptr_symbol: FieldPointerSymbol) -> SfgStatements:
pass
@abstractmethod
def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements:
pass
@abstractmethod
def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements:
pass
def extract_parameters(self, field: Field) -> SfgSequence:
ptr = FieldPointerSymbol(field.name, field.dtype, False)
from ..tree import make_sequence
return make_sequence(
self.extract_ptr(ptr),
*(self.extract_size(c, s) for c, s in enumerate(field.shape)),
*(self.extract_stride(c, s) for c, s in enumerate(field.strides))
)
class SrcVector(SrcObject):
@abstractmethod
def extract_component(self, destination: TypedSymbolOrObject, coordinate: int):
pass