data_types.py 19.1 KB
Newer Older
1
import ctypes
Martin Bauer's avatar
Martin Bauer committed
2

3
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
4
5
6
7
8
9
import sympy as sp
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean

from pystencils.cache import memorycache
from pystencils.utils import all_equal
10

11
12
13
14
15
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
16

17

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# noinspection PyPep8Naming
class address_of(sp.Function):
    is_Atom = True

    def __new__(cls, arg):
        obj = sp.Function.__new__(cls, arg)
        return obj

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()

    @property
    def is_commutative(self):
        return self.args[0].is_commutative

    @property
    def dtype(self):
        if hasattr(self.args[0], 'dtype'):
            return PointerType(self.args[0].dtype, restrict=True)
        else:
            return PointerType('void', restrict=True)


Martin Bauer's avatar
Martin Bauer committed
45
# noinspection PyPep8Naming
46
class cast_func(sp.Function):
47
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
48

49
50
51
52
53
54
55
56
    def __new__(cls, *args, **kwargs):
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
57
        # -> thus a separate class boolean_cast_func is introduced
58
59
60
61
        if isinstance(args[0], Boolean):
            cls = boolean_cast_func
        return sp.Function.__new__(cls, *args, **kwargs)

62
63
64
65
66
67
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
68

69
70
71
72
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
73
74
75
    def _eval_evalf(self, *args, **kwargs):
        return self.args[0].evalf()

Martin Bauer's avatar
Martin Bauer committed
76
77
78
79
80
    @property
    def dtype(self):
        return self.args[1]


81
82
83
84
85
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
86
87
88
89
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

90

91
92
93
94
95
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
    pass


Martin Bauer's avatar
Martin Bauer committed
96
97
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
98
99
100
101
102
103
104
105
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


106
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
107
108
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
109
110
        return obj

111
112
    def __new_stage2__(cls, name, dtype, *args, **kwargs):
        obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **kwargs)
113
        try:
Martin Bauer's avatar
Martin Bauer committed
114
            obj._dtype = create_type(dtype)
115
        except (TypeError, ValueError):
116
117
            # on error keep the string
            obj._dtype = dtype
118
119
120
121
122
123
124
125
126
127
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
128
        return super()._hashable_content(), hash(self._dtype)
129
130

    def __getnewargs__(self):
131
132
        return self.name, self.dtype

133
    # For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
134
135
136
    @property
    def is_integer(self):
        if hasattr(self.dtype, 'numpy_dtype'):
137
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
138
        else:
139
140
141
142
143
144
145
146
            return super().is_integer

    @property
    def is_negative(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

147
        return super().is_negative
148

149
150
151
152
153
154
155
    @property
    def is_nonnegative(self):
        if self.is_negative is False:
            return True
        else:
            return super().is_nonnegative

156
157
158
159
160
161
162
163
    @property
    def is_real(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real
164

165

Martin Bauer's avatar
Martin Bauer committed
166
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
167
168
169
170
171
172
173
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
174
    """
175
176
177
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
178
179
180
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
181
        else:
Martin Bauer's avatar
Martin Bauer committed
182
            return StructType(numpy_dtype, const=False)
183
184


185
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
186
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
187
188
189
190
191
192
193
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
194
    """
195
196
197
198
199
200
201
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
202
        else:
203
204
205
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
206
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
207
    base_part = parts.pop(0)
208
    const = False
Martin Bauer's avatar
Martin Bauer committed
209
    if 'const' in base_part:
210
        const = True
Martin Bauer's avatar
Martin Bauer committed
211
212
213
214
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
215
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
216
    current_type = BasicType(np.dtype(base_part[0]), const)
217
218
219
220
221
222
223
224
225
226
227
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
228
229
        current_type = PointerType(current_type, const, restrict)
    return current_type
230
231


Martin Bauer's avatar
Martin Bauer committed
232
233
234
235
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
236
237


Martin Bauer's avatar
Martin Bauer committed
238
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
239
240
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
241
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
242
243
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
244
245
246
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
247
        return ctypes.POINTER(ctypes.c_uint8)
248
    else:
Martin Bauer's avatar
Martin Bauer committed
249
        return to_ctypes.map[data_type.numpy_dtype]
250

Martin Bauer's avatar
Martin Bauer committed
251
252

to_ctypes.map = {
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


268
def ctypes_from_llvm(data_type):
269
270
    if not ir:
        raise _ir_importerror
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
304
305
    if not ir:
        raise _ir_importerror
306
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
307
        return to_llvm_type(data_type.base_type).as_pointer()
308
    else:
Martin Bauer's avatar
Martin Bauer committed
309
310
        return to_llvm_type.map[data_type.numpy_dtype]

311

312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
327

328

Martin Bauer's avatar
Martin Bauer committed
329
330
331
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
332
333
334
    return dtype


Martin Bauer's avatar
Martin Bauer committed
335
def collate_types(types):
336
337
338
339
340
341
342
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
343
        pointer_type = None
344
345
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
346
                if pointer_type is not None:
347
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
348
                pointer_type = t
349
350
351
352
353
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
354
        return pointer_type
355
356

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
357
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
358
    if not all_equal(t.width for t in vector_type):
359
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
360
    types = [peel_off_type(t, VectorType) for t in types]
361
362
363
364

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

365
366
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
367
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
368
369
370
371
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
372
373
374
375
    return result


@memorycache(maxsize=2048)
Martin Bauer's avatar
Martin Bauer committed
376
def get_type_of_expression(expr):
377
    from pystencils.astnodes import ResolvedFieldAccess
378
379
    from pystencils.cpu.vectorization import vec_all, vec_any

380
381
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
382
        return create_type("int")
383
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
Martin Bauer's avatar
Martin Bauer committed
384
        return create_type("double")
385
386
387
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
388
        return expr.dtype
389
    elif isinstance(expr, sp.Symbol):
390
        raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
391
    elif isinstance(expr, cast_func):
392
        return expr.args[1]
393
394
    elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
        return create_type("bool")
395
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
396
397
398
399
400
        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
401
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
402
403
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
404
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
405
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
406
407
408
409
        result = create_type("bool")
        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
410
        return result
411
412
    elif isinstance(expr, sp.Pow):
        return get_type_of_expression(expr.args[0])
413
    elif isinstance(expr, sp.Expr):
Martin Bauer's avatar
Martin Bauer committed
414
415
        types = tuple(get_type_of_expression(a) for a in expr.args)
        return collate_types(types)
416

417
    raise NotImplementedError("Could not determine type for", expr, type(expr))
418
419


420
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
421
422
    is_Atom = True

423
424
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
425

426
427
    def _sympystr(self, *args, **kwargs):
        return str(self)
428
429
430
431


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
432
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
433
434
435
436
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
437
438
439
440
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
441
            width = int(name[len("uint"):])
442
443
444
445
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
Martin Bauer's avatar
Martin Bauer committed
446
            raise NotImplementedError("Can map numpy to C name for %s" % (name,))
447
448
449

    def __init__(self, dtype, const=False):
        self.const = const
450
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
451
            self._dtype = dtype.numpy_dtype
452
453
        else:
            self._dtype = np.dtype(dtype)
454
455
456
457
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

458
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
459
        return self.numpy_dtype, self.const
460

461
    @property
Martin Bauer's avatar
Martin Bauer committed
462
    def base_type(self):
463
        return None
464

465
    @property
Martin Bauer's avatar
Martin Bauer committed
466
    def numpy_dtype(self):
467
468
        return self._dtype

469
    @property
Martin Bauer's avatar
Martin Bauer committed
470
    def item_size(self):
471
472
        return 1

473
    def is_int(self):
Martin Bauer's avatar
Martin Bauer committed
474
        return self.numpy_dtype in np.sctypes['int']
475
476

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
477
        return self.numpy_dtype in np.sctypes['float']
478
479

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
480
        return self.numpy_dtype in np.sctypes['uint']
481

Martin Bauer's avatar
Martin Bauer committed
482
483
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
484
485

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
486
        return self.numpy_dtype in np.sctypes['others']
487

488
    @property
Martin Bauer's avatar
Martin Bauer committed
489
490
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
491

Jan Hoenig's avatar
Jan Hoenig committed
492
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
493
        result = BasicType.numpy_name_to_c(str(self._dtype))
494
495
496
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
497

498
499
500
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
501
    def __eq__(self, other):
502
503
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
504
        else:
Martin Bauer's avatar
Martin Bauer committed
505
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
506
507
508
509
510

    def __hash__(self):
        return hash(str(self))


511
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
512
    instruction_set = None
513

Martin Bauer's avatar
Martin Bauer committed
514
515
    def __init__(self, base_type, width=4):
        self._base_type = base_type
516
517
518
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
519
520
    def base_type(self):
        return self._base_type
521
522

    @property
Martin Bauer's avatar
Martin Bauer committed
523
524
    def item_size(self):
        return self.width * self.base_type.item_size
525
526
527
528
529

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
530
            return (self.base_type, self.width) == (other.base_type, other.width)
531
532

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
533
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
534
            return "%s[%d]" % (self.base_type, self.width)
535
        else:
Martin Bauer's avatar
Martin Bauer committed
536
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
537
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
538
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
539
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
540
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
541
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
542
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
543
                return self.instruction_set['bool']
544
545
546
547
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
548
        return hash((self.base_type, self.width))
549

Martin Bauer's avatar
Martin Bauer committed
550
551
552
    def __getnewargs__(self):
        return self._base_type, self.width

553

554
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
555
556
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
557
558
559
        self.const = const
        self.restrict = restrict

560
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
561
        return self.base_type, self.const, self.restrict
562

563
564
565
566
567
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
568
569
    def base_type(self):
        return self._base_type
570

571
    @property
Martin Bauer's avatar
Martin Bauer committed
572
573
    def item_size(self):
        return self.base_type.item_size
574

575
576
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
577
            return False
578
        else:
Martin Bauer's avatar
Martin Bauer committed
579
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
580

Jan Hoenig's avatar
Jan Hoenig committed
581
    def __str__(self):
582
583
584
585
586
587
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
588

589
590
591
    def __repr__(self):
        return str(self)

592
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
593
        return hash((self._base_type, self.const, self.restrict))
594

Jan Hoenig's avatar
Jan Hoenig committed
595

596
class StructType:
Martin Bauer's avatar
Martin Bauer committed
597
    def __init__(self, numpy_type, const=False):
598
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
599
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
600

601
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
602
        return self.numpy_dtype, self.const
603

604
    @property
Martin Bauer's avatar
Martin Bauer committed
605
    def base_type(self):
606
607
608
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
609
    def numpy_dtype(self):
610
611
612
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
613
614
    def item_size(self):
        return self.numpy_dtype.itemsize
615

Martin Bauer's avatar
Martin Bauer committed
616
617
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
618

Martin Bauer's avatar
Martin Bauer committed
619
620
621
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
622

Martin Bauer's avatar
Martin Bauer committed
623
624
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
625

626
627
628
629
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
630
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
631
632
633
634
635
636
637
638

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

639
640
641
    def __repr__(self):
        return str(self)

642
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
643
        return hash((self.numpy_dtype, self.const))