types.py 8.56 KB
Newer Older
1
import ctypes
2
import sympy as sp
3
import numpy as np
Jan Hoenig's avatar
Jan Hoenig committed
4
import llvmlite.ir as ir
5
6
7
8
9
10
11
12
13
14
from sympy.core.cache import cacheit


class TypedSymbol(sp.Symbol):
    def __new__(cls, name, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, name, *args, **kwds)
        return obj

    def __new_stage2__(cls, name, dtype):
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
15
        obj._dtype = createType(dtype)
16
17
18
19
20
21
22
23
24
25
26
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
        superClassContents = list(super(TypedSymbol, self)._hashable_content())
27
        t = tuple(superClassContents + [hash(repr(self._dtype))])
28
        return t
29
30
31
32

    def __getnewargs__(self):
        return self.name, self.dtype

33

34
def createType(specification):
Jan Hoenig's avatar
Jan Hoenig committed
35
36
37
38
39
    """
    Create a subclass of Type according to a string or an object of subclass Type
    :param specification: Type object, or a string
    :return: Type object, or a new Type object parsed from the string
    """
40
41
42
43
44
45
46
47
48
49
    if isinstance(specification, Type):
        return specification
    elif isinstance(specification, str):
        return createTypeFromString(specification)
    else:
        npDataType = np.dtype(specification)
        return BasicType(npDataType, const=False)


def createTypeFromString(specification):
Jan Hoenig's avatar
Jan Hoenig committed
50
51
52
53
54
    """
    Creates a new Type object from a c-like string specification
    :param specification: Specification string
    :return: Type object
    """
55
56
57
58
59
60
61
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
62
        else:
63
64
65
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
66
        # Parse native part
67
68
69
70
71
72
    basePart = parts.pop(0)
    const = False
    if 'const' in basePart:
        const = True
        basePart.remove('const')
    assert len(basePart) == 1
Jan Hoenig's avatar
Jan Hoenig committed
73
74
75
    if basePart[0][-1] == "*":
        basePart[0] = basePart[0][:-1]
        parts.append('*')
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    baseType = BasicType(basePart[0], const)
    currentType = baseType
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
        currentType = PointerType(currentType, const, restrict)
    return currentType


def getBaseType(type):
    while type.baseType is not None:
        type = type.baseType
    return type


def toCtypes(dataType):
Jan Hoenig's avatar
Jan Hoenig committed
100
101
102
103
104
    """
    Transforms a given Type into ctypes
    :param dataType: Subclass of Type
    :return: ctypes type object
    """
105
106
107
108
109
    if isinstance(dataType, PointerType):
        return ctypes.POINTER(toCtypes(dataType.baseType))
    else:
        return toCtypes.map[dataType.numpyDtype]

Jan Hoenig's avatar
Jan Hoenig committed
110

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
toCtypes.map = {
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


Jan Hoenig's avatar
Jan Hoenig committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def to_llvmlite_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
    if isinstance(data_type, PointerType):
        return to_llvmlite_type.map[data_type.baseType].as_pointer()
    else:
        return to_llvmlite_type.map[data_type.numpyDType]

to_llvmlite_type.map = {
    np.dtype(np.int8): ir.IntType(8),
    np.dtype(np.int16): ir.IntType(16),
    np.dtype(np.int32): ir.IntType(32),
    np.dtype(np.int64): ir.IntType(64),

    # TODO llvmlite doesn't seem to differentiate between Int types
    np.dtype(np.uint8): ir.IntType(8),
    np.dtype(np.uint16): ir.IntType(16),
    np.dtype(np.uint32): ir.IntType(32),
    np.dtype(np.uint64): ir.IntType(64),

    np.dtype(np.float32): ir.FloatType(),
    np.dtype(np.float64): ir.DoubleType(),
    # TODO const, restrict, void
}


156
class Type(object):
Jan Hoenig's avatar
Jan Hoenig committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    def __lt__(self, other):
        # Needed for sorting the types inside an expression
        if isinstance(self, BasicType):
            if isinstance(other, BasicType):
                return self.numpyDtype < other.numpyDtype  # TODO const
            if isinstance(other, PointerType):
                return True  # TODO test
            if isinstance(other, StructType):
                raise NotImplementedError("Struct type comparison is not yet implemented")
        if isinstance(self, PointerType):
            if isinstance(other, BasicType):
                return False  # TODO test
            if isinstance(other, PointerType):
                return self.baseType < other.baseType  # TODO const, restrict
            if isinstance(other, StructType):
                raise NotImplementedError("Struct type comparison is not yet implemented")
        if isinstance(self, StructType):
            raise NotImplementedError("Struct type comparison is not yet implemented")
175
176
177
178
179


class BasicType(Type):
    @staticmethod
    def numpyNameToC(name):
Jan Hoenig's avatar
Jan Hoenig committed
180
181
182
183
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
            width = int(name[len("int"):])
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
            raise NotImplemented("Can map numpy to C name for %s" % (name,))

    def __init__(self, dtype, const=False):
        self.const = const
        self._dtype = np.dtype(dtype)
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

    @property
    def baseType(self):
        return None
205

206
207
208
    @property
    def numpyDtype(self):
        return self._dtype
209
210

    def __repr__(self):
211
212
213
214
        result = BasicType.numpyNameToC(str(self._dtype))
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
215
216

    def __eq__(self, other):
217
218
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
219
        else:
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
            return (self.numpyDtype, self.const) == (other.numpyDtype, other.const)

    def __hash__(self):
        return hash(str(self))


class PointerType(Type):
    def __init__(self, baseType, const=False, restrict=True):
        self._baseType = baseType
        self.const = const
        self.restrict = restrict

    @property
    def alias(self):
        return not self.restrict

    @property
    def baseType(self):
        return self._baseType

    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
242
            return False
243
244
245
        else:
            return (self.baseType, self.const, self.restrict) == (other.baseType, other.const, other.restrict)

Jan Hoenig's avatar
Jan Hoenig committed
246
    def __repr__(self):
247
248
249
250
251
        return "%s * %s%s" % (self.baseType, "RESTRICT " if self.restrict else "", "const " if self.const else "")

    def __hash__(self):
        return hash(str(self))

Jan Hoenig's avatar
Jan Hoenig committed
252

253
254
255
class StructType(object):
    def __init__(self, numpyType):
        self._dtype = np.dtype(numpyType)
Jan Hoenig's avatar
Jan Hoenig committed
256

Jan Hoenig's avatar
Jan Hoenig committed
257
    # TODO this should not work at all!!!
Jan Hoenig's avatar
Jan Hoenig committed
258
259
260
261
262
263
    def __gt__(self, other):
        if self.ptr and not other.ptr:
            return True
        if self.dtype > other.dtype:
            return True

Jan Hoenig's avatar
Jan Hoenig committed
264
265
266
    def __hash__(self):
        return hash(repr(self))

Jan Hoenig's avatar
Jan Hoenig committed
267

Jan Hoenig's avatar
Jan Hoenig committed
268
def get_type_from_sympy(node):
Jan Hoenig's avatar
Jan Hoenig committed
269
270
271
272
273
    """
    Creates a Type object from a Sympy object
    :param node: Sympy object
    :return: Type object
    """
Jan Hoenig's avatar
Jan Hoenig committed
274
275
276
277
278
279
280
281
282
283
    # Rational, NumberSymbol?
    # Zero, One, NegativeOne )= Integer
    # Half )= Rational
    # NAN, Infinity, Negative Inifinity,
    # Exp1, Imaginary Unit, Pi, EulerGamma, Catalan, Golden Ratio
    # Pow, Mul, Add, Mod, Relational
    if not isinstance(node, sp.Number):
        raise TypeError(node, 'is not a sp.Number')

    if isinstance(node, sp.Float) or isinstance(node, sp.RealNumber):
Jan Hoenig's avatar
Jan Hoenig committed
284
        return createType('double'), float(node)
Jan Hoenig's avatar
Jan Hoenig committed
285
    elif isinstance(node, sp.Integer):
Jan Hoenig's avatar
Jan Hoenig committed
286
        return createType('int'), int(node)
Jan Hoenig's avatar
Jan Hoenig committed
287
288
289
    elif isinstance(node, sp.Rational):
        raise NotImplementedError('Rationals are not supported yet')
    else:
Jan Hoenig's avatar
Jan Hoenig committed
290
        raise TypeError(node, ' is not a supported type (yet)!')