data_types.py 19.3 KB
Newer Older
1
import ctypes
Martin Bauer's avatar
Martin Bauer committed
2

3
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
4
5
6
7
8
9
import sympy as sp
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean

from pystencils.cache import memorycache
from pystencils.utils import all_equal
10

11
12
13
14
15
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
16

17

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# noinspection PyPep8Naming
class address_of(sp.Function):
    is_Atom = True

    def __new__(cls, arg):
        obj = sp.Function.__new__(cls, arg)
        return obj

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()

    @property
    def is_commutative(self):
        return self.args[0].is_commutative

    @property
    def dtype(self):
        if hasattr(self.args[0], 'dtype'):
            return PointerType(self.args[0].dtype, restrict=True)
        else:
            return PointerType('void', restrict=True)


Martin Bauer's avatar
Martin Bauer committed
45
# noinspection PyPep8Naming
46
class cast_func(sp.Function):
47
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
48

49
    def __new__(cls, *args, **kwargs):
50
51
52
53
54
        if len(args) != 2:
            pass
        expr, dtype, *other_args = args
        if not isinstance(dtype, Type):
            dtype = create_type(dtype)
55
56
57
58
59
60
61
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
62
        # -> thus a separate class boolean_cast_func is introduced
63
        if isinstance(expr, Boolean):
64
            cls = boolean_cast_func
65
        return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
66

67
68
69
70
71
72
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
73

74
75
76
77
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
78
79
80
    def _eval_evalf(self, *args, **kwargs):
        return self.args[0].evalf()

Martin Bauer's avatar
Martin Bauer committed
81
82
83
84
85
    @property
    def dtype(self):
        return self.args[1]


86
87
88
89
90
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
91
92
93
94
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

95

96
97
98
99
100
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
    pass


Martin Bauer's avatar
Martin Bauer committed
101
102
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
103
104
105
106
107
108
109
110
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


111
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
112
113
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
114
115
        return obj

116
117
    def __new_stage2__(cls, name, dtype, *args, **kwargs):
        obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **kwargs)
118
        try:
Martin Bauer's avatar
Martin Bauer committed
119
            obj._dtype = create_type(dtype)
120
        except (TypeError, ValueError):
121
122
            # on error keep the string
            obj._dtype = dtype
123
124
125
126
127
128
129
130
131
132
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
133
        return super()._hashable_content(), hash(self._dtype)
134
135

    def __getnewargs__(self):
136
137
        return self.name, self.dtype

138
    # For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
139
140
141
    @property
    def is_integer(self):
        if hasattr(self.dtype, 'numpy_dtype'):
142
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
143
        else:
144
145
146
147
148
149
150
151
            return super().is_integer

    @property
    def is_negative(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

152
        return super().is_negative
153

154
155
156
157
158
159
160
    @property
    def is_nonnegative(self):
        if self.is_negative is False:
            return True
        else:
            return super().is_nonnegative

161
162
163
164
165
166
167
168
    @property
    def is_real(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real
169

170

Martin Bauer's avatar
Martin Bauer committed
171
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
172
173
174
175
176
177
178
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
179
    """
180
181
182
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
183
184
185
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
186
        else:
Martin Bauer's avatar
Martin Bauer committed
187
            return StructType(numpy_dtype, const=False)
188
189


190
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
191
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
192
193
194
195
196
197
198
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
199
    """
200
201
202
203
204
205
206
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
207
        else:
208
209
210
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
211
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
212
    base_part = parts.pop(0)
213
    const = False
Martin Bauer's avatar
Martin Bauer committed
214
    if 'const' in base_part:
215
        const = True
Martin Bauer's avatar
Martin Bauer committed
216
217
218
219
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
220
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
221
    current_type = BasicType(np.dtype(base_part[0]), const)
222
223
224
225
226
227
228
229
230
231
232
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
233
234
        current_type = PointerType(current_type, const, restrict)
    return current_type
235
236


Martin Bauer's avatar
Martin Bauer committed
237
238
239
240
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
241
242


Martin Bauer's avatar
Martin Bauer committed
243
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
244
245
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
246
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
247
248
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
249
250
251
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
252
        return ctypes.POINTER(ctypes.c_uint8)
253
    else:
Martin Bauer's avatar
Martin Bauer committed
254
        return to_ctypes.map[data_type.numpy_dtype]
255

Martin Bauer's avatar
Martin Bauer committed
256
257

to_ctypes.map = {
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


273
def ctypes_from_llvm(data_type):
274
275
    if not ir:
        raise _ir_importerror
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
309
310
    if not ir:
        raise _ir_importerror
311
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
312
        return to_llvm_type(data_type.base_type).as_pointer()
313
    else:
Martin Bauer's avatar
Martin Bauer committed
314
315
        return to_llvm_type.map[data_type.numpy_dtype]

316

317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
332

333

Martin Bauer's avatar
Martin Bauer committed
334
335
336
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
337
338
339
    return dtype


Martin Bauer's avatar
Martin Bauer committed
340
def collate_types(types):
341
342
343
344
345
346
347
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
348
        pointer_type = None
349
350
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
351
                if pointer_type is not None:
352
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
353
                pointer_type = t
354
355
356
357
358
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
359
        return pointer_type
360
361

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
362
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
363
    if not all_equal(t.width for t in vector_type):
364
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
365
    types = [peel_off_type(t, VectorType) for t in types]
366
367
368
369

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

370
371
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
372
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
373
374
375
376
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
377
378
379
380
    return result


@memorycache(maxsize=2048)
Martin Bauer's avatar
Martin Bauer committed
381
def get_type_of_expression(expr):
382
    from pystencils.astnodes import ResolvedFieldAccess
383
384
    from pystencils.cpu.vectorization import vec_all, vec_any

385
386
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
387
        return create_type("int")
388
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
Martin Bauer's avatar
Martin Bauer committed
389
        return create_type("double")
390
391
392
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
393
        return expr.dtype
394
    elif isinstance(expr, sp.Symbol):
395
        raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
396
    elif isinstance(expr, cast_func):
397
        return expr.args[1]
398
399
    elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
        return create_type("bool")
400
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
401
402
403
404
405
        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
406
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
407
408
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
409
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
410
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
411
412
413
414
        result = create_type("bool")
        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
415
        return result
416
417
    elif isinstance(expr, sp.Pow):
        return get_type_of_expression(expr.args[0])
418
    elif isinstance(expr, sp.Expr):
Martin Bauer's avatar
Martin Bauer committed
419
420
        types = tuple(get_type_of_expression(a) for a in expr.args)
        return collate_types(types)
421

422
    raise NotImplementedError("Could not determine type for", expr, type(expr))
423
424


425
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
426
427
    is_Atom = True

428
429
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
430

431
432
    def _sympystr(self, *args, **kwargs):
        return str(self)
433
434
435
436


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
437
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
438
439
440
441
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
442
443
444
445
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
446
            width = int(name[len("uint"):])
447
448
449
450
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
Martin Bauer's avatar
Martin Bauer committed
451
            raise NotImplementedError("Can map numpy to C name for %s" % (name,))
452
453
454

    def __init__(self, dtype, const=False):
        self.const = const
455
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
456
            self._dtype = dtype.numpy_dtype
457
458
        else:
            self._dtype = np.dtype(dtype)
459
460
461
462
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

463
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
464
        return self.numpy_dtype, self.const
465

466
    @property
Martin Bauer's avatar
Martin Bauer committed
467
    def base_type(self):
468
        return None
469

470
    @property
Martin Bauer's avatar
Martin Bauer committed
471
    def numpy_dtype(self):
472
473
        return self._dtype

474
    @property
Martin Bauer's avatar
Martin Bauer committed
475
    def item_size(self):
476
477
        return 1

478
    def is_int(self):
479
        return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint']
480
481

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
482
        return self.numpy_dtype in np.sctypes['float']
483
484

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
485
        return self.numpy_dtype in np.sctypes['uint']
486

Martin Bauer's avatar
Martin Bauer committed
487
488
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
489
490

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
491
        return self.numpy_dtype in np.sctypes['others']
492

493
    @property
Martin Bauer's avatar
Martin Bauer committed
494
495
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
496

Jan Hoenig's avatar
Jan Hoenig committed
497
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
498
        result = BasicType.numpy_name_to_c(str(self._dtype))
499
500
501
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
502

503
504
505
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
506
    def __eq__(self, other):
507
508
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
509
        else:
Martin Bauer's avatar
Martin Bauer committed
510
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
511
512
513
514
515

    def __hash__(self):
        return hash(str(self))


516
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
517
    instruction_set = None
518

Martin Bauer's avatar
Martin Bauer committed
519
520
    def __init__(self, base_type, width=4):
        self._base_type = base_type
521
522
523
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
524
525
    def base_type(self):
        return self._base_type
526
527

    @property
Martin Bauer's avatar
Martin Bauer committed
528
529
    def item_size(self):
        return self.width * self.base_type.item_size
530
531
532
533
534

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
535
            return (self.base_type, self.width) == (other.base_type, other.width)
536
537

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
538
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
539
            return "%s[%d]" % (self.base_type, self.width)
540
        else:
Martin Bauer's avatar
Martin Bauer committed
541
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
542
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
543
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
544
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
545
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
546
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
547
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
548
                return self.instruction_set['bool']
549
550
551
552
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
553
        return hash((self.base_type, self.width))
554

Martin Bauer's avatar
Martin Bauer committed
555
556
557
    def __getnewargs__(self):
        return self._base_type, self.width

558

559
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
560
561
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
562
563
564
        self.const = const
        self.restrict = restrict

565
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
566
        return self.base_type, self.const, self.restrict
567

568
569
570
571
572
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
573
574
    def base_type(self):
        return self._base_type
575

576
    @property
Martin Bauer's avatar
Martin Bauer committed
577
578
    def item_size(self):
        return self.base_type.item_size
579

580
581
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
582
            return False
583
        else:
Martin Bauer's avatar
Martin Bauer committed
584
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
585

Jan Hoenig's avatar
Jan Hoenig committed
586
    def __str__(self):
587
588
589
590
591
592
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
593

594
595
596
    def __repr__(self):
        return str(self)

597
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
598
        return hash((self._base_type, self.const, self.restrict))
599

Jan Hoenig's avatar
Jan Hoenig committed
600

601
class StructType:
Martin Bauer's avatar
Martin Bauer committed
602
    def __init__(self, numpy_type, const=False):
603
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
604
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
605

606
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
607
        return self.numpy_dtype, self.const
608

609
    @property
Martin Bauer's avatar
Martin Bauer committed
610
    def base_type(self):
611
612
613
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
614
    def numpy_dtype(self):
615
616
617
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
618
619
    def item_size(self):
        return self.numpy_dtype.itemsize
620

Martin Bauer's avatar
Martin Bauer committed
621
622
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
623

Martin Bauer's avatar
Martin Bauer committed
624
625
626
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
627

Martin Bauer's avatar
Martin Bauer committed
628
629
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
630

631
632
633
634
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
635
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
636
637
638
639
640
641
642
643

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

644
645
646
    def __repr__(self):
        return str(self)

647
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
648
        return hash((self.numpy_dtype, self.const))