Skip to content
Snippets Groups Projects
data_types.py 18.10 KiB
import ctypes
import sympy as sp
import numpy as np
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
from sympy.core.cache import cacheit

from pystencils.cache import memorycache
from pystencils.utils import allEqual


# to work in conditions of sp.Piecewise castFunc has to be of type Relational as well
class castFunc(sp.Function, sp.Rel):
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()

    @property
    def is_commutative(self):
        return self.args[0].is_commutative


class pointerArithmeticFunc(sp.Function, sp.Rel):

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


class TypedSymbol(sp.Symbol):
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
        return obj

    def __new_stage2__(cls, name, dtype):
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
        try:
            obj._dtype = createType(dtype)
        except TypeError:
            # on error keep the string
            obj._dtype = dtype
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
        superClassContents = list(super(TypedSymbol, self)._hashable_content())
        return tuple(superClassContents + [hash(self._dtype)])

    def __getnewargs__(self):
        return self.name, self.dtype


def createType(specification):
    """
    Create a subclass of Type according to a string or an object of subclass Type
    :param specification: Type object, or a string
    :return: Type object, or a new Type object parsed from the string
    """
    if isinstance(specification, Type):
        return specification
    elif isinstance(specification, str):
        return createTypeFromString(specification)
    else:
        npDataType = np.dtype(specification)
        if npDataType.fields is None:
            return BasicType(npDataType, const=False)
        else:
            return StructType(npDataType, const=False)


@memorycache(maxsize=64)
def createTypeFromString(specification):
    """
    Creates a new Type object from a c-like string specification
    :param specification: Specification string
    :return: Type object
    """
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
        else:
            current.append(s)
    if len(current) > 0:
        parts.append(current)
        # Parse native part
    basePart = parts.pop(0)
    const = False
    if 'const' in basePart:
        const = True
        basePart.remove('const')
    assert len(basePart) == 1
    if basePart[0][-1] == "*":
        basePart[0] = basePart[0][:-1]
        parts.append('*')
    try:
        baseType = BasicType(basePart[0], const)
    except TypeError:
        baseType = BasicType(createTypeFromString.map[basePart[0]], const)
    currentType = baseType
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
        currentType = PointerType(currentType, const, restrict)
    return currentType

createTypeFromString.map = {
    'i64': np.int64,
    'i32': np.int32,
    'i16': np.int16,
    'i8': np.int8,
}


def getBaseType(type):
    while type.baseType is not None:
        type = type.baseType
    return type


def toCtypes(dataType):
    """
    Transforms a given Type into ctypes
    :param dataType: Subclass of Type
    :return: ctypes type object
    """
    if isinstance(dataType, PointerType):
        return ctypes.POINTER(toCtypes(dataType.baseType))
    elif isinstance(dataType, StructType):
        return ctypes.POINTER(ctypes.c_uint8)
    else:
        return toCtypes.map[dataType.numpyDtype]

toCtypes.map = {
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


def ctypes_from_llvm(data_type):
    if not ir:
        raise _ir_importerror
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
    if not ir:
        raise _ir_importerror
    if isinstance(data_type, PointerType):
        return to_llvm_type(data_type.baseType).as_pointer()
    else:
        return to_llvm_type.map[data_type.numpyDtype]

if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }


def peelOffType(dtype, typeToPeelOff):
    while type(dtype) is typeToPeelOff:
        dtype = dtype.baseType
    return dtype


def collateTypes(types):
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
        pointerType = None
        for t in types:
            if type(t) is PointerType:
                if pointerType is not None:
                    raise ValueError("Cannot collate the combination of two pointer types")
                pointerType = t
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
        return pointerType

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
    vectorType = [t for t in types if type(t) is VectorType]
    if not allEqual(t.width for t in vectorType):
        raise ValueError("Collation failed because of vector types with different width")
    types = [peelOffType(t, VectorType) for t in types]

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
    resultNumpyType = np.result_type(*(t.numpyDtype for t in types))
    result = BasicType(resultNumpyType)
    if vectorType:
        result = VectorType(result, vectorType[0].width)
    return result


@memorycache(maxsize=2048)
def getTypeOfExpression(expr):
    from pystencils.astnodes import ResolvedFieldAccess
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
        return createTypeFromString("int")
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
        return createTypeFromString("double")
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
        return expr.dtype
    elif isinstance(expr, sp.Symbol):
        raise ValueError("All symbols inside this expression have to be typed!")
    elif hasattr(expr, 'func') and expr.func == castFunc:
        return expr.args[1]
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
        collatedResultType = collateTypes(tuple(getTypeOfExpression(a[0]) for a in expr.args))
        collatedConditionType = collateTypes(tuple(getTypeOfExpression(a[1]) for a in expr.args))
        if type(collatedConditionType) is VectorType and type(collatedResultType) is not VectorType:
            collatedResultType = VectorType(collatedResultType, width=collatedConditionType.width)
        return collatedResultType
    elif isinstance(expr, sp.Indexed):
        typedSymbol = expr.base.label
        return typedSymbol.dtype.baseType
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
        result = createTypeFromString("bool")
        vecArgs = [getTypeOfExpression(a) for a in expr.args if isinstance(getTypeOfExpression(a), VectorType)]
        if vecArgs:
            result = VectorType(result, width=vecArgs[0].width)
        return result
    elif isinstance(expr, sp.Expr):
        types = tuple(getTypeOfExpression(a) for a in expr.args)
        return collateTypes(types)

    raise NotImplementedError("Could not determine type for", expr, type(expr))


class Type(sp.Basic):
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)

    def __lt__(self, other):  # deprecated
        # Needed for sorting the types inside an expression
        if isinstance(self, BasicType):
            if isinstance(other, BasicType):
                return self.numpyDtype > other.numpyDtype  # TODO const
            elif isinstance(other, PointerType):
                return False
            else:  # isinstance(other, StructType):
                raise NotImplementedError("Struct type comparison is not yet implemented")
        elif isinstance(self, PointerType):
            if isinstance(other, BasicType):
                return True
            elif isinstance(other, PointerType):
                return self.baseType > other.baseType  # TODO const, restrict
            else:  # isinstance(other, StructType):
                raise NotImplementedError("Struct type comparison is not yet implemented")
        elif isinstance(self, StructType):
            raise NotImplementedError("Struct type comparison is not yet implemented")
        else:
            raise NotImplementedError

    def _sympystr(self, *args, **kwargs):
        return str(self)

    def _sympystr(self, *args, **kwargs):
        return str(self)


class BasicType(Type):
    @staticmethod
    def numpyNameToC(name):
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
            width = int(name[len("uint"):])
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
            raise NotImplemented("Can map numpy to C name for %s" % (name,))

    def __init__(self, dtype, const=False):
        self.const = const
        if isinstance(dtype, Type):
            self._dtype = dtype.numpyDtype
        else:
            self._dtype = np.dtype(dtype)
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

    def __getnewargs__(self):
        return self.numpyDtype, self.const

    @property
    def baseType(self):
        return None

    @property
    def numpyDtype(self):
        return self._dtype

    @property
    def itemSize(self):
        return 1

    def is_int(self):
        return self.numpyDtype in np.sctypes['int']

    def is_float(self):
        return self.numpyDtype in np.sctypes['float']

    def is_uint(self):
        return self.numpyDtype in np.sctypes['uint']

    def is_comlex(self):
        return self.numpyDtype in np.sctypes['complex']

    def is_other(self):
        return self.numpyDtype in np.sctypes['others']

    @property
    def baseName(self):
        return BasicType.numpyNameToC(str(self._dtype))

    def __str__(self):
        result = BasicType.numpyNameToC(str(self._dtype))
        if self.const:
            result += " const"
        return result

    def __repr__(self):
        return str(self)

    def __eq__(self, other):
        if not isinstance(other, BasicType):
            return False
        else:
            return (self.numpyDtype, self.const) == (other.numpyDtype, other.const)

    def __hash__(self):
        return hash(str(self))


class VectorType(Type):
    instructionSet = None

    def __init__(self, baseType, width=4):
        self._baseType = baseType
        self.width = width

    @property
    def baseType(self):
        return self._baseType

    @property
    def itemSize(self):
        return self.width * self.baseType.itemSize

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
            return (self.baseType, self.width) == (other.baseType, other.width)

    def __str__(self):
        if self.instructionSet is None:
            return "%s[%d]" % (self.baseType, self.width)
        else:
            if self.baseType == createTypeFromString("int64"):
                return self.instructionSet['int']
            elif self.baseType == createTypeFromString("double"):
                return self.instructionSet['double']
            elif self.baseType == createTypeFromString("float"):
                return self.instructionSet['float']
            elif self.baseType == createTypeFromString("bool"):
                return self.instructionSet['bool']
            else:
                raise NotImplementedError()

    def __hash__(self):
        return hash((self.baseType, self.width))


class PointerType(Type):
    def __init__(self, baseType, const=False, restrict=True):
        self._baseType = baseType
        self.const = const
        self.restrict = restrict

    def __getnewargs__(self):
        return self.baseType, self.const, self.restrict

    @property
    def alias(self):
        return not self.restrict

    @property
    def baseType(self):
        return self._baseType

    @property
    def itemSize(self):
        return self.baseType.itemSize

    def __eq__(self, other):
        if not isinstance(other, PointerType):
            return False
        else:
            return (self.baseType, self.const, self.restrict) == (other.baseType, other.const, other.restrict)

    def __str__(self):
        return "%s *%s%s" % (self.baseType, " RESTRICT " if self.restrict else "", " const " if self.const else "")

    def __repr__(self):
        return str(self)

    def __hash__(self):
        return hash((self._baseType, self.const, self.restrict))


class StructType(object):
    def __init__(self, numpyType, const=False):
        self.const = const
        self._dtype = np.dtype(numpyType)

    def __getnewargs__(self):
        return self.numpyDtype, self.const

    @property
    def baseType(self):
        return None

    @property
    def numpyDtype(self):
        return self._dtype

    @property
    def itemSize(self):
        return self.numpyDtype.itemsize

    def getElementOffset(self, elementName):
        return self.numpyDtype.fields[elementName][1]

    def getElementType(self, elementName):
        npElementType = self.numpyDtype.fields[elementName][0]
        return BasicType(npElementType, self.const)

    def hasElement(self, elementName):
        return elementName in self.numpyDtype.fields

    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
            return (self.numpyDtype, self.const) == (other.numpyDtype, other.const)

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

    def __repr__(self):
        return str(self)

    def __hash__(self):
        return hash((self.numpyDtype, self.const))

    # TODO this should not work at all!!!
    def __gt__(self, other):
        if self.ptr and not other.ptr:
            return True
        if self.dtype > other.dtype:
            return True


def get_type_from_sympy(node):
    """
    Creates a Type object from a Sympy object
    :param node: Sympy object
    :return: Type object
    """
    # Rational, NumberSymbol?
    # Zero, One, NegativeOne )= Integer
    # Half )= Rational
    # NAN, Infinity, Negative Inifinity,
    # Exp1, Imaginary Unit, Pi, EulerGamma, Catalan, Golden Ratio
    # Pow, Mul, Add, Mod, Relational
    if not isinstance(node, sp.Number):
        raise TypeError(node, 'is not a sp.Number')

    if isinstance(node, sp.Float) or isinstance(node, sp.RealNumber):
        return createType('double'), float(node)
    elif isinstance(node, sp.Integer):
        return createType('int'), int(node)
    elif isinstance(node, sp.Rational):
        # TODO is it always float?
        return createType('double'), float(node.p/node.q)
    else:
        raise TypeError(node, ' is not a supported type (yet)!')