Skip to content
Snippets Groups Projects
deferred_nodes.py 2.89 KiB
Newer Older
from __future__ import annotations
Frederik Hennig's avatar
Frederik Hennig committed
from typing import TYPE_CHECKING, Sequence, Set

if TYPE_CHECKING:
    from ..context import SfgContext

from abc import ABC, abstractmethod

from pystencils import Field, TypedSymbol
from pystencils.typing import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol

from ..exceptions import SfgException

from .basic_nodes import SfgCallTreeNode
from .builders import make_sequence

from ..source_concepts import SrcField


class SfgDeferredNode(SfgCallTreeNode, ABC):
Frederik Hennig's avatar
Frederik Hennig committed
    """Nodes of this type are inserted as placeholders into the kernel call tree
    and need to be expanded at a later time.

    Subclasses of SfgDeferredNode correspond to nodes that cannot be created yet
    because information required for their construction is not yet known.
    """

    @property
    def children(self) -> Sequence[SfgCallTreeNode]:
        raise SfgException("Deferred nodes cannot be descended into; expand it first.")

    def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None:
        raise SfgException("Deferred nodes do not have children.")

    def get_code(self, ctx: SfgContext) -> str:
        raise SfgException("Deferred nodes can not generate code; they need to be expanded first.")
    @abstractmethod
    def expand(self, ctx: SfgContext, *args, **kwargs) -> SfgCallTreeNode:
        pass


class SfgParamCollectionDeferredNode(SfgDeferredNode, ABC):
    @abstractmethod
    def expand(self, ctx: SfgContext, visible_params: Set[TypedSymbol]) -> SfgCallTreeNode:
        pass


class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode):
    def __init__(self, field: Field, src_field: SrcField):
        self._field = field
        self._src_field = src_field

    def expand(self, ctx: SfgContext, visible_params: Set[TypedSymbol]) -> SfgCallTreeNode:
        #    Find field pointer
        ptr = None
        for param in visible_params:
            if isinstance(param, FieldPointerSymbol) and param.field_name == self._field.name:
                if param.dtype.base_type != self._field.dtype:
                    raise SfgException("Data type mismatch between field and encountered pointer symbol")
                ptr = param

        #   Find required sizes
        shape = []
        for c, s in enumerate(self._field.shape):
            if isinstance(s, FieldShapeSymbol) and s not in visible_params:
                continue
Frederik Hennig's avatar
Frederik Hennig committed
            else:
                shape.append((c, s))

        #   Find required strides
        strides = []
        for c, s in enumerate(self._field.strides):
            if isinstance(s, FieldStrideSymbol) and s not in visible_params:
                continue
Frederik Hennig's avatar
Frederik Hennig committed
            else:
                strides.append((c, s))

        return make_sequence(
            self._src_field.extract_ptr(ptr),
            *(self._src_field.extract_size(c, s) for c, s in shape),
            *(self._src_field.extract_stride(c, s) for c, s in strides)
        )