data_types.py 18.1 KB
Newer Older
1
import ctypes
2
import sympy as sp
3
import numpy as np
4
5
6
7
8
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
9
from sympy.core.cache import cacheit
10

11
from pystencils.cache import memorycache
12
13
from pystencils.utils import allEqual

14
15
16
17
18
19
20
21
22

# to work in conditions of sp.Piecewise castFunc has to be of type Relational as well
class castFunc(sp.Function, sp.Rel):
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
23

24
25
26
27
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

28

29
30
31
32
33
34
35
36
37
38
class pointerArithmeticFunc(sp.Function, sp.Rel):

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


39
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
40
41
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
42
43
        return obj

44
    def __new_stage2__(cls, name, dtype):
45
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
46
47
48
49
50
        try:
            obj._dtype = createType(dtype)
        except TypeError:
            # on error keep the string
            obj._dtype = dtype
51
52
53
54
55
56
57
58
59
60
61
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
        superClassContents = list(super(TypedSymbol, self)._hashable_content())
62
        return tuple(superClassContents + [hash(str(self._dtype))])
63
64

    def __getnewargs__(self):
65
66
67
        return self.name, self.dtype


68
def createType(specification):
Jan Hoenig's avatar
Jan Hoenig committed
69
70
71
72
73
    """
    Create a subclass of Type according to a string or an object of subclass Type
    :param specification: Type object, or a string
    :return: Type object, or a new Type object parsed from the string
    """
74
75
76
77
78
79
    if isinstance(specification, Type):
        return specification
    elif isinstance(specification, str):
        return createTypeFromString(specification)
    else:
        npDataType = np.dtype(specification)
80
81
82
83
        if npDataType.fields is None:
            return BasicType(npDataType, const=False)
        else:
            return StructType(npDataType, const=False)
84
85


86
@memorycache(maxsize=64)
87
def createTypeFromString(specification):
Jan Hoenig's avatar
Jan Hoenig committed
88
89
90
91
92
    """
    Creates a new Type object from a c-like string specification
    :param specification: Specification string
    :return: Type object
    """
93
94
95
96
97
98
99
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
100
        else:
101
102
103
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
104
        # Parse native part
105
106
107
108
109
110
    basePart = parts.pop(0)
    const = False
    if 'const' in basePart:
        const = True
        basePart.remove('const')
    assert len(basePart) == 1
Jan Hoenig's avatar
Jan Hoenig committed
111
112
113
    if basePart[0][-1] == "*":
        basePart[0] = basePart[0][:-1]
        parts.append('*')
114
115
116
117
    try:
        baseType = BasicType(basePart[0], const)
    except TypeError:
        baseType = BasicType(createTypeFromString.map[basePart[0]], const)
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    currentType = baseType
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
        currentType = PointerType(currentType, const, restrict)
    return currentType

133
134
135
136
137
138
139
createTypeFromString.map = {
    'i64': np.int64,
    'i32': np.int32,
    'i16': np.int16,
    'i8': np.int8,
}

140
141
142
143
144
145
146
147

def getBaseType(type):
    while type.baseType is not None:
        type = type.baseType
    return type


def toCtypes(dataType):
Jan Hoenig's avatar
Jan Hoenig committed
148
149
150
151
152
    """
    Transforms a given Type into ctypes
    :param dataType: Subclass of Type
    :return: ctypes type object
    """
153
154
    if isinstance(dataType, PointerType):
        return ctypes.POINTER(toCtypes(dataType.baseType))
155
156
    elif isinstance(dataType, StructType):
        return ctypes.POINTER(ctypes.c_uint8)
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    else:
        return toCtypes.map[dataType.numpyDtype]

toCtypes.map = {
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


176
def ctypes_from_llvm(data_type):
177
178
    if not ir:
        raise _ir_importerror
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
212
213
    if not ir:
        raise _ir_importerror
214
215
216
217
218
    if isinstance(data_type, PointerType):
        return to_llvm_type(data_type.baseType).as_pointer()
    else:
        return to_llvm_type.map[data_type.numpyDtype]

219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
234

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
def peelOffType(dtype, typeToPeelOff):
    while type(dtype) is typeToPeelOff:
        dtype = dtype.baseType
    return dtype


def collateTypes(types):
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
        pointerType = None
        for t in types:
            if type(t) is PointerType:
                if pointerType is not None:
                    raise ValueError("Cannot collate the combination of two pointer types")
                pointerType = t
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
        return pointerType

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
    vectorType = [t for t in types if type(t) is VectorType]
    if not allEqual(t.width for t in vectorType):
        raise ValueError("Collation failed because of vector types with different width")
    types = [peelOffType(t, VectorType) for t in types]

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
    resultNumpyType = np.result_type(*(t.numpyDtype for t in types))
    result = BasicType(resultNumpyType)
    if vectorType:
        result = VectorType(result, vectorType[0].width)
    return result


@memorycache(maxsize=2048)
280
def getTypeOfExpression(expr):
281
282
283
284
285
286
287
288
289
    from pystencils.astnodes import ResolvedFieldAccess
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
        return createTypeFromString("int")
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
        return createTypeFromString("double")
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
290
        return expr.dtype
291
292
293
294
295
    elif isinstance(expr, sp.Symbol):
        raise ValueError("All symbols inside this expression have to be typed!")
    elif hasattr(expr, 'func') and expr.func == castFunc:
        return expr.args[1]
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
296
297
298
299
300
        collatedResultType = collateTypes(tuple(getTypeOfExpression(a[0]) for a in expr.args))
        collatedConditionType = collateTypes(tuple(getTypeOfExpression(a[1]) for a in expr.args))
        if type(collatedConditionType) is VectorType and type(collatedResultType) is not VectorType:
            collatedResultType = VectorType(collatedResultType, width=collatedConditionType.width)
        return collatedResultType
301
302
    elif isinstance(expr, sp.Indexed):
        typedSymbol = expr.base.label
303
        return typedSymbol.dtype.baseType
304
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
305
306
307
308
309
310
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
        result = createTypeFromString("bool")
        vecArgs = [getTypeOfExpression(a) for a in expr.args if isinstance(getTypeOfExpression(a), VectorType)]
        if vecArgs:
            result = VectorType(result, width=vecArgs[0].width)
        return result
311
312
313
314
    elif isinstance(expr, sp.Expr):
        types = tuple(getTypeOfExpression(a) for a in expr.args)
        return collateTypes(types)

315
    raise NotImplementedError("Could not determine type for", expr, type(expr))
316
317


318
319
320
class Type(sp.Basic):
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
321

322
    def __lt__(self, other):  # deprecated
Jan Hoenig's avatar
Jan Hoenig committed
323
324
325
        # Needed for sorting the types inside an expression
        if isinstance(self, BasicType):
            if isinstance(other, BasicType):
326
327
                return self.numpyDtype > other.numpyDtype  # TODO const
            elif isinstance(other, PointerType):
328
                return False
329
            else:  # isinstance(other, StructType):
Jan Hoenig's avatar
Jan Hoenig committed
330
                raise NotImplementedError("Struct type comparison is not yet implemented")
331
        elif isinstance(self, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
332
            if isinstance(other, BasicType):
333
                return True
334
335
336
            elif isinstance(other, PointerType):
                return self.baseType > other.baseType  # TODO const, restrict
            else:  # isinstance(other, StructType):
Jan Hoenig's avatar
Jan Hoenig committed
337
                raise NotImplementedError("Struct type comparison is not yet implemented")
338
        elif isinstance(self, StructType):
Jan Hoenig's avatar
Jan Hoenig committed
339
            raise NotImplementedError("Struct type comparison is not yet implemented")
340
341
342
343
344
        else:
            raise NotImplementedError

    def _sympystr(self, *args, **kwargs):
        return str(self)
345

346
347
348
    def _sympystr(self, *args, **kwargs):
        return str(self)

349
350
351
352

class BasicType(Type):
    @staticmethod
    def numpyNameToC(name):
Jan Hoenig's avatar
Jan Hoenig committed
353
354
355
356
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
357
358
359
360
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
361
            width = int(name[len("uint"):])
362
363
364
365
366
367
368
369
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
            raise NotImplemented("Can map numpy to C name for %s" % (name,))

    def __init__(self, dtype, const=False):
        self.const = const
370
371
372
373
        if isinstance(dtype, Type):
            self._dtype = dtype.numpyDtype
        else:
            self._dtype = np.dtype(dtype)
374
375
376
377
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

378
379
380
    def __getnewargs__(self):
        return self.numpyDtype, self.const

381
382
383
    @property
    def baseType(self):
        return None
384

385
386
387
388
    @property
    def numpyDtype(self):
        return self._dtype

389
390
391
392
    @property
    def itemSize(self):
        return 1

393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
    def is_int(self):
        return self.numpyDtype in np.sctypes['int']

    def is_float(self):
        return self.numpyDtype in np.sctypes['float']

    def is_uint(self):
        return self.numpyDtype in np.sctypes['uint']

    def is_comlex(self):
        return self.numpyDtype in np.sctypes['complex']

    def is_other(self):
        return self.numpyDtype in np.sctypes['others']

408
409
410
411
    @property
    def baseName(self):
        return BasicType.numpyNameToC(str(self._dtype))

Jan Hoenig's avatar
Jan Hoenig committed
412
    def __str__(self):
413
414
415
416
        result = BasicType.numpyNameToC(str(self._dtype))
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
417

418
419
420
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
421
    def __eq__(self, other):
422
423
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
424
        else:
425
426
427
428
429
430
            return (self.numpyDtype, self.const) == (other.numpyDtype, other.const)

    def __hash__(self):
        return hash(str(self))


431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
class VectorType(Type):
    instructionSet = None

    def __init__(self, baseType, width=4):
        self._baseType = baseType
        self.width = width

    @property
    def baseType(self):
        return self._baseType

    @property
    def itemSize(self):
        return self.width * self.baseType.itemSize

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
            return (self.baseType, self.width) == (other.baseType, other.width)

    def __str__(self):
        if self.instructionSet is None:
            return "%s[%d]" % (self.baseType, self.width)
        else:
            if self.baseType == createTypeFromString("int64"):
                return self.instructionSet['int']
            elif self.baseType == createTypeFromString("double"):
                return self.instructionSet['double']
            elif self.baseType == createTypeFromString("float"):
                return self.instructionSet['float']
462
463
            elif self.baseType == createTypeFromString("bool"):
                return self.instructionSet['bool']
464
465
466
467
468
469
470
            else:
                raise NotImplementedError()

    def __hash__(self):
        return hash(str(self))


471
472
473
474
475
476
class PointerType(Type):
    def __init__(self, baseType, const=False, restrict=True):
        self._baseType = baseType
        self.const = const
        self.restrict = restrict

477
478
479
    def __getnewargs__(self):
        return self.baseType, self.const, self.restrict

480
481
482
483
484
485
486
487
    @property
    def alias(self):
        return not self.restrict

    @property
    def baseType(self):
        return self._baseType

488
489
490
491
    @property
    def itemSize(self):
        return self.baseType.itemSize

492
493
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
494
            return False
495
496
497
        else:
            return (self.baseType, self.const, self.restrict) == (other.baseType, other.const, other.restrict)

Jan Hoenig's avatar
Jan Hoenig committed
498
    def __str__(self):
Jan Hoenig's avatar
Jan Hoenig committed
499
        return "%s *%s%s" % (self.baseType, " RESTRICT " if self.restrict else "", " const " if self.const else "")
500

501
502
503
    def __repr__(self):
        return str(self)

504
505
506
    def __hash__(self):
        return hash(str(self))

Jan Hoenig's avatar
Jan Hoenig committed
507

508
class StructType(object):
509
510
    def __init__(self, numpyType, const=False):
        self.const = const
511
        self._dtype = np.dtype(numpyType)
Martin Bauer's avatar
Martin Bauer committed
512

513
514
515
    def __getnewargs__(self):
        return self.numpyDtype, self.const

516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
    @property
    def baseType(self):
        return None

    @property
    def numpyDtype(self):
        return self._dtype

    @property
    def itemSize(self):
        return self.numpyDtype.itemsize

    def getElementOffset(self, elementName):
        return self.numpyDtype.fields[elementName][1]

    def getElementType(self, elementName):
        npElementType = self.numpyDtype.fields[elementName][0]
        return BasicType(npElementType, self.const)

535
536
537
    def hasElement(self, elementName):
        return elementName in self.numpyDtype.fields

538
539
540
541
542
543
544
545
546
547
548
549
550
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
            return (self.numpyDtype, self.const) == (other.numpyDtype, other.const)

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

551
552
553
    def __repr__(self):
        return str(self)

554
555
    def __hash__(self):
        return hash((self.numpyDtype, self.const))
556

Jan Hoenig's avatar
Jan Hoenig committed
557
    # TODO this should not work at all!!!
Jan Hoenig's avatar
Jan Hoenig committed
558
559
560
561
562
563
564
    def __gt__(self, other):
        if self.ptr and not other.ptr:
            return True
        if self.dtype > other.dtype:
            return True


Jan Hoenig's avatar
Jan Hoenig committed
565
def get_type_from_sympy(node):
Jan Hoenig's avatar
Jan Hoenig committed
566
567
568
569
570
    """
    Creates a Type object from a Sympy object
    :param node: Sympy object
    :return: Type object
    """
Jan Hoenig's avatar
Jan Hoenig committed
571
572
573
574
575
576
577
578
579
580
    # Rational, NumberSymbol?
    # Zero, One, NegativeOne )= Integer
    # Half )= Rational
    # NAN, Infinity, Negative Inifinity,
    # Exp1, Imaginary Unit, Pi, EulerGamma, Catalan, Golden Ratio
    # Pow, Mul, Add, Mod, Relational
    if not isinstance(node, sp.Number):
        raise TypeError(node, 'is not a sp.Number')

    if isinstance(node, sp.Float) or isinstance(node, sp.RealNumber):
Jan Hoenig's avatar
Jan Hoenig committed
581
        return createType('double'), float(node)
Jan Hoenig's avatar
Jan Hoenig committed
582
    elif isinstance(node, sp.Integer):
Jan Hoenig's avatar
Jan Hoenig committed
583
        return createType('int'), int(node)
Jan Hoenig's avatar
Jan Hoenig committed
584
    elif isinstance(node, sp.Rational):
585
586
        # TODO is it always float?
        return createType('double'), float(node.p/node.q)
Jan Hoenig's avatar
Jan Hoenig committed
587
    else:
Jan Hoenig's avatar
Jan Hoenig committed
588
        raise TypeError(node, ' is not a supported type (yet)!')