data_types.py 20.5 KB
Newer Older
1
import ctypes
Martin Bauer's avatar
Martin Bauer committed
2

3
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
4
5
6
7
8
9
import sympy as sp
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean

from pystencils.cache import memorycache
from pystencils.utils import all_equal
10

11
12
13
14
15
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
16

17

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# noinspection PyPep8Naming
class address_of(sp.Function):
    is_Atom = True

    def __new__(cls, arg):
        obj = sp.Function.__new__(cls, arg)
        return obj

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()

    @property
    def is_commutative(self):
        return self.args[0].is_commutative

    @property
    def dtype(self):
        if hasattr(self.args[0], 'dtype'):
            return PointerType(self.args[0].dtype, restrict=True)
        else:
            return PointerType('void', restrict=True)


Martin Bauer's avatar
Martin Bauer committed
45
# noinspection PyPep8Naming
46
class cast_func(sp.Function):
47
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
48

49
    def __new__(cls, *args, **kwargs):
50
51
52
53
54
        if len(args) != 2:
            pass
        expr, dtype, *other_args = args
        if not isinstance(dtype, Type):
            dtype = create_type(dtype)
55
56
57
58
59
60
61
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
62
        # -> thus a separate class boolean_cast_func is introduced
63
        if isinstance(expr, Boolean):
64
            cls = boolean_cast_func
65
        return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
66

67
68
69
70
71
72
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
73

74
75
76
77
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
78
79
80
    def _eval_evalf(self, *args, **kwargs):
        return self.args[0].evalf()

Martin Bauer's avatar
Martin Bauer committed
81
82
83
84
    @property
    def dtype(self):
        return self.args[1]

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    @property
    def is_integer(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
        else:
            return super().is_integer

    @property
    def is_negative(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

        return super().is_negative

    @property
    def is_nonnegative(self):
        if self.is_negative is False:
            return True
        else:
            return super().is_nonnegative

    @property
    def is_real(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real

Martin Bauer's avatar
Martin Bauer committed
116

117
118
119
120
121
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
122
123
124
125
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

126

127
128
129
130
131
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
    pass


Martin Bauer's avatar
Martin Bauer committed
132
133
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
134
135
136
137
138
139
140
141
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


142
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
143
144
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
145
146
        return obj

147
148
    def __new_stage2__(cls, name, dtype, *args, **kwargs):
        obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **kwargs)
149
        try:
Martin Bauer's avatar
Martin Bauer committed
150
            obj._dtype = create_type(dtype)
151
        except (TypeError, ValueError):
152
153
            # on error keep the string
            obj._dtype = dtype
154
155
156
157
158
159
160
161
162
163
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
164
        return super()._hashable_content(), hash(self._dtype)
165
166

    def __getnewargs__(self):
167
168
        return self.name, self.dtype

169
    # For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
170
171
172
    @property
    def is_integer(self):
        if hasattr(self.dtype, 'numpy_dtype'):
173
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
174
        else:
175
176
177
178
179
180
181
182
            return super().is_integer

    @property
    def is_negative(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

183
        return super().is_negative
184

185
186
187
188
189
190
191
    @property
    def is_nonnegative(self):
        if self.is_negative is False:
            return True
        else:
            return super().is_nonnegative

192
193
194
195
196
197
198
199
    @property
    def is_real(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real
200

201

Martin Bauer's avatar
Martin Bauer committed
202
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
203
204
205
206
207
208
209
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
210
    """
211
212
213
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
214
215
216
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
217
        else:
Martin Bauer's avatar
Martin Bauer committed
218
            return StructType(numpy_dtype, const=False)
219
220


221
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
222
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
223
224
225
226
227
228
229
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
230
    """
231
232
233
234
235
236
237
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
238
        else:
239
240
241
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
242
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
243
    base_part = parts.pop(0)
244
    const = False
Martin Bauer's avatar
Martin Bauer committed
245
    if 'const' in base_part:
246
        const = True
Martin Bauer's avatar
Martin Bauer committed
247
248
249
250
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
251
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
252
    current_type = BasicType(np.dtype(base_part[0]), const)
253
254
255
256
257
258
259
260
261
262
263
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
264
265
        current_type = PointerType(current_type, const, restrict)
    return current_type
266
267


Martin Bauer's avatar
Martin Bauer committed
268
269
270
271
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
272
273


Martin Bauer's avatar
Martin Bauer committed
274
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
275
276
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
277
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
278
279
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
280
281
282
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
283
        return ctypes.POINTER(ctypes.c_uint8)
284
    else:
Martin Bauer's avatar
Martin Bauer committed
285
        return to_ctypes.map[data_type.numpy_dtype]
286

Martin Bauer's avatar
Martin Bauer committed
287
288

to_ctypes.map = {
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


304
def ctypes_from_llvm(data_type):
305
306
    if not ir:
        raise _ir_importerror
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
340
341
    if not ir:
        raise _ir_importerror
342
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
343
        return to_llvm_type(data_type.base_type).as_pointer()
344
    else:
Martin Bauer's avatar
Martin Bauer committed
345
346
        return to_llvm_type.map[data_type.numpy_dtype]

347

348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
363

364

Martin Bauer's avatar
Martin Bauer committed
365
366
367
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
368
369
370
    return dtype


Martin Bauer's avatar
Martin Bauer committed
371
def collate_types(types):
372
373
374
375
376
377
378
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
379
        pointer_type = None
380
381
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
382
                if pointer_type is not None:
383
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
384
                pointer_type = t
385
386
387
388
389
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
390
        return pointer_type
391
392

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
393
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
394
    if not all_equal(t.width for t in vector_type):
395
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
396
    types = [peel_off_type(t, VectorType) for t in types]
397
398
399
400

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

401
402
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
403
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
404
405
406
407
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
408
409
410
411
    return result


@memorycache(maxsize=2048)
412
def get_type_of_expression(expr, default_float_type='double', default_int_type='int'):
413
    from pystencils.astnodes import ResolvedFieldAccess
414
415
    from pystencils.cpu.vectorization import vec_all, vec_any

416
417
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
418
        return create_type(default_int_type)
419
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
420
        return create_type(default_float_type)
421
422
423
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
424
        return expr.dtype
425
    elif isinstance(expr, sp.Symbol):
426
        raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
427
    elif isinstance(expr, cast_func):
428
        return expr.args[1]
429
430
    elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
        return create_type("bool")
431
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
432
433
434
435
436
        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
437
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
438
439
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
440
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
441
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
442
443
444
445
        result = create_type("bool")
        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
446
        return result
447
448
    elif isinstance(expr, sp.Pow):
        return get_type_of_expression(expr.args[0])
449
    elif isinstance(expr, sp.Expr):
450
451
452
453
454
455
456
457
458
        expr: sp.Expr
        if expr.args:
            types = tuple(get_type_of_expression(a) for a in expr.args)
            return collate_types(types)
        else:
            if expr.is_integer:
                return create_type(default_int_type)
            else:
                return create_type(default_float_type)
459

460
    raise NotImplementedError("Could not determine type for", expr, type(expr))
461
462


463
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
464
465
    is_Atom = True

466
467
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
468

469
470
    def _sympystr(self, *args, **kwargs):
        return str(self)
471
472
473
474


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
475
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
476
477
478
479
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
480
481
482
483
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
484
            width = int(name[len("uint"):])
485
486
487
488
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
Martin Bauer's avatar
Martin Bauer committed
489
            raise NotImplementedError("Can map numpy to C name for %s" % (name,))
490
491
492

    def __init__(self, dtype, const=False):
        self.const = const
493
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
494
            self._dtype = dtype.numpy_dtype
495
496
        else:
            self._dtype = np.dtype(dtype)
497
498
499
500
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

501
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
502
        return self.numpy_dtype, self.const
503

504
    @property
Martin Bauer's avatar
Martin Bauer committed
505
    def base_type(self):
506
        return None
507

508
    @property
Martin Bauer's avatar
Martin Bauer committed
509
    def numpy_dtype(self):
510
511
        return self._dtype

512
    @property
Martin Bauer's avatar
Martin Bauer committed
513
    def item_size(self):
514
515
        return 1

516
    def is_int(self):
517
        return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint']
518
519

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
520
        return self.numpy_dtype in np.sctypes['float']
521
522

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
523
        return self.numpy_dtype in np.sctypes['uint']
524

Martin Bauer's avatar
Martin Bauer committed
525
526
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
527
528

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
529
        return self.numpy_dtype in np.sctypes['others']
530

531
    @property
Martin Bauer's avatar
Martin Bauer committed
532
533
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
534

Jan Hoenig's avatar
Jan Hoenig committed
535
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
536
        result = BasicType.numpy_name_to_c(str(self._dtype))
537
538
539
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
540

541
542
543
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
544
    def __eq__(self, other):
545
546
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
547
        else:
Martin Bauer's avatar
Martin Bauer committed
548
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
549
550
551
552
553

    def __hash__(self):
        return hash(str(self))


554
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
555
    instruction_set = None
556

Martin Bauer's avatar
Martin Bauer committed
557
558
    def __init__(self, base_type, width=4):
        self._base_type = base_type
559
560
561
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
562
563
    def base_type(self):
        return self._base_type
564
565

    @property
Martin Bauer's avatar
Martin Bauer committed
566
567
    def item_size(self):
        return self.width * self.base_type.item_size
568
569
570
571
572

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
573
            return (self.base_type, self.width) == (other.base_type, other.width)
574
575

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
576
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
577
            return "%s[%d]" % (self.base_type, self.width)
578
        else:
Martin Bauer's avatar
Martin Bauer committed
579
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
580
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
581
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
582
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
583
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
584
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
585
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
586
                return self.instruction_set['bool']
587
588
589
590
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
591
        return hash((self.base_type, self.width))
592

Martin Bauer's avatar
Martin Bauer committed
593
594
595
    def __getnewargs__(self):
        return self._base_type, self.width

596

597
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
598
599
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
600
601
602
        self.const = const
        self.restrict = restrict

603
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
604
        return self.base_type, self.const, self.restrict
605

606
607
608
609
610
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
611
612
    def base_type(self):
        return self._base_type
613

614
    @property
Martin Bauer's avatar
Martin Bauer committed
615
616
    def item_size(self):
        return self.base_type.item_size
617

618
619
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
620
            return False
621
        else:
Martin Bauer's avatar
Martin Bauer committed
622
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
623

Jan Hoenig's avatar
Jan Hoenig committed
624
    def __str__(self):
625
626
627
628
629
630
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
631

632
633
634
    def __repr__(self):
        return str(self)

635
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
636
        return hash((self._base_type, self.const, self.restrict))
637

Jan Hoenig's avatar
Jan Hoenig committed
638

639
class StructType:
Martin Bauer's avatar
Martin Bauer committed
640
    def __init__(self, numpy_type, const=False):
641
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
642
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
643

644
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
645
        return self.numpy_dtype, self.const
646

647
    @property
Martin Bauer's avatar
Martin Bauer committed
648
    def base_type(self):
649
650
651
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
652
    def numpy_dtype(self):
653
654
655
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
656
657
    def item_size(self):
        return self.numpy_dtype.itemsize
658

Martin Bauer's avatar
Martin Bauer committed
659
660
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
661

Martin Bauer's avatar
Martin Bauer committed
662
663
664
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
665

Martin Bauer's avatar
Martin Bauer committed
666
667
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
668

669
670
671
672
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
673
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
674
675
676
677
678
679
680
681

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

682
683
684
    def __repr__(self):
        return str(self)

685
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
686
        return hash((self.numpy_dtype, self.const))