data_types.py 18.8 KB
Newer Older
1
import ctypes
Martin Bauer's avatar
Martin Bauer committed
2

3
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
4
5
6
7
8
9
import sympy as sp
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean

from pystencils.cache import memorycache
from pystencils.utils import all_equal
10

11
12
13
14
15
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
16

17

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# noinspection PyPep8Naming
class address_of(sp.Function):
    is_Atom = True

    def __new__(cls, arg):
        obj = sp.Function.__new__(cls, arg)
        return obj

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()

    @property
    def is_commutative(self):
        return self.args[0].is_commutative

    @property
    def dtype(self):
        if hasattr(self.args[0], 'dtype'):
            return PointerType(self.args[0].dtype, restrict=True)
        else:
            return PointerType('void', restrict=True)


Martin Bauer's avatar
Martin Bauer committed
45
# noinspection PyPep8Naming
46
class cast_func(sp.Function):
47
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
48

49
50
51
52
53
54
55
56
    def __new__(cls, *args, **kwargs):
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
57
        # -> thus a separate class boolean_cast_func is introduced
58
59
60
61
        if isinstance(args[0], Boolean):
            cls = boolean_cast_func
        return sp.Function.__new__(cls, *args, **kwargs)

62
63
64
65
66
67
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
68

69
70
71
72
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
73
74
75
76
77
    @property
    def dtype(self):
        return self.args[1]


78
79
80
81
82
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
83
84
85
86
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

87

88
89
90
91
92
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
    pass


Martin Bauer's avatar
Martin Bauer committed
93
94
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
95
96
97
98
99
100
101
102
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


103
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
104
105
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
106
107
        return obj

108
    def __new_stage2__(cls, name, dtype):
109
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
110
        try:
Martin Bauer's avatar
Martin Bauer committed
111
            obj._dtype = create_type(dtype)
112
        except (TypeError, ValueError):
113
114
            # on error keep the string
            obj._dtype = dtype
115
116
117
118
119
120
121
122
123
124
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
125
        return super()._hashable_content(), hash(self._dtype)
126
127

    def __getnewargs__(self):
128
129
        return self.name, self.dtype

130
    # For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
131
132
133
    @property
    def is_integer(self):
        if hasattr(self.dtype, 'numpy_dtype'):
134
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
135
        else:
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
            return super().is_integer

    @property
    def is_negative(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

        return super().is_positive

    @property
    def is_real(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real
154

155

Martin Bauer's avatar
Martin Bauer committed
156
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
157
158
159
160
161
162
163
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
164
    """
165
166
167
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
168
169
170
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
171
        else:
Martin Bauer's avatar
Martin Bauer committed
172
            return StructType(numpy_dtype, const=False)
173
174


175
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
176
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
177
178
179
180
181
182
183
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
184
    """
185
186
187
188
189
190
191
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
192
        else:
193
194
195
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
196
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
197
    base_part = parts.pop(0)
198
    const = False
Martin Bauer's avatar
Martin Bauer committed
199
    if 'const' in base_part:
200
        const = True
Martin Bauer's avatar
Martin Bauer committed
201
202
203
204
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
205
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
206
    current_type = BasicType(np.dtype(base_part[0]), const)
207
208
209
210
211
212
213
214
215
216
217
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
218
219
        current_type = PointerType(current_type, const, restrict)
    return current_type
220
221


Martin Bauer's avatar
Martin Bauer committed
222
223
224
225
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
226
227


Martin Bauer's avatar
Martin Bauer committed
228
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
229
230
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
231
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
232
233
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
234
235
236
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
237
        return ctypes.POINTER(ctypes.c_uint8)
238
    else:
Martin Bauer's avatar
Martin Bauer committed
239
        return to_ctypes.map[data_type.numpy_dtype]
240

Martin Bauer's avatar
Martin Bauer committed
241
242

to_ctypes.map = {
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


258
def ctypes_from_llvm(data_type):
259
260
    if not ir:
        raise _ir_importerror
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
294
295
    if not ir:
        raise _ir_importerror
296
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
297
        return to_llvm_type(data_type.base_type).as_pointer()
298
    else:
Martin Bauer's avatar
Martin Bauer committed
299
300
        return to_llvm_type.map[data_type.numpy_dtype]

301

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
317

318

Martin Bauer's avatar
Martin Bauer committed
319
320
321
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
322
323
324
    return dtype


Martin Bauer's avatar
Martin Bauer committed
325
def collate_types(types):
326
327
328
329
330
331
332
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
333
        pointer_type = None
334
335
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
336
                if pointer_type is not None:
337
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
338
                pointer_type = t
339
340
341
342
343
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
344
        return pointer_type
345
346

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
347
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
348
    if not all_equal(t.width for t in vector_type):
349
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
350
    types = [peel_off_type(t, VectorType) for t in types]
351
352
353
354

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

355
356
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
357
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
358
359
360
361
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
362
363
364
365
    return result


@memorycache(maxsize=2048)
Martin Bauer's avatar
Martin Bauer committed
366
def get_type_of_expression(expr):
367
    from pystencils.astnodes import ResolvedFieldAccess
368
369
    from pystencils.cpu.vectorization import vec_all, vec_any

370
371
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
372
        return create_type("int")
373
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
Martin Bauer's avatar
Martin Bauer committed
374
        return create_type("double")
375
376
377
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
378
        return expr.dtype
379
    elif isinstance(expr, sp.Symbol):
380
        raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
381
    elif isinstance(expr, cast_func):
382
        return expr.args[1]
383
384
    elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
        return create_type("bool")
385
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
386
387
388
389
390
        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
391
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
392
393
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
394
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
395
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
396
397
398
399
        result = create_type("bool")
        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
400
        return result
401
402
    elif isinstance(expr, sp.Pow):
        return get_type_of_expression(expr.args[0])
403
    elif isinstance(expr, sp.Expr):
Martin Bauer's avatar
Martin Bauer committed
404
405
        types = tuple(get_type_of_expression(a) for a in expr.args)
        return collate_types(types)
406

407
    raise NotImplementedError("Could not determine type for", expr, type(expr))
408
409


410
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
411
412
    is_Atom = True

413
414
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
415

416
417
    def _sympystr(self, *args, **kwargs):
        return str(self)
418
419
420
421


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
422
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
423
424
425
426
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
427
428
429
430
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
431
            width = int(name[len("uint"):])
432
433
434
435
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
Martin Bauer's avatar
Martin Bauer committed
436
            raise NotImplementedError("Can map numpy to C name for %s" % (name,))
437
438
439

    def __init__(self, dtype, const=False):
        self.const = const
440
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
441
            self._dtype = dtype.numpy_dtype
442
443
        else:
            self._dtype = np.dtype(dtype)
444
445
446
447
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

448
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
449
        return self.numpy_dtype, self.const
450

451
    @property
Martin Bauer's avatar
Martin Bauer committed
452
    def base_type(self):
453
        return None
454

455
    @property
Martin Bauer's avatar
Martin Bauer committed
456
    def numpy_dtype(self):
457
458
        return self._dtype

459
    @property
Martin Bauer's avatar
Martin Bauer committed
460
    def item_size(self):
461
462
        return 1

463
    def is_int(self):
Martin Bauer's avatar
Martin Bauer committed
464
        return self.numpy_dtype in np.sctypes['int']
465
466

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
467
        return self.numpy_dtype in np.sctypes['float']
468
469

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
470
        return self.numpy_dtype in np.sctypes['uint']
471

Martin Bauer's avatar
Martin Bauer committed
472
473
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
474
475

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
476
        return self.numpy_dtype in np.sctypes['others']
477

478
    @property
Martin Bauer's avatar
Martin Bauer committed
479
480
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
481

Jan Hoenig's avatar
Jan Hoenig committed
482
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
483
        result = BasicType.numpy_name_to_c(str(self._dtype))
484
485
486
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
487

488
489
490
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
491
    def __eq__(self, other):
492
493
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
494
        else:
Martin Bauer's avatar
Martin Bauer committed
495
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
496
497
498
499
500

    def __hash__(self):
        return hash(str(self))


501
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
502
    instruction_set = None
503

Martin Bauer's avatar
Martin Bauer committed
504
505
    def __init__(self, base_type, width=4):
        self._base_type = base_type
506
507
508
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
509
510
    def base_type(self):
        return self._base_type
511
512

    @property
Martin Bauer's avatar
Martin Bauer committed
513
514
    def item_size(self):
        return self.width * self.base_type.item_size
515
516
517
518
519

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
520
            return (self.base_type, self.width) == (other.base_type, other.width)
521
522

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
523
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
524
            return "%s[%d]" % (self.base_type, self.width)
525
        else:
Martin Bauer's avatar
Martin Bauer committed
526
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
527
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
528
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
529
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
530
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
531
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
532
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
533
                return self.instruction_set['bool']
534
535
536
537
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
538
        return hash((self.base_type, self.width))
539

Martin Bauer's avatar
Martin Bauer committed
540
541
542
    def __getnewargs__(self):
        return self._base_type, self.width

543

544
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
545
546
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
547
548
549
        self.const = const
        self.restrict = restrict

550
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
551
        return self.base_type, self.const, self.restrict
552

553
554
555
556
557
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
558
559
    def base_type(self):
        return self._base_type
560

561
    @property
Martin Bauer's avatar
Martin Bauer committed
562
563
    def item_size(self):
        return self.base_type.item_size
564

565
566
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
567
            return False
568
        else:
Martin Bauer's avatar
Martin Bauer committed
569
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
570

Jan Hoenig's avatar
Jan Hoenig committed
571
    def __str__(self):
572
573
574
575
576
577
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
578

579
580
581
    def __repr__(self):
        return str(self)

582
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
583
        return hash((self._base_type, self.const, self.restrict))
584

Jan Hoenig's avatar
Jan Hoenig committed
585

586
class StructType:
Martin Bauer's avatar
Martin Bauer committed
587
    def __init__(self, numpy_type, const=False):
588
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
589
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
590

591
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
592
        return self.numpy_dtype, self.const
593

594
    @property
Martin Bauer's avatar
Martin Bauer committed
595
    def base_type(self):
596
597
598
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
599
    def numpy_dtype(self):
600
601
602
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
603
604
    def item_size(self):
        return self.numpy_dtype.itemsize
605

Martin Bauer's avatar
Martin Bauer committed
606
607
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
608

Martin Bauer's avatar
Martin Bauer committed
609
610
611
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
612

Martin Bauer's avatar
Martin Bauer committed
613
614
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
615

616
617
618
619
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
620
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
621
622
623
624
625
626
627
628

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

629
630
631
    def __repr__(self):
        return str(self)

632
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
633
        return hash((self.numpy_dtype, self.const))