types.py 11.2 KB
Newer Older
1
import ctypes
2
import sympy as sp
3
import numpy as np
4
# import llvmlite.ir as ir
5
6
7
8
9
10
11
12
from sympy.core.cache import cacheit


class TypedSymbol(sp.Symbol):
    def __new__(cls, name, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, name, *args, **kwds)
        return obj

13
    def __new_stage2__(cls, name, dtype):
14
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
15
        obj._dtype = createType(dtype)
16
17
18
19
20
21
22
23
24
25
26
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
        superClassContents = list(super(TypedSymbol, self)._hashable_content())
27
        t = tuple(superClassContents + [hash(repr(self._dtype))])
28
        return t
29
30

    def __getnewargs__(self):
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
        return self.name, self.dtype


#class IndexedWithCast(sp.tensor.Indexed):
#    def __new__(cls, base, castTo, *args):
#        obj = super(IndexedWithCast, cls).__new__(cls, base, *args)
#        obj._castTo = castTo
#        return obj
#
#    @property
#    def castTo(self):
#        return self._castTo
#
#    def _hashable_content(self):
#        superClassContents = list(super(IndexedWithCast, self)._hashable_content())
#        t = tuple(superClassContents + [hash(repr(self._castTo))])
#        return t
#
#    def __getnewargs__(self):
#        return self.base, self.castTo
51

52

53
def createType(specification):
Jan Hoenig's avatar
Jan Hoenig committed
54
55
56
57
58
    """
    Create a subclass of Type according to a string or an object of subclass Type
    :param specification: Type object, or a string
    :return: Type object, or a new Type object parsed from the string
    """
59
60
61
62
63
64
    if isinstance(specification, Type):
        return specification
    elif isinstance(specification, str):
        return createTypeFromString(specification)
    else:
        npDataType = np.dtype(specification)
65
66
67
68
        if npDataType.fields is None:
            return BasicType(npDataType, const=False)
        else:
            return StructType(npDataType, const=False)
69
70
71


def createTypeFromString(specification):
Jan Hoenig's avatar
Jan Hoenig committed
72
73
74
75
76
    """
    Creates a new Type object from a c-like string specification
    :param specification: Specification string
    :return: Type object
    """
77
78
79
80
81
82
83
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
84
        else:
85
86
87
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
88
        # Parse native part
89
90
91
92
93
94
    basePart = parts.pop(0)
    const = False
    if 'const' in basePart:
        const = True
        basePart.remove('const')
    assert len(basePart) == 1
Jan Hoenig's avatar
Jan Hoenig committed
95
96
97
    if basePart[0][-1] == "*":
        basePart[0] = basePart[0][:-1]
        parts.append('*')
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    baseType = BasicType(basePart[0], const)
    currentType = baseType
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
        currentType = PointerType(currentType, const, restrict)
    return currentType


def getBaseType(type):
    while type.baseType is not None:
        type = type.baseType
    return type


def toCtypes(dataType):
Jan Hoenig's avatar
Jan Hoenig committed
122
123
124
125
126
    """
    Transforms a given Type into ctypes
    :param dataType: Subclass of Type
    :return: ctypes type object
    """
127
128
    if isinstance(dataType, PointerType):
        return ctypes.POINTER(toCtypes(dataType.baseType))
129
130
    elif isinstance(dataType, StructType):
        return ctypes.POINTER(ctypes.c_uint8)
131
132
133
    else:
        return toCtypes.map[dataType.numpyDtype]

Jan Hoenig's avatar
Jan Hoenig committed
134

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
toCtypes.map = {
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#def to_llvmlite_type(data_type):
#    """
#    Transforms a given type into ctypes
#    :param data_type: Subclass of Type
#    :return: llvmlite type object
#    """
#    if isinstance(data_type, PointerType):
#        return to_llvmlite_type.map[data_type.baseType].as_pointer()
#    else:
#        return to_llvmlite_type.map[data_type.numpyDType]
#
#to_llvmlite_type.map = {
#    np.dtype(np.int8): ir.IntType(8),
#    np.dtype(np.int16): ir.IntType(16),
#    np.dtype(np.int32): ir.IntType(32),
#    np.dtype(np.int64): ir.IntType(64),
#
#    # TODO llvmlite doesn't seem to differentiate between Int types
#    np.dtype(np.uint8): ir.IntType(8),
#    np.dtype(np.uint16): ir.IntType(16),
#    np.dtype(np.uint32): ir.IntType(32),
#    np.dtype(np.uint64): ir.IntType(64),
#
#    np.dtype(np.float32): ir.FloatType(),
#    np.dtype(np.float64): ir.DoubleType(),
#    # TODO const, restrict, void
#}


180
181
182
class Type(sp.Basic):
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
183

Jan Hoenig's avatar
Jan Hoenig committed
184
185
186
187
188
189
    def __lt__(self, other):
        # Needed for sorting the types inside an expression
        if isinstance(self, BasicType):
            if isinstance(other, BasicType):
                return self.numpyDtype < other.numpyDtype  # TODO const
            if isinstance(other, PointerType):
190
                return False
Jan Hoenig's avatar
Jan Hoenig committed
191
192
193
194
            if isinstance(other, StructType):
                raise NotImplementedError("Struct type comparison is not yet implemented")
        if isinstance(self, PointerType):
            if isinstance(other, BasicType):
195
                return True
Jan Hoenig's avatar
Jan Hoenig committed
196
197
198
199
200
201
            if isinstance(other, PointerType):
                return self.baseType < other.baseType  # TODO const, restrict
            if isinstance(other, StructType):
                raise NotImplementedError("Struct type comparison is not yet implemented")
        if isinstance(self, StructType):
            raise NotImplementedError("Struct type comparison is not yet implemented")
202
203
204
205
206


class BasicType(Type):
    @staticmethod
    def numpyNameToC(name):
Jan Hoenig's avatar
Jan Hoenig committed
207
208
209
210
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
211
212
213
214
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
215
            width = int(name[len("uint"):])
216
217
218
219
220
221
222
223
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
            raise NotImplemented("Can map numpy to C name for %s" % (name,))

    def __init__(self, dtype, const=False):
        self.const = const
224
225
226
227
        if isinstance(dtype, Type):
            self._dtype = dtype.numpyDtype
        else:
            self._dtype = np.dtype(dtype)
228
229
230
231
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

232
233
234
    def __getnewargs__(self):
        return self.numpyDtype, self.const

235
236
237
    @property
    def baseType(self):
        return None
238

239
240
241
242
    @property
    def numpyDtype(self):
        return self._dtype

243
244
245
246
    @property
    def itemSize(self):
        return 1

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    def is_int(self):
        return self.numpyDtype in np.sctypes['int']

    def is_float(self):
        return self.numpyDtype in np.sctypes['float']

    def is_uint(self):
        return self.numpyDtype in np.sctypes['uint']

    def is_comlex(self):
        return self.numpyDtype in np.sctypes['complex']

    def is_other(self):
        return self.numpyDtype in np.sctypes['others']

262
    def __repr__(self):
263
264
265
266
        result = BasicType.numpyNameToC(str(self._dtype))
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
267
268

    def __eq__(self, other):
269
270
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
271
        else:
272
273
274
275
276
277
278
279
280
281
282
283
            return (self.numpyDtype, self.const) == (other.numpyDtype, other.const)

    def __hash__(self):
        return hash(str(self))


class PointerType(Type):
    def __init__(self, baseType, const=False, restrict=True):
        self._baseType = baseType
        self.const = const
        self.restrict = restrict

284
285
286
    def __getnewargs__(self):
        return self.baseType, self.const, self.restrict

287
288
289
290
291
292
293
294
    @property
    def alias(self):
        return not self.restrict

    @property
    def baseType(self):
        return self._baseType

295
296
297
298
    @property
    def itemSize(self):
        return self.baseType.itemSize

299
300
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
301
            return False
302
303
304
        else:
            return (self.baseType, self.const, self.restrict) == (other.baseType, other.const, other.restrict)

Jan Hoenig's avatar
Jan Hoenig committed
305
    def __repr__(self):
306
307
308
309
310
        return "%s * %s%s" % (self.baseType, "RESTRICT " if self.restrict else "", "const " if self.const else "")

    def __hash__(self):
        return hash(str(self))

Jan Hoenig's avatar
Jan Hoenig committed
311

312
class StructType(object):
313
314
    def __init__(self, numpyType, const=False):
        self.const = const
315
        self._dtype = np.dtype(numpyType)
Martin Bauer's avatar
Martin Bauer committed
316

317
318
319
    def __getnewargs__(self):
        return self.numpyDtype, self.const

320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    @property
    def baseType(self):
        return None

    @property
    def numpyDtype(self):
        return self._dtype

    @property
    def itemSize(self):
        return self.numpyDtype.itemsize

    def getElementOffset(self, elementName):
        return self.numpyDtype.fields[elementName][1]

    def getElementType(self, elementName):
        npElementType = self.numpyDtype.fields[elementName][0]
        return BasicType(npElementType, self.const)

339
340
341
    def hasElement(self, elementName):
        return elementName in self.numpyDtype.fields

342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
            return (self.numpyDtype, self.const) == (other.numpyDtype, other.const)

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

    def __hash__(self):
        return hash((self.numpyDtype, self.const))
357

Jan Hoenig's avatar
Jan Hoenig committed
358
    # TODO this should not work at all!!!
Jan Hoenig's avatar
Jan Hoenig committed
359
360
361
362
363
364
365
    def __gt__(self, other):
        if self.ptr and not other.ptr:
            return True
        if self.dtype > other.dtype:
            return True


Jan Hoenig's avatar
Jan Hoenig committed
366
def get_type_from_sympy(node):
Jan Hoenig's avatar
Jan Hoenig committed
367
368
369
370
371
    """
    Creates a Type object from a Sympy object
    :param node: Sympy object
    :return: Type object
    """
Jan Hoenig's avatar
Jan Hoenig committed
372
373
374
375
376
377
378
379
380
381
    # Rational, NumberSymbol?
    # Zero, One, NegativeOne )= Integer
    # Half )= Rational
    # NAN, Infinity, Negative Inifinity,
    # Exp1, Imaginary Unit, Pi, EulerGamma, Catalan, Golden Ratio
    # Pow, Mul, Add, Mod, Relational
    if not isinstance(node, sp.Number):
        raise TypeError(node, 'is not a sp.Number')

    if isinstance(node, sp.Float) or isinstance(node, sp.RealNumber):
Jan Hoenig's avatar
Jan Hoenig committed
382
        return createType('double'), float(node)
Jan Hoenig's avatar
Jan Hoenig committed
383
    elif isinstance(node, sp.Integer):
Jan Hoenig's avatar
Jan Hoenig committed
384
        return createType('int'), int(node)
Jan Hoenig's avatar
Jan Hoenig committed
385
386
387
    elif isinstance(node, sp.Rational):
        raise NotImplementedError('Rationals are not supported yet')
    else:
Jan Hoenig's avatar
Jan Hoenig committed
388
        raise TypeError(node, ' is not a supported type (yet)!')