data_types.py 19.6 KB
Newer Older
1
import ctypes
Martin Bauer's avatar
Martin Bauer committed
2

3
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
4
5
6
7
8
9
import sympy as sp
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean

from pystencils.cache import memorycache
from pystencils.utils import all_equal
10

11
12
13
14
15
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
16

17

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# noinspection PyPep8Naming
class address_of(sp.Function):
    is_Atom = True

    def __new__(cls, arg):
        obj = sp.Function.__new__(cls, arg)
        return obj

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()

    @property
    def is_commutative(self):
        return self.args[0].is_commutative

    @property
    def dtype(self):
        if hasattr(self.args[0], 'dtype'):
            return PointerType(self.args[0].dtype, restrict=True)
        else:
            return PointerType('void', restrict=True)


Martin Bauer's avatar
Martin Bauer committed
45
# noinspection PyPep8Naming
46
class cast_func(sp.Function):
47
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
48

49
    def __new__(cls, *args, **kwargs):
50
51
52
53
54
        if len(args) != 2:
            pass
        expr, dtype, *other_args = args
        if not isinstance(dtype, Type):
            dtype = create_type(dtype)
55
56
57
58
59
60
61
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
62
        # -> thus a separate class boolean_cast_func is introduced
63
        if isinstance(expr, Boolean):
64
            cls = boolean_cast_func
65
        return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
66

67
68
69
70
71
72
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
73

74
75
76
77
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
78
79
80
    def _eval_evalf(self, *args, **kwargs):
        return self.args[0].evalf()

Martin Bauer's avatar
Martin Bauer committed
81
82
83
84
85
    @property
    def dtype(self):
        return self.args[1]


86
87
88
89
90
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
91
92
93
94
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

95

96
97
98
99
100
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
    pass


Martin Bauer's avatar
Martin Bauer committed
101
102
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
103
104
105
106
107
108
109
110
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


111
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
112
113
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
114
115
        return obj

116
117
    def __new_stage2__(cls, name, dtype, *args, **kwargs):
        obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **kwargs)
118
        try:
Martin Bauer's avatar
Martin Bauer committed
119
            obj._dtype = create_type(dtype)
120
        except (TypeError, ValueError):
121
122
            # on error keep the string
            obj._dtype = dtype
123
124
125
126
127
128
129
130
131
132
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
133
        return super()._hashable_content(), hash(self._dtype)
134
135

    def __getnewargs__(self):
136
137
        return self.name, self.dtype

138
    # For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
139
140
141
    @property
    def is_integer(self):
        if hasattr(self.dtype, 'numpy_dtype'):
142
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
143
        else:
144
145
146
147
148
149
150
151
            return super().is_integer

    @property
    def is_negative(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

152
        return super().is_negative
153

154
155
156
157
158
159
160
    @property
    def is_nonnegative(self):
        if self.is_negative is False:
            return True
        else:
            return super().is_nonnegative

161
162
163
164
165
166
167
168
    @property
    def is_real(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real
169

170

Martin Bauer's avatar
Martin Bauer committed
171
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
172
173
174
175
176
177
178
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
179
    """
180
181
182
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
183
184
185
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
186
        else:
Martin Bauer's avatar
Martin Bauer committed
187
            return StructType(numpy_dtype, const=False)
188
189


190
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
191
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
192
193
194
195
196
197
198
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
199
    """
200
201
202
203
204
205
206
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
207
        else:
208
209
210
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
211
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
212
    base_part = parts.pop(0)
213
    const = False
Martin Bauer's avatar
Martin Bauer committed
214
    if 'const' in base_part:
215
        const = True
Martin Bauer's avatar
Martin Bauer committed
216
217
218
219
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
220
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
221
    current_type = BasicType(np.dtype(base_part[0]), const)
222
223
224
225
226
227
228
229
230
231
232
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
233
234
        current_type = PointerType(current_type, const, restrict)
    return current_type
235
236


Martin Bauer's avatar
Martin Bauer committed
237
238
239
240
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
241
242


Martin Bauer's avatar
Martin Bauer committed
243
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
244
245
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
246
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
247
248
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
249
250
251
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
252
        return ctypes.POINTER(ctypes.c_uint8)
253
    else:
Martin Bauer's avatar
Martin Bauer committed
254
        return to_ctypes.map[data_type.numpy_dtype]
255

Martin Bauer's avatar
Martin Bauer committed
256
257

to_ctypes.map = {
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


273
def ctypes_from_llvm(data_type):
274
275
    if not ir:
        raise _ir_importerror
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
309
310
    if not ir:
        raise _ir_importerror
311
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
312
        return to_llvm_type(data_type.base_type).as_pointer()
313
    else:
Martin Bauer's avatar
Martin Bauer committed
314
315
        return to_llvm_type.map[data_type.numpy_dtype]

316

317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
332

333

Martin Bauer's avatar
Martin Bauer committed
334
335
336
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
337
338
339
    return dtype


Martin Bauer's avatar
Martin Bauer committed
340
def collate_types(types):
341
342
343
344
345
346
347
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
348
        pointer_type = None
349
350
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
351
                if pointer_type is not None:
352
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
353
                pointer_type = t
354
355
356
357
358
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
359
        return pointer_type
360
361

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
362
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
363
    if not all_equal(t.width for t in vector_type):
364
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
365
    types = [peel_off_type(t, VectorType) for t in types]
366
367
368
369

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

370
371
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
372
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
373
374
375
376
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
377
378
379
380
    return result


@memorycache(maxsize=2048)
381
def get_type_of_expression(expr, default_float_type='double', default_int_type='int'):
382
    from pystencils.astnodes import ResolvedFieldAccess
383
384
    from pystencils.cpu.vectorization import vec_all, vec_any

385
386
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
387
        return create_type(default_int_type)
388
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
389
        return create_type(default_float_type)
390
391
392
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
393
        return expr.dtype
394
    elif isinstance(expr, sp.Symbol):
395
        raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
396
    elif isinstance(expr, cast_func):
397
        return expr.args[1]
398
399
    elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
        return create_type("bool")
400
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
401
402
403
404
405
        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
406
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
407
408
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
409
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
410
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
411
412
413
414
        result = create_type("bool")
        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
415
        return result
416
417
    elif isinstance(expr, sp.Pow):
        return get_type_of_expression(expr.args[0])
418
    elif isinstance(expr, sp.Expr):
419
420
421
422
423
424
425
426
427
        expr: sp.Expr
        if expr.args:
            types = tuple(get_type_of_expression(a) for a in expr.args)
            return collate_types(types)
        else:
            if expr.is_integer:
                return create_type(default_int_type)
            else:
                return create_type(default_float_type)
428

429
    raise NotImplementedError("Could not determine type for", expr, type(expr))
430
431


432
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
433
434
    is_Atom = True

435
436
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
437

438
439
    def _sympystr(self, *args, **kwargs):
        return str(self)
440
441
442
443


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
444
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
445
446
447
448
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
449
450
451
452
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
453
            width = int(name[len("uint"):])
454
455
456
457
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
Martin Bauer's avatar
Martin Bauer committed
458
            raise NotImplementedError("Can map numpy to C name for %s" % (name,))
459
460
461

    def __init__(self, dtype, const=False):
        self.const = const
462
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
463
            self._dtype = dtype.numpy_dtype
464
465
        else:
            self._dtype = np.dtype(dtype)
466
467
468
469
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

470
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
471
        return self.numpy_dtype, self.const
472

473
    @property
Martin Bauer's avatar
Martin Bauer committed
474
    def base_type(self):
475
        return None
476

477
    @property
Martin Bauer's avatar
Martin Bauer committed
478
    def numpy_dtype(self):
479
480
        return self._dtype

481
    @property
Martin Bauer's avatar
Martin Bauer committed
482
    def item_size(self):
483
484
        return 1

485
    def is_int(self):
486
        return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint']
487
488

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
489
        return self.numpy_dtype in np.sctypes['float']
490
491

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
492
        return self.numpy_dtype in np.sctypes['uint']
493

Martin Bauer's avatar
Martin Bauer committed
494
495
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
496
497

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
498
        return self.numpy_dtype in np.sctypes['others']
499

500
    @property
Martin Bauer's avatar
Martin Bauer committed
501
502
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
503

Jan Hoenig's avatar
Jan Hoenig committed
504
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
505
        result = BasicType.numpy_name_to_c(str(self._dtype))
506
507
508
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
509

510
511
512
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
513
    def __eq__(self, other):
514
515
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
516
        else:
Martin Bauer's avatar
Martin Bauer committed
517
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
518
519
520
521
522

    def __hash__(self):
        return hash(str(self))


523
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
524
    instruction_set = None
525

Martin Bauer's avatar
Martin Bauer committed
526
527
    def __init__(self, base_type, width=4):
        self._base_type = base_type
528
529
530
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
531
532
    def base_type(self):
        return self._base_type
533
534

    @property
Martin Bauer's avatar
Martin Bauer committed
535
536
    def item_size(self):
        return self.width * self.base_type.item_size
537
538
539
540
541

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
542
            return (self.base_type, self.width) == (other.base_type, other.width)
543
544

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
545
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
546
            return "%s[%d]" % (self.base_type, self.width)
547
        else:
Martin Bauer's avatar
Martin Bauer committed
548
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
549
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
550
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
551
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
552
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
553
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
554
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
555
                return self.instruction_set['bool']
556
557
558
559
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
560
        return hash((self.base_type, self.width))
561

Martin Bauer's avatar
Martin Bauer committed
562
563
564
    def __getnewargs__(self):
        return self._base_type, self.width

565

566
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
567
568
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
569
570
571
        self.const = const
        self.restrict = restrict

572
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
573
        return self.base_type, self.const, self.restrict
574

575
576
577
578
579
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
580
581
    def base_type(self):
        return self._base_type
582

583
    @property
Martin Bauer's avatar
Martin Bauer committed
584
585
    def item_size(self):
        return self.base_type.item_size
586

587
588
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
589
            return False
590
        else:
Martin Bauer's avatar
Martin Bauer committed
591
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
592

Jan Hoenig's avatar
Jan Hoenig committed
593
    def __str__(self):
594
595
596
597
598
599
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
600

601
602
603
    def __repr__(self):
        return str(self)

604
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
605
        return hash((self._base_type, self.const, self.restrict))
606

Jan Hoenig's avatar
Jan Hoenig committed
607

608
class StructType:
Martin Bauer's avatar
Martin Bauer committed
609
    def __init__(self, numpy_type, const=False):
610
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
611
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
612

613
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
614
        return self.numpy_dtype, self.const
615

616
    @property
Martin Bauer's avatar
Martin Bauer committed
617
    def base_type(self):
618
619
620
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
621
    def numpy_dtype(self):
622
623
624
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
625
626
    def item_size(self):
        return self.numpy_dtype.itemsize
627

Martin Bauer's avatar
Martin Bauer committed
628
629
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
630

Martin Bauer's avatar
Martin Bauer committed
631
632
633
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
634

Martin Bauer's avatar
Martin Bauer committed
635
636
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
637

638
639
640
641
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
642
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
643
644
645
646
647
648
649
650

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

651
652
653
    def __repr__(self):
        return str(self)

654
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
655
        return hash((self.numpy_dtype, self.const))