data_types.py 21.1 KB
Newer Older
1
import ctypes
2
3
from collections import defaultdict
from functools import partial
Martin Bauer's avatar
Martin Bauer committed
4

5
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
6
7
8
9
import sympy as sp
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean

10
from pystencils.cache import memorycache, memorycache_if_hashable
Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.utils import all_equal
12

13
14
15
16
17
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
18

19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# noinspection PyPep8Naming
class address_of(sp.Function):
    is_Atom = True

    def __new__(cls, arg):
        obj = sp.Function.__new__(cls, arg)
        return obj

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()

    @property
    def is_commutative(self):
        return self.args[0].is_commutative

    @property
    def dtype(self):
        if hasattr(self.args[0], 'dtype'):
            return PointerType(self.args[0].dtype, restrict=True)
        else:
            return PointerType('void', restrict=True)


Martin Bauer's avatar
Martin Bauer committed
47
# noinspection PyPep8Naming
48
class cast_func(sp.Function):
49
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
50

51
    def __new__(cls, *args, **kwargs):
52
53
54
55
56
        if len(args) != 2:
            pass
        expr, dtype, *other_args = args
        if not isinstance(dtype, Type):
            dtype = create_type(dtype)
57
58
59
60
61
62
63
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
64
        # -> thus a separate class boolean_cast_func is introduced
65
        if isinstance(expr, Boolean):
66
            cls = boolean_cast_func
67
        return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
68

69
70
71
72
73
74
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
75

76
77
78
79
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
80
81
82
    def _eval_evalf(self, *args, **kwargs):
        return self.args[0].evalf()

Martin Bauer's avatar
Martin Bauer committed
83
84
85
86
    @property
    def dtype(self):
        return self.args[1]

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    @property
    def is_integer(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
        else:
            return super().is_integer

    @property
    def is_negative(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

        return super().is_negative

    @property
    def is_nonnegative(self):
        if self.is_negative is False:
            return True
        else:
            return super().is_nonnegative

    @property
    def is_real(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real

Martin Bauer's avatar
Martin Bauer committed
118

119
120
121
122
123
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
124
125
126
127
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

128

129
130
131
132
133
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
    pass


Martin Bauer's avatar
Martin Bauer committed
134
135
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
136
137
138
139
140
141
142
143
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


144
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
145
146
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
147
148
        return obj

149
150
    def __new_stage2__(cls, name, dtype, *args, **kwargs):
        obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **kwargs)
151
        try:
Martin Bauer's avatar
Martin Bauer committed
152
            obj._dtype = create_type(dtype)
153
        except (TypeError, ValueError):
154
155
            # on error keep the string
            obj._dtype = dtype
156
157
158
159
160
161
162
163
164
165
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
166
        return super()._hashable_content(), hash(self._dtype)
167
168

    def __getnewargs__(self):
169
170
        return self.name, self.dtype

171
    # For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
172
173
174
    @property
    def is_integer(self):
        if hasattr(self.dtype, 'numpy_dtype'):
175
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
176
        else:
177
178
179
180
181
182
183
184
            return super().is_integer

    @property
    def is_negative(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

185
        return super().is_negative
186

187
188
189
190
191
192
193
    @property
    def is_nonnegative(self):
        if self.is_negative is False:
            return True
        else:
            return super().is_nonnegative

194
195
196
197
198
199
200
201
    @property
    def is_real(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real
202

203

Martin Bauer's avatar
Martin Bauer committed
204
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
205
206
207
208
209
210
211
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
212
    """
213
214
215
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
216
217
218
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
219
        else:
Martin Bauer's avatar
Martin Bauer committed
220
            return StructType(numpy_dtype, const=False)
221
222


223
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
224
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
225
226
227
228
229
230
231
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
232
    """
233
234
235
236
237
238
239
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
240
        else:
241
242
243
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
244
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
245
    base_part = parts.pop(0)
246
    const = False
Martin Bauer's avatar
Martin Bauer committed
247
    if 'const' in base_part:
248
        const = True
Martin Bauer's avatar
Martin Bauer committed
249
250
251
252
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
253
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
254
    current_type = BasicType(np.dtype(base_part[0]), const)
255
256
257
258
259
260
261
262
263
264
265
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
266
267
        current_type = PointerType(current_type, const, restrict)
    return current_type
268
269


Martin Bauer's avatar
Martin Bauer committed
270
271
272
273
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
274
275


Martin Bauer's avatar
Martin Bauer committed
276
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
277
278
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
279
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
280
281
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
282
283
284
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
285
        return ctypes.POINTER(ctypes.c_uint8)
286
    else:
Martin Bauer's avatar
Martin Bauer committed
287
        return to_ctypes.map[data_type.numpy_dtype]
288

Martin Bauer's avatar
Martin Bauer committed
289
290

to_ctypes.map = {
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


306
def ctypes_from_llvm(data_type):
307
308
    if not ir:
        raise _ir_importerror
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


336
def to_llvm_type(data_type, nvvm_target=False):
337
338
339
340
341
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
342
343
    if not ir:
        raise _ir_importerror
344
    if isinstance(data_type, PointerType):
345
        return to_llvm_type(data_type.base_type).as_pointer(1 if nvvm_target else 0)
346
    else:
Martin Bauer's avatar
Martin Bauer committed
347
348
        return to_llvm_type.map[data_type.numpy_dtype]

349

350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
365

366

Martin Bauer's avatar
Martin Bauer committed
367
368
369
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
370
371
372
    return dtype


Martin Bauer's avatar
Martin Bauer committed
373
def collate_types(types):
374
375
376
377
378
379
380
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
381
        pointer_type = None
382
383
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
384
                if pointer_type is not None:
385
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
386
                pointer_type = t
387
388
389
390
391
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
392
        return pointer_type
393
394

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
395
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
396
    if not all_equal(t.width for t in vector_type):
397
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
398
    types = [peel_off_type(t, VectorType) for t in types]
399
400
401
402

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

403
404
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
405
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
406
407
408
409
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
410
411
412
    return result


413
414
415
416
417
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
                           default_float_type='double',
                           default_int_type='int',
                           symbol_type_dict=None):
418
    from pystencils.astnodes import ResolvedFieldAccess
419
420
    from pystencils.cpu.vectorization import vec_all, vec_any

421
422
423
424
425
426
427
428
    if not symbol_type_dict:
        symbol_type_dict = defaultdict(lambda: create_type('double'))

    get_type = partial(get_type_of_expression,
                       default_float_type=default_float_type,
                       default_int_type=default_int_type,
                       symbol_type_dict=symbol_type_dict)

429
430
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
431
        return create_type(default_int_type)
432
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
433
        return create_type(default_float_type)
434
435
436
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
437
        return expr.dtype
438
    elif isinstance(expr, sp.Symbol):
439
440
441
442
        if symbol_type_dict:
            return symbol_type_dict[expr.name]
        else:
            raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
443
    elif isinstance(expr, cast_func):
444
        return expr.args[1]
445
    elif isinstance(expr, (vec_any, vec_all)):
446
        return create_type("bool")
447
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
448
449
        collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
Martin Bauer's avatar
Martin Bauer committed
450
451
452
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
453
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
454
455
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
456
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
457
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
458
        result = create_type("bool")
459
        vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
Martin Bauer's avatar
Martin Bauer committed
460
461
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
462
        return result
463
464
    elif isinstance(expr, (sp.Pow, sp.Sum, sp.Product)):
        return get_type(expr.args[0])
465
    elif isinstance(expr, sp.Expr):
466
467
        expr: sp.Expr
        if expr.args:
468
            types = tuple(get_type(a) for a in expr.args)
469
470
471
472
473
474
            return collate_types(types)
        else:
            if expr.is_integer:
                return create_type(default_int_type)
            else:
                return create_type(default_float_type)
475

476
    raise NotImplementedError("Could not determine type for", expr, type(expr))
477
478


479
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
480
481
    is_Atom = True

482
483
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
484

485
486
    def _sympystr(self, *args, **kwargs):
        return str(self)
487
488
489
490


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
491
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
492
493
494
495
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
496
497
498
499
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
500
            width = int(name[len("uint"):])
501
502
503
504
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
Martin Bauer's avatar
Martin Bauer committed
505
            raise NotImplementedError("Can map numpy to C name for %s" % (name,))
506
507
508

    def __init__(self, dtype, const=False):
        self.const = const
509
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
510
            self._dtype = dtype.numpy_dtype
511
512
        else:
            self._dtype = np.dtype(dtype)
513
514
515
516
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

517
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
518
        return self.numpy_dtype, self.const
519

520
    @property
Martin Bauer's avatar
Martin Bauer committed
521
    def base_type(self):
522
        return None
523

524
    @property
Martin Bauer's avatar
Martin Bauer committed
525
    def numpy_dtype(self):
526
527
        return self._dtype

528
    @property
Martin Bauer's avatar
Martin Bauer committed
529
    def item_size(self):
530
531
        return 1

532
    def is_int(self):
533
        return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint']
534
535

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
536
        return self.numpy_dtype in np.sctypes['float']
537
538

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
539
        return self.numpy_dtype in np.sctypes['uint']
540

Martin Bauer's avatar
Martin Bauer committed
541
542
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
543
544

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
545
        return self.numpy_dtype in np.sctypes['others']
546

547
    @property
Martin Bauer's avatar
Martin Bauer committed
548
549
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
550

Jan Hoenig's avatar
Jan Hoenig committed
551
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
552
        result = BasicType.numpy_name_to_c(str(self._dtype))
553
554
555
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
556

557
558
559
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
560
    def __eq__(self, other):
561
562
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
563
        else:
Martin Bauer's avatar
Martin Bauer committed
564
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
565
566
567
568
569

    def __hash__(self):
        return hash(str(self))


570
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
571
    instruction_set = None
572

Martin Bauer's avatar
Martin Bauer committed
573
574
    def __init__(self, base_type, width=4):
        self._base_type = base_type
575
576
577
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
578
579
    def base_type(self):
        return self._base_type
580
581

    @property
Martin Bauer's avatar
Martin Bauer committed
582
583
    def item_size(self):
        return self.width * self.base_type.item_size
584
585
586
587
588

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
589
            return (self.base_type, self.width) == (other.base_type, other.width)
590
591

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
592
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
593
            return "%s[%d]" % (self.base_type, self.width)
594
        else:
Martin Bauer's avatar
Martin Bauer committed
595
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
596
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
597
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
598
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
599
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
600
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
601
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
602
                return self.instruction_set['bool']
603
604
605
606
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
607
        return hash((self.base_type, self.width))
608

Martin Bauer's avatar
Martin Bauer committed
609
610
611
    def __getnewargs__(self):
        return self._base_type, self.width

612

613
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
614
615
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
616
617
618
        self.const = const
        self.restrict = restrict

619
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
620
        return self.base_type, self.const, self.restrict
621

622
623
624
625
626
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
627
628
    def base_type(self):
        return self._base_type
629

630
    @property
Martin Bauer's avatar
Martin Bauer committed
631
632
    def item_size(self):
        return self.base_type.item_size
633

634
635
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
636
            return False
637
        else:
Martin Bauer's avatar
Martin Bauer committed
638
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
639

Jan Hoenig's avatar
Jan Hoenig committed
640
    def __str__(self):
641
642
643
644
645
646
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
647

648
649
650
    def __repr__(self):
        return str(self)

651
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
652
        return hash((self._base_type, self.const, self.restrict))
653

Jan Hoenig's avatar
Jan Hoenig committed
654

655
class StructType:
Martin Bauer's avatar
Martin Bauer committed
656
    def __init__(self, numpy_type, const=False):
657
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
658
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
659

660
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
661
        return self.numpy_dtype, self.const
662

663
    @property
Martin Bauer's avatar
Martin Bauer committed
664
    def base_type(self):
665
666
667
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
668
    def numpy_dtype(self):
669
670
671
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
672
673
    def item_size(self):
        return self.numpy_dtype.itemsize
674

Martin Bauer's avatar
Martin Bauer committed
675
676
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
677

Martin Bauer's avatar
Martin Bauer committed
678
679
680
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
681

Martin Bauer's avatar
Martin Bauer committed
682
683
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
684

685
686
687
688
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
689
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
690
691
692
693
694
695
696
697

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

698
699
700
    def __repr__(self):
        return str(self)

701
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
702
        return hash((self.numpy_dtype, self.const))