data_types.py 19 KB
Newer Older
1
import ctypes
Martin Bauer's avatar
Martin Bauer committed
2

3
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
4
5
6
7
8
9
import sympy as sp
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean

from pystencils.cache import memorycache
from pystencils.utils import all_equal
10

11
12
13
14
15
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
16

17

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# noinspection PyPep8Naming
class address_of(sp.Function):
    is_Atom = True

    def __new__(cls, arg):
        obj = sp.Function.__new__(cls, arg)
        return obj

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()

    @property
    def is_commutative(self):
        return self.args[0].is_commutative

    @property
    def dtype(self):
        if hasattr(self.args[0], 'dtype'):
            return PointerType(self.args[0].dtype, restrict=True)
        else:
            return PointerType('void', restrict=True)


Martin Bauer's avatar
Martin Bauer committed
45
# noinspection PyPep8Naming
46
class cast_func(sp.Function):
47
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
48

49
50
51
52
53
54
55
56
    def __new__(cls, *args, **kwargs):
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
57
        # -> thus a separate class boolean_cast_func is introduced
58
59
60
61
        if isinstance(args[0], Boolean):
            cls = boolean_cast_func
        return sp.Function.__new__(cls, *args, **kwargs)

62
63
64
65
66
67
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
68

69
70
71
72
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
73
74
75
76
77
    @property
    def dtype(self):
        return self.args[1]


78
79
80
81
82
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
83
84
85
86
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

87

88
89
90
91
92
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
    pass


Martin Bauer's avatar
Martin Bauer committed
93
94
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
95
96
97
98
99
100
101
102
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


103
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
104
105
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
106
107
        return obj

108
109
    def __new_stage2__(cls, name, dtype, *args, **kwargs):
        obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **kwargs)
110
        try:
Martin Bauer's avatar
Martin Bauer committed
111
            obj._dtype = create_type(dtype)
112
        except (TypeError, ValueError):
113
114
            # on error keep the string
            obj._dtype = dtype
115
116
117
118
119
120
121
122
123
124
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
125
        return super()._hashable_content(), hash(self._dtype)
126
127

    def __getnewargs__(self):
128
129
        return self.name, self.dtype

130
    # For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
131
132
133
    @property
    def is_integer(self):
        if hasattr(self.dtype, 'numpy_dtype'):
134
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
135
        else:
136
137
138
139
140
141
142
143
            return super().is_integer

    @property
    def is_negative(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

144
        return super().is_negative
145

146
147
148
149
150
151
152
    @property
    def is_nonnegative(self):
        if self.is_negative is False:
            return True
        else:
            return super().is_nonnegative

153
154
155
156
157
158
159
160
    @property
    def is_real(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real
161

162

Martin Bauer's avatar
Martin Bauer committed
163
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
164
165
166
167
168
169
170
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
171
    """
172
173
174
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
175
176
177
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
178
        else:
Martin Bauer's avatar
Martin Bauer committed
179
            return StructType(numpy_dtype, const=False)
180
181


182
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
183
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
184
185
186
187
188
189
190
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
191
    """
192
193
194
195
196
197
198
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
199
        else:
200
201
202
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
203
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
204
    base_part = parts.pop(0)
205
    const = False
Martin Bauer's avatar
Martin Bauer committed
206
    if 'const' in base_part:
207
        const = True
Martin Bauer's avatar
Martin Bauer committed
208
209
210
211
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
212
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
213
    current_type = BasicType(np.dtype(base_part[0]), const)
214
215
216
217
218
219
220
221
222
223
224
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
225
226
        current_type = PointerType(current_type, const, restrict)
    return current_type
227
228


Martin Bauer's avatar
Martin Bauer committed
229
230
231
232
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
233
234


Martin Bauer's avatar
Martin Bauer committed
235
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
236
237
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
238
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
239
240
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
241
242
243
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
244
        return ctypes.POINTER(ctypes.c_uint8)
245
    else:
Martin Bauer's avatar
Martin Bauer committed
246
        return to_ctypes.map[data_type.numpy_dtype]
247

Martin Bauer's avatar
Martin Bauer committed
248
249

to_ctypes.map = {
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


265
def ctypes_from_llvm(data_type):
266
267
    if not ir:
        raise _ir_importerror
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
301
302
    if not ir:
        raise _ir_importerror
303
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
304
        return to_llvm_type(data_type.base_type).as_pointer()
305
    else:
Martin Bauer's avatar
Martin Bauer committed
306
307
        return to_llvm_type.map[data_type.numpy_dtype]

308

309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
324

325

Martin Bauer's avatar
Martin Bauer committed
326
327
328
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
329
330
331
    return dtype


Martin Bauer's avatar
Martin Bauer committed
332
def collate_types(types):
333
334
335
336
337
338
339
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
340
        pointer_type = None
341
342
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
343
                if pointer_type is not None:
344
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
345
                pointer_type = t
346
347
348
349
350
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
351
        return pointer_type
352
353

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
354
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
355
    if not all_equal(t.width for t in vector_type):
356
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
357
    types = [peel_off_type(t, VectorType) for t in types]
358
359
360
361

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

362
363
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
364
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
365
366
367
368
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
369
370
371
372
    return result


@memorycache(maxsize=2048)
Martin Bauer's avatar
Martin Bauer committed
373
def get_type_of_expression(expr):
374
    from pystencils.astnodes import ResolvedFieldAccess
375
376
    from pystencils.cpu.vectorization import vec_all, vec_any

377
378
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
379
        return create_type("int")
380
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
Martin Bauer's avatar
Martin Bauer committed
381
        return create_type("double")
382
383
384
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
385
        return expr.dtype
386
    elif isinstance(expr, sp.Symbol):
387
        raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
388
    elif isinstance(expr, cast_func):
389
        return expr.args[1]
390
391
    elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
        return create_type("bool")
392
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
393
394
395
396
397
        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
398
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
399
400
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
401
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
402
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
403
404
405
406
        result = create_type("bool")
        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
407
        return result
408
409
    elif isinstance(expr, sp.Pow):
        return get_type_of_expression(expr.args[0])
410
    elif isinstance(expr, sp.Expr):
Martin Bauer's avatar
Martin Bauer committed
411
412
        types = tuple(get_type_of_expression(a) for a in expr.args)
        return collate_types(types)
413

414
    raise NotImplementedError("Could not determine type for", expr, type(expr))
415
416


417
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
418
419
    is_Atom = True

420
421
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
422

423
424
    def _sympystr(self, *args, **kwargs):
        return str(self)
425
426
427
428


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
429
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
430
431
432
433
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
434
435
436
437
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
438
            width = int(name[len("uint"):])
439
440
441
442
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
Martin Bauer's avatar
Martin Bauer committed
443
            raise NotImplementedError("Can map numpy to C name for %s" % (name,))
444
445
446

    def __init__(self, dtype, const=False):
        self.const = const
447
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
448
            self._dtype = dtype.numpy_dtype
449
450
        else:
            self._dtype = np.dtype(dtype)
451
452
453
454
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

455
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
456
        return self.numpy_dtype, self.const
457

458
    @property
Martin Bauer's avatar
Martin Bauer committed
459
    def base_type(self):
460
        return None
461

462
    @property
Martin Bauer's avatar
Martin Bauer committed
463
    def numpy_dtype(self):
464
465
        return self._dtype

466
    @property
Martin Bauer's avatar
Martin Bauer committed
467
    def item_size(self):
468
469
        return 1

470
    def is_int(self):
Martin Bauer's avatar
Martin Bauer committed
471
        return self.numpy_dtype in np.sctypes['int']
472
473

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
474
        return self.numpy_dtype in np.sctypes['float']
475
476

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
477
        return self.numpy_dtype in np.sctypes['uint']
478

Martin Bauer's avatar
Martin Bauer committed
479
480
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
481
482

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
483
        return self.numpy_dtype in np.sctypes['others']
484

485
    @property
Martin Bauer's avatar
Martin Bauer committed
486
487
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
488

Jan Hoenig's avatar
Jan Hoenig committed
489
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
490
        result = BasicType.numpy_name_to_c(str(self._dtype))
491
492
493
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
494

495
496
497
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
498
    def __eq__(self, other):
499
500
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
501
        else:
Martin Bauer's avatar
Martin Bauer committed
502
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
503
504
505
506
507

    def __hash__(self):
        return hash(str(self))


508
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
509
    instruction_set = None
510

Martin Bauer's avatar
Martin Bauer committed
511
512
    def __init__(self, base_type, width=4):
        self._base_type = base_type
513
514
515
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
516
517
    def base_type(self):
        return self._base_type
518
519

    @property
Martin Bauer's avatar
Martin Bauer committed
520
521
    def item_size(self):
        return self.width * self.base_type.item_size
522
523
524
525
526

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
527
            return (self.base_type, self.width) == (other.base_type, other.width)
528
529

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
530
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
531
            return "%s[%d]" % (self.base_type, self.width)
532
        else:
Martin Bauer's avatar
Martin Bauer committed
533
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
534
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
535
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
536
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
537
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
538
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
539
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
540
                return self.instruction_set['bool']
541
542
543
544
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
545
        return hash((self.base_type, self.width))
546

Martin Bauer's avatar
Martin Bauer committed
547
548
549
    def __getnewargs__(self):
        return self._base_type, self.width

550

551
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
552
553
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
554
555
556
        self.const = const
        self.restrict = restrict

557
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
558
        return self.base_type, self.const, self.restrict
559

560
561
562
563
564
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
565
566
    def base_type(self):
        return self._base_type
567

568
    @property
Martin Bauer's avatar
Martin Bauer committed
569
570
    def item_size(self):
        return self.base_type.item_size
571

572
573
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
574
            return False
575
        else:
Martin Bauer's avatar
Martin Bauer committed
576
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
577

Jan Hoenig's avatar
Jan Hoenig committed
578
    def __str__(self):
579
580
581
582
583
584
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
585

586
587
588
    def __repr__(self):
        return str(self)

589
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
590
        return hash((self._base_type, self.const, self.restrict))
591

Jan Hoenig's avatar
Jan Hoenig committed
592

593
class StructType:
Martin Bauer's avatar
Martin Bauer committed
594
    def __init__(self, numpy_type, const=False):
595
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
596
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
597

598
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
599
        return self.numpy_dtype, self.const
600

601
    @property
Martin Bauer's avatar
Martin Bauer committed
602
    def base_type(self):
603
604
605
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
606
    def numpy_dtype(self):
607
608
609
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
610
611
    def item_size(self):
        return self.numpy_dtype.itemsize
612

Martin Bauer's avatar
Martin Bauer committed
613
614
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
615

Martin Bauer's avatar
Martin Bauer committed
616
617
618
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
619

Martin Bauer's avatar
Martin Bauer committed
620
621
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
622

623
624
625
626
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
627
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
628
629
630
631
632
633
634
635

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

636
637
638
    def __repr__(self):
        return str(self)

639
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
640
        return hash((self.numpy_dtype, self.const))