Newer
Older
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Sequence, Set, Union, Iterable
if TYPE_CHECKING:
from ..context import SfgContext
from abc import ABC, abstractmethod
from pystencils import Field, TypedSymbol
from pystencils.typing import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from ..exceptions import SfgException
from .basic_nodes import SfgCallTreeNode
from .builders import make_sequence
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class SfgDeferredNode(SfgCallTreeNode, ABC):
"""Nodes of this type are inserted as placeholders into the kernel call tree and need to be expanded at a later time.
Subclasses of SfgDeferredNode correspond to nodes that cannot be created yet because information required for their
construction is not yet known.
"""
@property
def children(self) -> Sequence[SfgCallTreeNode]:
raise SfgException("Deferred nodes cannot be descended into; expand it first.")
def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None:
raise SfgException("Deferred nodes do not have children.")
def get_code(self, ctx: SfgContext) -> str:
raise SfgException("Deferred nodes can not generate code; they need to be expanded first.")
@abstractmethod
def expand(self, ctx: SfgContext, *args, **kwargs) -> SfgCallTreeNode:
pass
class SfgParamCollectionDeferredNode(SfgDeferredNode, ABC):
@abstractmethod
def expand(self, ctx: SfgContext, visible_params: Set[TypedSymbol]) -> SfgCallTreeNode:
pass
class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode):
def __init__(self, field: Field, src_field: SrcField):
self._field = field
self._src_field = src_field
def expand(self, ctx: SfgContext, visible_params: Set[TypedSymbol]) -> SfgCallTreeNode:
# Find field pointer
ptr = None
for param in visible_params:
if isinstance(param, FieldPointerSymbol) and param.field_name == self._field.name:
if param.dtype.base_type != self._field.dtype:
raise SfgException("Data type mismatch between field and encountered pointer symbol")
ptr = param
# Find required sizes
shape = []
for c, s in enumerate(self._field.shape):
if isinstance(s, FieldShapeSymbol) and s not in visible_params:
continue
else:
shape.append((c, s))
# Find required strides
strides = []
for c, s in enumerate(self._field.strides):
if isinstance(s, FieldStrideSymbol) and s not in visible_params:
continue
else:
strides.append((c, s))
return make_sequence(
self._src_field.extract_ptr(ptr),
*(self._src_field.extract_size(c, s) for c, s in shape),
*(self._src_field.extract_stride(c, s) for c, s in strides)
)