data_types.py 17.7 KB
Newer Older
1
import ctypes
2
import sympy as sp
3
import numpy as np
4
5
6
7
8
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
9
from sympy.core.cache import cacheit
10

11
from pystencils.cache import memorycache
12
13
from pystencils.utils import allEqual

14
15
16
17
18
19
20
21
22

# to work in conditions of sp.Piecewise castFunc has to be of type Relational as well
class castFunc(sp.Function, sp.Rel):
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
23

24
25
26
27
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

28

29
30
31
32
33
34
35
36
37
38
class pointerArithmeticFunc(sp.Function, sp.Rel):

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


39
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
40
41
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
42
43
        return obj

44
    def __new_stage2__(cls, name, dtype):
45
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
46
47
48
49
50
        try:
            obj._dtype = createType(dtype)
        except TypeError:
            # on error keep the string
            obj._dtype = dtype
51
52
53
54
55
56
57
58
59
60
61
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
        superClassContents = list(super(TypedSymbol, self)._hashable_content())
62
        return tuple(superClassContents + [hash(self._dtype)])
63
64

    def __getnewargs__(self):
65
66
67
        return self.name, self.dtype


68
def createType(specification):
Jan Hoenig's avatar
Jan Hoenig committed
69
70
71
72
73
    """
    Create a subclass of Type according to a string or an object of subclass Type
    :param specification: Type object, or a string
    :return: Type object, or a new Type object parsed from the string
    """
74
75
76
77
    if isinstance(specification, Type):
        return specification
    else:
        npDataType = np.dtype(specification)
78
79
80
81
        if npDataType.fields is None:
            return BasicType(npDataType, const=False)
        else:
            return StructType(npDataType, const=False)
82
83


84
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
85
def createCompositeTypeFromString(specification):
Jan Hoenig's avatar
Jan Hoenig committed
86
87
88
89
90
    """
    Creates a new Type object from a c-like string specification
    :param specification: Specification string
    :return: Type object
    """
91
92
93
94
95
96
97
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
98
        else:
99
100
101
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
102
        # Parse native part
103
104
105
106
107
108
    basePart = parts.pop(0)
    const = False
    if 'const' in basePart:
        const = True
        basePart.remove('const')
    assert len(basePart) == 1
Jan Hoenig's avatar
Jan Hoenig committed
109
110
111
    if basePart[0][-1] == "*":
        basePart[0] = basePart[0][:-1]
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
112
    currentType = BasicType(np.dtype(basePart[0]), const)
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
        currentType = PointerType(currentType, const, restrict)
    return currentType


def getBaseType(type):
    while type.baseType is not None:
        type = type.baseType
    return type


def toCtypes(dataType):
Jan Hoenig's avatar
Jan Hoenig committed
135
136
137
138
139
    """
    Transforms a given Type into ctypes
    :param dataType: Subclass of Type
    :return: ctypes type object
    """
140
141
    if isinstance(dataType, PointerType):
        return ctypes.POINTER(toCtypes(dataType.baseType))
142
143
    elif isinstance(dataType, StructType):
        return ctypes.POINTER(ctypes.c_uint8)
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    else:
        return toCtypes.map[dataType.numpyDtype]

toCtypes.map = {
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


163
def ctypes_from_llvm(data_type):
164
165
    if not ir:
        raise _ir_importerror
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
199
200
    if not ir:
        raise _ir_importerror
201
202
203
204
205
    if isinstance(data_type, PointerType):
        return to_llvm_type(data_type.baseType).as_pointer()
    else:
        return to_llvm_type.map[data_type.numpyDtype]

206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
221

222

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def peelOffType(dtype, typeToPeelOff):
    while type(dtype) is typeToPeelOff:
        dtype = dtype.baseType
    return dtype


def collateTypes(types):
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
        pointerType = None
        for t in types:
            if type(t) is PointerType:
                if pointerType is not None:
                    raise ValueError("Cannot collate the combination of two pointer types")
                pointerType = t
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
        return pointerType

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
    vectorType = [t for t in types if type(t) is VectorType]
    if not allEqual(t.width for t in vectorType):
        raise ValueError("Collation failed because of vector types with different width")
    types = [peelOffType(t, VectorType) for t in types]

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
    resultNumpyType = np.result_type(*(t.numpyDtype for t in types))
    result = BasicType(resultNumpyType)
    if vectorType:
        result = VectorType(result, vectorType[0].width)
    return result


@memorycache(maxsize=2048)
268
def getTypeOfExpression(expr):
269
270
271
    from pystencils.astnodes import ResolvedFieldAccess
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
Martin Bauer's avatar
Martin Bauer committed
272
        return createType("int")
273
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
Martin Bauer's avatar
Martin Bauer committed
274
        return createType("double")
275
276
277
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
278
        return expr.dtype
279
280
281
282
283
    elif isinstance(expr, sp.Symbol):
        raise ValueError("All symbols inside this expression have to be typed!")
    elif hasattr(expr, 'func') and expr.func == castFunc:
        return expr.args[1]
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
284
285
286
287
288
        collatedResultType = collateTypes(tuple(getTypeOfExpression(a[0]) for a in expr.args))
        collatedConditionType = collateTypes(tuple(getTypeOfExpression(a[1]) for a in expr.args))
        if type(collatedConditionType) is VectorType and type(collatedResultType) is not VectorType:
            collatedResultType = VectorType(collatedResultType, width=collatedConditionType.width)
        return collatedResultType
289
290
    elif isinstance(expr, sp.Indexed):
        typedSymbol = expr.base.label
291
        return typedSymbol.dtype.baseType
292
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
293
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
294
        result = createType("bool")
295
296
297
298
        vecArgs = [getTypeOfExpression(a) for a in expr.args if isinstance(getTypeOfExpression(a), VectorType)]
        if vecArgs:
            result = VectorType(result, width=vecArgs[0].width)
        return result
299
300
301
302
    elif isinstance(expr, sp.Expr):
        types = tuple(getTypeOfExpression(a) for a in expr.args)
        return collateTypes(types)

303
    raise NotImplementedError("Could not determine type for", expr, type(expr))
304
305


306
307
308
class Type(sp.Basic):
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
309

310
    def __lt__(self, other):  # deprecated
Jan Hoenig's avatar
Jan Hoenig committed
311
312
313
        # Needed for sorting the types inside an expression
        if isinstance(self, BasicType):
            if isinstance(other, BasicType):
314
315
                return self.numpyDtype > other.numpyDtype  # TODO const
            elif isinstance(other, PointerType):
316
                return False
317
            else:  # isinstance(other, StructType):
Jan Hoenig's avatar
Jan Hoenig committed
318
                raise NotImplementedError("Struct type comparison is not yet implemented")
319
        elif isinstance(self, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
320
            if isinstance(other, BasicType):
321
                return True
322
323
324
            elif isinstance(other, PointerType):
                return self.baseType > other.baseType  # TODO const, restrict
            else:  # isinstance(other, StructType):
Jan Hoenig's avatar
Jan Hoenig committed
325
                raise NotImplementedError("Struct type comparison is not yet implemented")
326
        elif isinstance(self, StructType):
Jan Hoenig's avatar
Jan Hoenig committed
327
            raise NotImplementedError("Struct type comparison is not yet implemented")
328
329
330
331
332
        else:
            raise NotImplementedError

    def _sympystr(self, *args, **kwargs):
        return str(self)
333

334
335
336
    def _sympystr(self, *args, **kwargs):
        return str(self)

337
338
339
340

class BasicType(Type):
    @staticmethod
    def numpyNameToC(name):
Jan Hoenig's avatar
Jan Hoenig committed
341
342
343
344
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
345
346
347
348
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
349
            width = int(name[len("uint"):])
350
351
352
353
354
355
356
357
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
            raise NotImplemented("Can map numpy to C name for %s" % (name,))

    def __init__(self, dtype, const=False):
        self.const = const
358
359
360
361
        if isinstance(dtype, Type):
            self._dtype = dtype.numpyDtype
        else:
            self._dtype = np.dtype(dtype)
362
363
364
365
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

366
367
368
    def __getnewargs__(self):
        return self.numpyDtype, self.const

369
370
371
    @property
    def baseType(self):
        return None
372

373
374
375
376
    @property
    def numpyDtype(self):
        return self._dtype

377
378
379
380
    @property
    def itemSize(self):
        return 1

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    def is_int(self):
        return self.numpyDtype in np.sctypes['int']

    def is_float(self):
        return self.numpyDtype in np.sctypes['float']

    def is_uint(self):
        return self.numpyDtype in np.sctypes['uint']

    def is_comlex(self):
        return self.numpyDtype in np.sctypes['complex']

    def is_other(self):
        return self.numpyDtype in np.sctypes['others']

396
397
398
399
    @property
    def baseName(self):
        return BasicType.numpyNameToC(str(self._dtype))

Jan Hoenig's avatar
Jan Hoenig committed
400
    def __str__(self):
401
402
403
404
        result = BasicType.numpyNameToC(str(self._dtype))
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
405

406
407
408
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
409
    def __eq__(self, other):
410
411
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
412
        else:
413
414
415
416
417
418
            return (self.numpyDtype, self.const) == (other.numpyDtype, other.const)

    def __hash__(self):
        return hash(str(self))


419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
class VectorType(Type):
    instructionSet = None

    def __init__(self, baseType, width=4):
        self._baseType = baseType
        self.width = width

    @property
    def baseType(self):
        return self._baseType

    @property
    def itemSize(self):
        return self.width * self.baseType.itemSize

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
            return (self.baseType, self.width) == (other.baseType, other.width)

    def __str__(self):
        if self.instructionSet is None:
            return "%s[%d]" % (self.baseType, self.width)
        else:
Martin Bauer's avatar
Martin Bauer committed
444
            if self.baseType == createType("int64"):
445
                return self.instructionSet['int']
Martin Bauer's avatar
Martin Bauer committed
446
            elif self.baseType == createType("float64"):
447
                return self.instructionSet['double']
Martin Bauer's avatar
Martin Bauer committed
448
            elif self.baseType == createType("float32"):
449
                return self.instructionSet['float']
Martin Bauer's avatar
Martin Bauer committed
450
            elif self.baseType == createType("bool"):
451
                return self.instructionSet['bool']
452
453
454
455
            else:
                raise NotImplementedError()

    def __hash__(self):
456
        return hash((self.baseType, self.width))
457
458


459
460
461
462
463
464
class PointerType(Type):
    def __init__(self, baseType, const=False, restrict=True):
        self._baseType = baseType
        self.const = const
        self.restrict = restrict

465
466
467
    def __getnewargs__(self):
        return self.baseType, self.const, self.restrict

468
469
470
471
472
473
474
475
    @property
    def alias(self):
        return not self.restrict

    @property
    def baseType(self):
        return self._baseType

476
477
478
479
    @property
    def itemSize(self):
        return self.baseType.itemSize

480
481
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
482
            return False
483
484
485
        else:
            return (self.baseType, self.const, self.restrict) == (other.baseType, other.const, other.restrict)

Jan Hoenig's avatar
Jan Hoenig committed
486
    def __str__(self):
Jan Hoenig's avatar
Jan Hoenig committed
487
        return "%s *%s%s" % (self.baseType, " RESTRICT " if self.restrict else "", " const " if self.const else "")
488

489
490
491
    def __repr__(self):
        return str(self)

492
    def __hash__(self):
493
        return hash((self._baseType, self.const, self.restrict))
494

Jan Hoenig's avatar
Jan Hoenig committed
495

496
class StructType(object):
497
498
    def __init__(self, numpyType, const=False):
        self.const = const
499
        self._dtype = np.dtype(numpyType)
Martin Bauer's avatar
Martin Bauer committed
500

501
502
503
    def __getnewargs__(self):
        return self.numpyDtype, self.const

504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
    @property
    def baseType(self):
        return None

    @property
    def numpyDtype(self):
        return self._dtype

    @property
    def itemSize(self):
        return self.numpyDtype.itemsize

    def getElementOffset(self, elementName):
        return self.numpyDtype.fields[elementName][1]

    def getElementType(self, elementName):
        npElementType = self.numpyDtype.fields[elementName][0]
        return BasicType(npElementType, self.const)

523
524
525
    def hasElement(self, elementName):
        return elementName in self.numpyDtype.fields

526
527
528
529
530
531
532
533
534
535
536
537
538
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
            return (self.numpyDtype, self.const) == (other.numpyDtype, other.const)

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

539
540
541
    def __repr__(self):
        return str(self)

542
543
    def __hash__(self):
        return hash((self.numpyDtype, self.const))
544

Jan Hoenig's avatar
Jan Hoenig committed
545
    # TODO this should not work at all!!!
Jan Hoenig's avatar
Jan Hoenig committed
546
547
548
549
550
551
552
    def __gt__(self, other):
        if self.ptr and not other.ptr:
            return True
        if self.dtype > other.dtype:
            return True


Jan Hoenig's avatar
Jan Hoenig committed
553
def get_type_from_sympy(node):
Jan Hoenig's avatar
Jan Hoenig committed
554
555
556
557
558
    """
    Creates a Type object from a Sympy object
    :param node: Sympy object
    :return: Type object
    """
Jan Hoenig's avatar
Jan Hoenig committed
559
560
561
562
563
564
565
566
567
568
    # Rational, NumberSymbol?
    # Zero, One, NegativeOne )= Integer
    # Half )= Rational
    # NAN, Infinity, Negative Inifinity,
    # Exp1, Imaginary Unit, Pi, EulerGamma, Catalan, Golden Ratio
    # Pow, Mul, Add, Mod, Relational
    if not isinstance(node, sp.Number):
        raise TypeError(node, 'is not a sp.Number')

    if isinstance(node, sp.Float) or isinstance(node, sp.RealNumber):
Jan Hoenig's avatar
Jan Hoenig committed
569
        return createType('double'), float(node)
Jan Hoenig's avatar
Jan Hoenig committed
570
    elif isinstance(node, sp.Integer):
Jan Hoenig's avatar
Jan Hoenig committed
571
        return createType('int'), int(node)
Jan Hoenig's avatar
Jan Hoenig committed
572
    elif isinstance(node, sp.Rational):
573
574
        # TODO is it always float?
        return createType('double'), float(node.p/node.q)
Jan Hoenig's avatar
Jan Hoenig committed
575
    else:
Jan Hoenig's avatar
Jan Hoenig committed
576
        raise TypeError(node, ' is not a supported type (yet)!')