Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Sequence, Set, Union, Iterable
if TYPE_CHECKING:
from ..context import SfgContext
from abc import ABC, abstractmethod
from pystencils import Field, TypedSymbol
from pystencils.typing import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from ..exceptions import SfgException
from .basic_nodes import SfgCallTreeNode
from .builders import make_sequence
from ..source_concepts.containers import SrcField
class SfgDeferredNode(SfgCallTreeNode, ABC):
"""Nodes of this type are inserted as placeholders into the kernel call tree and need to be expanded at a later time.
Subclasses of SfgDeferredNode correspond to nodes that cannot be created yet because information required for their
construction is not yet known.
"""
@property
def children(self) -> Sequence[SfgCallTreeNode]:
raise SfgException("Deferred nodes cannot be descended into; expand it first.")
def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None:
raise SfgException("Deferred nodes do not have children.")
def get_code(self, ctx: SfgContext) -> str:
raise SfgException("Deferred nodes can not generate code; they need to be expanded first.")
@abstractmethod
def expand(self, ctx: SfgContext, *args, **kwargs) -> SfgCallTreeNode:
pass
class SfgParamCollectionDeferredNode(SfgDeferredNode, ABC):
@abstractmethod
def expand(self, ctx: SfgContext, visible_params: Set[TypedSymbol]) -> SfgCallTreeNode:
pass
class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode):
def __init__(self, field: Field, src_field: SrcField):
self._field = field
self._src_field = src_field
def expand(self, ctx: SfgContext, visible_params: Set[TypedSymbol]) -> SfgCallTreeNode:
# Find field pointer
ptr = None
for param in visible_params:
if isinstance(param, FieldPointerSymbol) and param.field_name == self._field.name:
if param.dtype.base_type != self._field.dtype:
raise SfgException("Data type mismatch between field and encountered pointer symbol")
ptr = param
# Find required sizes
shape = []
for c, s in enumerate(self._field.shape):
if isinstance(s, FieldShapeSymbol) and s not in visible_params:
continue
else:
shape.append((c, s))
# Find required strides
strides = []
for c, s in enumerate(self._field.strides):
if isinstance(s, FieldStrideSymbol) and s not in visible_params:
continue
else:
strides.append((c, s))
return make_sequence(
self._src_field.extract_ptr(ptr),
*(self._src_field.extract_size(c, s) for c, s in shape),
*(self._src_field.extract_stride(c, s) for c, s in strides)
)