Skip to content
Snippets Groups Projects
deferred_nodes.py 2.78 KiB
Newer Older
from __future__ import annotations
Frederik Hennig's avatar
Frederik Hennig committed
from typing import TYPE_CHECKING

from abc import ABC, abstractmethod

Frederik Hennig's avatar
Frederik Hennig committed
from pystencils import Field
from pystencils.typing import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol

from ..exceptions import SfgException

from .basic_nodes import SfgCallTreeNode
from .builders import make_sequence

from ..source_concepts import SrcField
Frederik Hennig's avatar
Frederik Hennig committed
from ..source_concepts.source_objects import TypedSymbolOrObject
Frederik Hennig's avatar
Frederik Hennig committed
if TYPE_CHECKING:
    from ..context import SfgContext


class SfgDeferredNode(SfgCallTreeNode, ABC):
Frederik Hennig's avatar
Frederik Hennig committed
    """Nodes of this type are inserted as placeholders into the kernel call tree
    and need to be expanded at a later time.

    Subclasses of SfgDeferredNode correspond to nodes that cannot be created yet
    because information required for their construction is not yet known.
    class InvalidAccess:
        def __get__(self):
            raise SfgException("Invalid access into deferred node; deferred nodes must be expanded first.")
    def __init__(self):
        self._children = SfgDeferredNode.InvalidAccess
Frederik Hennig's avatar
Frederik Hennig committed
    def get_code(self, ctx: SfgContext) -> str:
        raise SfgException("Invalid access into deferred node; deferred nodes must be expanded first.")


class SfgParamCollectionDeferredNode(SfgDeferredNode, ABC):
    @abstractmethod
Frederik Hennig's avatar
Frederik Hennig committed
    def expand(self, ctx: SfgContext, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode:
        pass


class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode):
    def __init__(self, field: Field, src_field: SrcField):
        self._field = field
        self._src_field = src_field

Frederik Hennig's avatar
Frederik Hennig committed
    def expand(self, ctx: SfgContext, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode:
        #    Find field pointer
        ptr = None
        for param in visible_params:
            if isinstance(param, FieldPointerSymbol) and param.field_name == self._field.name:
                if param.dtype.base_type != self._field.dtype:
                    raise SfgException("Data type mismatch between field and encountered pointer symbol")
                ptr = param

        #   Find required sizes
        shape = []
        for c, s in enumerate(self._field.shape):
            if isinstance(s, FieldShapeSymbol) and s not in visible_params:
                continue
Frederik Hennig's avatar
Frederik Hennig committed
            else:
                shape.append((c, s))

        #   Find required strides
        strides = []
        for c, s in enumerate(self._field.strides):
            if isinstance(s, FieldStrideSymbol) and s not in visible_params:
                continue
Frederik Hennig's avatar
Frederik Hennig committed
            else:
                strides.append((c, s))

        return make_sequence(
            self._src_field.extract_ptr(ptr),
            *(self._src_field.extract_size(c, s) for c, s in shape),
            *(self._src_field.extract_stride(c, s) for c, s in strides)
        )