Skip to content
Snippets Groups Projects
data_types.py 17.1 KiB
Newer Older
import ctypes
import sympy as sp
import numpy as np
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
from sympy.core.cache import cacheit
from pystencils.cache import memorycache
Martin Bauer's avatar
Martin Bauer committed
from pystencils.utils import all_equal
Martin Bauer's avatar
Martin Bauer committed
from sympy.logic.boolalg import Boolean
Martin Bauer's avatar
Martin Bauer committed
# noinspection PyPep8Naming
class cast_func(sp.Function):
    def __new__(cls, *args, **kwargs):
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
        # -> thus a separate class boolean_cast_func is introduced
        if isinstance(args[0], Boolean):
            cls = boolean_cast_func
        return sp.Function.__new__(cls, *args, **kwargs)

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
    @property
    def dtype(self):
        return self.args[1]


# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
    pass


Martin Bauer's avatar
Martin Bauer committed
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
    def __new_stage2__(cls, name, dtype):
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
Martin Bauer's avatar
Martin Bauer committed
            obj._dtype = create_type(dtype)
        except TypeError:
            # on error keep the string
            obj._dtype = dtype
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
        return super()._hashable_content(), hash(self._dtype)
Martin Bauer's avatar
Martin Bauer committed
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
    """
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
Martin Bauer's avatar
Martin Bauer committed
            return StructType(numpy_dtype, const=False)
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
    """
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
    base_part = parts.pop(0)
    const = False
Martin Bauer's avatar
Martin Bauer committed
    if 'const' in base_part:
        const = True
Martin Bauer's avatar
Martin Bauer committed
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
    current_type = BasicType(np.dtype(base_part[0]), const)
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
        current_type = PointerType(current_type, const, restrict)
    return current_type
Martin Bauer's avatar
Martin Bauer committed
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
Martin Bauer's avatar
Martin Bauer committed
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
        return ctypes.POINTER(ctypes.c_uint8)
Martin Bauer's avatar
Martin Bauer committed
        return to_ctypes.map[data_type.numpy_dtype]
Martin Bauer's avatar
Martin Bauer committed

to_ctypes.map = {
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


def ctypes_from_llvm(data_type):
    if not ir:
        raise _ir_importerror
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
    if not ir:
        raise _ir_importerror
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
        return to_llvm_type(data_type.base_type).as_pointer()
Martin Bauer's avatar
Martin Bauer committed
        return to_llvm_type.map[data_type.numpy_dtype]

if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
Martin Bauer's avatar
Martin Bauer committed
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
Martin Bauer's avatar
Martin Bauer committed
def collate_types(types):
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
        pointer_type = None
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
                if pointer_type is not None:
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
                pointer_type = t
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
        return pointer_type

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
    if not all_equal(t.width for t in vector_type):
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
    types = [peel_off_type(t, VectorType) for t in types]

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
    return result


@memorycache(maxsize=2048)
Martin Bauer's avatar
Martin Bauer committed
def get_type_of_expression(expr):
    from pystencils.astnodes import ResolvedFieldAccess
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
        return create_type("int")
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
Martin Bauer's avatar
Martin Bauer committed
        return create_type("double")
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
    elif isinstance(expr, sp.Symbol):
        raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
    elif isinstance(expr, cast_func):
        return expr.args[1]
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
        result = create_type("bool")
        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
        return result
    elif isinstance(expr, sp.Pow):
        return get_type_of_expression(expr.args[0])
    elif isinstance(expr, sp.Expr):
Martin Bauer's avatar
Martin Bauer committed
        types = tuple(get_type_of_expression(a) for a in expr.args)
        return collate_types(types)
    raise NotImplementedError("Could not determine type for", expr, type(expr))
Martin Bauer's avatar
Martin Bauer committed
    is_Atom = True

    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
    def _sympystr(self, *args, **kwargs):
        return str(self)


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
            width = int(name[len("uint"):])
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
Martin Bauer's avatar
Martin Bauer committed
            raise NotImplementedError("Can map numpy to C name for %s" % (name,))

    def __init__(self, dtype, const=False):
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
            self._dtype = dtype.numpy_dtype
        else:
            self._dtype = np.dtype(dtype)
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

Martin Bauer's avatar
Martin Bauer committed
        return self.numpy_dtype, self.const
Martin Bauer's avatar
Martin Bauer committed
    def base_type(self):
        return None
Martin Bauer's avatar
Martin Bauer committed
    def numpy_dtype(self):
        return self._dtype

Martin Bauer's avatar
Martin Bauer committed
    def item_size(self):
Martin Bauer's avatar
Martin Bauer committed
        return self.numpy_dtype in np.sctypes['int']
Martin Bauer's avatar
Martin Bauer committed
        return self.numpy_dtype in np.sctypes['float']
Martin Bauer's avatar
Martin Bauer committed
        return self.numpy_dtype in np.sctypes['uint']
Martin Bauer's avatar
Martin Bauer committed
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
Martin Bauer's avatar
Martin Bauer committed
        return self.numpy_dtype in np.sctypes['others']
Martin Bauer's avatar
Martin Bauer committed
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
Jan Hoenig's avatar
Jan Hoenig committed
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
        result = BasicType.numpy_name_to_c(str(self._dtype))
        if self.const:
            result += " const"
        return result
    def __repr__(self):
        return str(self)

    def __eq__(self, other):
        if not isinstance(other, BasicType):
            return False
Martin Bauer's avatar
Martin Bauer committed
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)

    def __hash__(self):
        return hash(str(self))


class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
    instruction_set = None
Martin Bauer's avatar
Martin Bauer committed
    def __init__(self, base_type, width=4):
        self._base_type = base_type
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
    def base_type(self):
        return self._base_type
Martin Bauer's avatar
Martin Bauer committed
    def item_size(self):
        return self.width * self.base_type.item_size

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
            return (self.base_type, self.width) == (other.base_type, other.width)

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
            return "%s[%d]" % (self.base_type, self.width)
Martin Bauer's avatar
Martin Bauer committed
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
                return self.instruction_set['bool']
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
        return hash((self.base_type, self.width))
Martin Bauer's avatar
Martin Bauer committed
    def __getnewargs__(self):
        return self._base_type, self.width

class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
        self.const = const
        self.restrict = restrict

Martin Bauer's avatar
Martin Bauer committed
        return self.base_type, self.const, self.restrict
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
    def base_type(self):
        return self._base_type
Martin Bauer's avatar
Martin Bauer committed
    def item_size(self):
        return self.base_type.item_size
    def __eq__(self, other):
        if not isinstance(other, PointerType):
            return False
Martin Bauer's avatar
Martin Bauer committed
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
Jan Hoenig's avatar
Jan Hoenig committed
    def __str__(self):
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
    def __repr__(self):
        return str(self)

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
        return hash((self._base_type, self.const, self.restrict))
Jan Hoenig's avatar
Jan Hoenig committed

Martin Bauer's avatar
Martin Bauer committed
    def __init__(self, numpy_type, const=False):
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed

Martin Bauer's avatar
Martin Bauer committed
        return self.numpy_dtype, self.const
Martin Bauer's avatar
Martin Bauer committed
    def base_type(self):
Martin Bauer's avatar
Martin Bauer committed
    def numpy_dtype(self):
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
    def item_size(self):
        return self.numpy_dtype.itemsize
Martin Bauer's avatar
Martin Bauer committed
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
Martin Bauer's avatar
Martin Bauer committed
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
Martin Bauer's avatar
Martin Bauer committed
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

    def __repr__(self):
        return str(self)

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
        return hash((self.numpy_dtype, self.const))