Skip to content
Snippets Groups Projects
source_objects.py 3.25 KiB
Newer Older
from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Union, Set, TypeAlias, NewType

if TYPE_CHECKING:
    from ..source_components import SfgHeaderInclude
    from ..tree import SfgStatements, SfgSequence

from numpy import dtype

from abc import ABC, abstractmethod

from pystencils import TypedSymbol, Field
from pystencils.typing import AbstractType, FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol

PsType: TypeAlias = Union[type, dtype, AbstractType]
"""Types used in interacting with pystencils.

PsType represents various ways of specifying types within pystencils.
In particular, it encompasses most ways to construct an instance of `AbstractType`,
for example via `create_type`.

(Note that, while `create_type` does accept strings, they are excluded here for
reasons of safety. It is discouraged to use strings for type specifications when working
with pystencils!)
"""

SrcType = NewType('SrcType', str)
"""Nonprimitive C/C++-Types occuring during source file generation.

Nonprimitive C/C++ types are represented by their names.
When necessary, the SFG package checks equality of types by these name strings; it does
not care about typedefs, aliases, namespaces, etc!
"""


class SrcObject:
    """C/C++ object of nonprimitive type.
    
    Two objects are identical if they have the same identifier and type string."""

    def __init__(self, src_type: SrcType, identifier: Optional[str]):
        self._src_type = src_type
        self._identifier = identifier
    
    @property
    def identifier(self):
        return self._identifier

    @property
    def name(self):
        """For interface compatibility with ps.TypedSymbol"""
        return self._identifier

    @property
    def dtype(self):
        return self._src_type

    @property
    def required_includes(self) -> Set[SfgHeaderInclude]:
        return set()
    
    def __hash__(self) -> int:
        return hash((self._identifier, self._src_type))
    
    def __eq__(self, other: SrcObject) -> bool:
        return (isinstance(other, SrcObject) 
                and self._identifier == other._identifier
                and self._src_type == other._src_type)


TypedSymbolOrObject: TypeAlias = Union[TypedSymbol, SrcObject]


class SrcField(SrcObject, ABC):
    def __init__(self, src_type: SrcType, identifier: Optional[str]):
        super().__init__(src_type, identifier)

    @abstractmethod
    def extract_ptr(self, ptr_symbol: FieldPointerSymbol) -> SfgStatements:
        pass

    @abstractmethod
    def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements:
        pass

    @abstractmethod
    def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements:
        pass

    def extract_parameters(self, field: Field) -> SfgSequence:
        ptr = FieldPointerSymbol(field.name, field.dtype, False)

        from ..tree import make_sequence

        return make_sequence(
            self.extract_ptr(ptr),
            *(self.extract_size(c, s) for c, s in enumerate(field.shape)),
            *(self.extract_stride(c, s) for c, s in enumerate(field.strides))
        )


class SrcVector(SrcObject):
    @abstractmethod
    def extract_component(self, destination: TypedSymbolOrObject, coordinate: int):
        pass