data_types.py 17.9 KB
Newer Older
1
import ctypes
2
import sympy as sp
3
import numpy as np
4

5
6
7
8
9
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
10
from sympy.core.cache import cacheit
11

12
from pystencils.cache import memorycache
Martin Bauer's avatar
Martin Bauer committed
13
from pystencils.utils import all_equal
Martin Bauer's avatar
Martin Bauer committed
14
from sympy.logic.boolalg import Boolean
15

16

17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# noinspection PyPep8Naming
class address_of(sp.Function):
    is_Atom = True

    def __new__(cls, arg):
        obj = sp.Function.__new__(cls, arg)
        return obj

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()

    @property
    def is_commutative(self):
        return self.args[0].is_commutative

    @property
    def dtype(self):
        if hasattr(self.args[0], 'dtype'):
            return PointerType(self.args[0].dtype, restrict=True)
        else:
            return PointerType('void', restrict=True)


Martin Bauer's avatar
Martin Bauer committed
44
# noinspection PyPep8Naming
45
class cast_func(sp.Function):
46
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
47

48
49
50
51
52
53
54
55
    def __new__(cls, *args, **kwargs):
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
56
        # -> thus a separate class boolean_cast_func is introduced
57
58
59
60
        if isinstance(args[0], Boolean):
            cls = boolean_cast_func
        return sp.Function.__new__(cls, *args, **kwargs)

61
62
63
64
65
66
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
67

68
69
70
71
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
72
73
74
75
76
    @property
    def dtype(self):
        return self.args[1]


77
78
79
80
81
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
82
83
84
85
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

86

87
88
89
90
91
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
    pass


Martin Bauer's avatar
Martin Bauer committed
92
93
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
94
95
96
97
98
99
100
101
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


102
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
103
104
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
105
106
        return obj

107
    def __new_stage2__(cls, name, dtype):
108
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
109
        try:
Martin Bauer's avatar
Martin Bauer committed
110
            obj._dtype = create_type(dtype)
111
112
113
        except TypeError:
            # on error keep the string
            obj._dtype = dtype
114
115
116
117
118
119
120
121
122
123
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
124
        return super()._hashable_content(), hash(self._dtype)
125
126

    def __getnewargs__(self):
127
128
129
        return self.name, self.dtype


Martin Bauer's avatar
Martin Bauer committed
130
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
131
132
133
134
135
136
137
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
138
    """
139
140
141
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
142
143
144
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
145
        else:
Martin Bauer's avatar
Martin Bauer committed
146
            return StructType(numpy_dtype, const=False)
147
148


149
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
150
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
151
152
153
154
155
156
157
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
158
    """
159
160
161
162
163
164
165
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
166
        else:
167
168
169
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
170
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
171
    base_part = parts.pop(0)
172
    const = False
Martin Bauer's avatar
Martin Bauer committed
173
    if 'const' in base_part:
174
        const = True
Martin Bauer's avatar
Martin Bauer committed
175
176
177
178
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
179
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
180
    current_type = BasicType(np.dtype(base_part[0]), const)
181
182
183
184
185
186
187
188
189
190
191
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
192
193
        current_type = PointerType(current_type, const, restrict)
    return current_type
194
195


Martin Bauer's avatar
Martin Bauer committed
196
197
198
199
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
200
201


Martin Bauer's avatar
Martin Bauer committed
202
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
203
204
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
205
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
206
207
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
208
209
210
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
211
        return ctypes.POINTER(ctypes.c_uint8)
212
    else:
Martin Bauer's avatar
Martin Bauer committed
213
        return to_ctypes.map[data_type.numpy_dtype]
214

Martin Bauer's avatar
Martin Bauer committed
215
216

to_ctypes.map = {
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


232
def ctypes_from_llvm(data_type):
233
234
    if not ir:
        raise _ir_importerror
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
268
269
    if not ir:
        raise _ir_importerror
270
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
271
        return to_llvm_type(data_type.base_type).as_pointer()
272
    else:
Martin Bauer's avatar
Martin Bauer committed
273
274
        return to_llvm_type.map[data_type.numpy_dtype]

275

276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
291

292

Martin Bauer's avatar
Martin Bauer committed
293
294
295
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
296
297
298
    return dtype


Martin Bauer's avatar
Martin Bauer committed
299
def collate_types(types):
300
301
302
303
304
305
306
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
307
        pointer_type = None
308
309
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
310
                if pointer_type is not None:
311
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
312
                pointer_type = t
313
314
315
316
317
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
318
        return pointer_type
319
320

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
321
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
322
    if not all_equal(t.width for t in vector_type):
323
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
324
    types = [peel_off_type(t, VectorType) for t in types]
325
326
327
328

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

329
330
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
331
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
332
333
334
335
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
336
337
338
339
    return result


@memorycache(maxsize=2048)
Martin Bauer's avatar
Martin Bauer committed
340
def get_type_of_expression(expr):
341
    from pystencils.astnodes import ResolvedFieldAccess
342
343
    from pystencils.cpu.vectorization import vec_all, vec_any

344
345
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
346
        return create_type("int")
347
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
Martin Bauer's avatar
Martin Bauer committed
348
        return create_type("double")
349
350
351
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
352
        return expr.dtype
353
    elif isinstance(expr, sp.Symbol):
354
        raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
355
    elif isinstance(expr, cast_func):
356
        return expr.args[1]
357
358
    elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
        return create_type("bool")
359
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
360
361
362
363
364
        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
365
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
366
367
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
368
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
369
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
370
371
372
373
        result = create_type("bool")
        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
374
        return result
375
376
    elif isinstance(expr, sp.Pow):
        return get_type_of_expression(expr.args[0])
377
    elif isinstance(expr, sp.Expr):
Martin Bauer's avatar
Martin Bauer committed
378
379
        types = tuple(get_type_of_expression(a) for a in expr.args)
        return collate_types(types)
380

381
    raise NotImplementedError("Could not determine type for", expr, type(expr))
382
383


384
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
385
386
    is_Atom = True

387
388
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
389

390
391
    def _sympystr(self, *args, **kwargs):
        return str(self)
392
393
394
395


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
396
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
397
398
399
400
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
401
402
403
404
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
405
            width = int(name[len("uint"):])
406
407
408
409
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
Martin Bauer's avatar
Martin Bauer committed
410
            raise NotImplementedError("Can map numpy to C name for %s" % (name,))
411
412
413

    def __init__(self, dtype, const=False):
        self.const = const
414
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
415
            self._dtype = dtype.numpy_dtype
416
417
        else:
            self._dtype = np.dtype(dtype)
418
419
420
421
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

422
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
423
        return self.numpy_dtype, self.const
424

425
    @property
Martin Bauer's avatar
Martin Bauer committed
426
    def base_type(self):
427
        return None
428

429
    @property
Martin Bauer's avatar
Martin Bauer committed
430
    def numpy_dtype(self):
431
432
        return self._dtype

433
    @property
Martin Bauer's avatar
Martin Bauer committed
434
    def item_size(self):
435
436
        return 1

437
    def is_int(self):
Martin Bauer's avatar
Martin Bauer committed
438
        return self.numpy_dtype in np.sctypes['int']
439
440

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
441
        return self.numpy_dtype in np.sctypes['float']
442
443

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
444
        return self.numpy_dtype in np.sctypes['uint']
445

Martin Bauer's avatar
Martin Bauer committed
446
447
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
448
449

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
450
        return self.numpy_dtype in np.sctypes['others']
451

452
    @property
Martin Bauer's avatar
Martin Bauer committed
453
454
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
455

Jan Hoenig's avatar
Jan Hoenig committed
456
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
457
        result = BasicType.numpy_name_to_c(str(self._dtype))
458
459
460
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
461

462
463
464
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
465
    def __eq__(self, other):
466
467
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
468
        else:
Martin Bauer's avatar
Martin Bauer committed
469
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
470
471
472
473
474

    def __hash__(self):
        return hash(str(self))


475
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
476
    instruction_set = None
477

Martin Bauer's avatar
Martin Bauer committed
478
479
    def __init__(self, base_type, width=4):
        self._base_type = base_type
480
481
482
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
483
484
    def base_type(self):
        return self._base_type
485
486

    @property
Martin Bauer's avatar
Martin Bauer committed
487
488
    def item_size(self):
        return self.width * self.base_type.item_size
489
490
491
492
493

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
494
            return (self.base_type, self.width) == (other.base_type, other.width)
495
496

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
497
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
498
            return "%s[%d]" % (self.base_type, self.width)
499
        else:
Martin Bauer's avatar
Martin Bauer committed
500
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
501
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
502
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
503
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
504
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
505
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
506
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
507
                return self.instruction_set['bool']
508
509
510
511
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
512
        return hash((self.base_type, self.width))
513

Martin Bauer's avatar
Martin Bauer committed
514
515
516
    def __getnewargs__(self):
        return self._base_type, self.width

517

518
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
519
520
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
521
522
523
        self.const = const
        self.restrict = restrict

524
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
525
        return self.base_type, self.const, self.restrict
526

527
528
529
530
531
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
532
533
    def base_type(self):
        return self._base_type
534

535
    @property
Martin Bauer's avatar
Martin Bauer committed
536
537
    def item_size(self):
        return self.base_type.item_size
538

539
540
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
541
            return False
542
        else:
Martin Bauer's avatar
Martin Bauer committed
543
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
544

Jan Hoenig's avatar
Jan Hoenig committed
545
    def __str__(self):
546
547
548
549
550
551
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
552

553
554
555
    def __repr__(self):
        return str(self)

556
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
557
        return hash((self._base_type, self.const, self.restrict))
558

Jan Hoenig's avatar
Jan Hoenig committed
559

560
class StructType:
Martin Bauer's avatar
Martin Bauer committed
561
    def __init__(self, numpy_type, const=False):
562
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
563
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
564

565
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
566
        return self.numpy_dtype, self.const
567

568
    @property
Martin Bauer's avatar
Martin Bauer committed
569
    def base_type(self):
570
571
572
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
573
    def numpy_dtype(self):
574
575
576
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
577
578
    def item_size(self):
        return self.numpy_dtype.itemsize
579

Martin Bauer's avatar
Martin Bauer committed
580
581
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
582

Martin Bauer's avatar
Martin Bauer committed
583
584
585
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
586

Martin Bauer's avatar
Martin Bauer committed
587
588
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
589

590
591
592
593
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
594
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
595
596
597
598
599
600
601
602

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

603
604
605
    def __repr__(self):
        return str(self)

606
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
607
        return hash((self.numpy_dtype, self.const))