data_types.py 21 KB
Newer Older
1
import ctypes
2
3
from collections import defaultdict
from functools import partial
Martin Bauer's avatar
Martin Bauer committed
4

5
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
6
7
8
9
import sympy as sp
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean

10
from pystencils.cache import memorycache, memorycache_if_hashable
Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.utils import all_equal
12

13
14
15
16
17
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
18

19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# noinspection PyPep8Naming
class address_of(sp.Function):
    is_Atom = True

    def __new__(cls, arg):
        obj = sp.Function.__new__(cls, arg)
        return obj

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()

    @property
    def is_commutative(self):
        return self.args[0].is_commutative

    @property
    def dtype(self):
        if hasattr(self.args[0], 'dtype'):
            return PointerType(self.args[0].dtype, restrict=True)
        else:
            return PointerType('void', restrict=True)


Martin Bauer's avatar
Martin Bauer committed
47
# noinspection PyPep8Naming
48
class cast_func(sp.Function):
49
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
50

51
    def __new__(cls, *args, **kwargs):
52
53
54
55
56
        if len(args) != 2:
            pass
        expr, dtype, *other_args = args
        if not isinstance(dtype, Type):
            dtype = create_type(dtype)
57
58
59
60
61
62
63
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
64
        # -> thus a separate class boolean_cast_func is introduced
65
        if isinstance(expr, Boolean):
66
            cls = boolean_cast_func
67
        return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
68

69
70
71
72
73
74
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
75

76
77
78
79
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
80
81
82
    def _eval_evalf(self, *args, **kwargs):
        return self.args[0].evalf()

Martin Bauer's avatar
Martin Bauer committed
83
84
85
86
    @property
    def dtype(self):
        return self.args[1]

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    @property
    def is_integer(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
        else:
            return super().is_integer

    @property
    def is_negative(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

        return super().is_negative

    @property
    def is_nonnegative(self):
        if self.is_negative is False:
            return True
        else:
            return super().is_nonnegative

    @property
    def is_real(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real

Martin Bauer's avatar
Martin Bauer committed
118

119
120
121
122
123
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
124
125
126
127
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

128

129
130
131
132
133
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
    pass


Martin Bauer's avatar
Martin Bauer committed
134
135
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
136
137
138
139
140
141
142
143
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


144
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
145
146
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
147
148
        return obj

149
150
    def __new_stage2__(cls, name, dtype, *args, **kwargs):
        obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **kwargs)
151
        try:
Martin Bauer's avatar
Martin Bauer committed
152
            obj._dtype = create_type(dtype)
153
        except (TypeError, ValueError):
154
155
            # on error keep the string
            obj._dtype = dtype
156
157
158
159
160
161
162
163
164
165
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
166
        return super()._hashable_content(), hash(self._dtype)
167
168

    def __getnewargs__(self):
169
170
        return self.name, self.dtype

171
    # For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
172
173
174
    @property
    def is_integer(self):
        if hasattr(self.dtype, 'numpy_dtype'):
175
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
176
        else:
177
178
179
180
181
182
183
184
            return super().is_integer

    @property
    def is_negative(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

185
        return super().is_negative
186

187
188
189
190
191
192
193
    @property
    def is_nonnegative(self):
        if self.is_negative is False:
            return True
        else:
            return super().is_nonnegative

194
195
196
197
198
199
200
201
    @property
    def is_real(self):
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real
202

203

Martin Bauer's avatar
Martin Bauer committed
204
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
205
206
207
208
209
210
211
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
212
    """
213
214
215
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
216
217
218
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
219
        else:
Martin Bauer's avatar
Martin Bauer committed
220
            return StructType(numpy_dtype, const=False)
221
222


223
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
224
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
225
226
227
228
229
230
231
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
232
    """
233
234
235
236
237
238
239
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
240
        else:
241
242
243
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
244
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
245
    base_part = parts.pop(0)
246
    const = False
Martin Bauer's avatar
Martin Bauer committed
247
    if 'const' in base_part:
248
        const = True
Martin Bauer's avatar
Martin Bauer committed
249
250
251
252
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
253
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
254
    current_type = BasicType(np.dtype(base_part[0]), const)
255
256
257
258
259
260
261
262
263
264
265
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
266
267
        current_type = PointerType(current_type, const, restrict)
    return current_type
268
269


Martin Bauer's avatar
Martin Bauer committed
270
271
272
273
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
274
275


Martin Bauer's avatar
Martin Bauer committed
276
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
277
278
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
279
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
280
281
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
282
283
284
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
285
        return ctypes.POINTER(ctypes.c_uint8)
286
    else:
Martin Bauer's avatar
Martin Bauer committed
287
        return to_ctypes.map[data_type.numpy_dtype]
288

Martin Bauer's avatar
Martin Bauer committed
289
290

to_ctypes.map = {
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


306
def ctypes_from_llvm(data_type):
307
308
    if not ir:
        raise _ir_importerror
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
342
343
    if not ir:
        raise _ir_importerror
344
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
345
        return to_llvm_type(data_type.base_type).as_pointer()
346
    else:
Martin Bauer's avatar
Martin Bauer committed
347
348
        return to_llvm_type.map[data_type.numpy_dtype]

349

350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
365

366

Martin Bauer's avatar
Martin Bauer committed
367
368
369
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
370
371
372
    return dtype


Martin Bauer's avatar
Martin Bauer committed
373
def collate_types(types):
374
375
376
377
378
379
380
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
381
        pointer_type = None
382
383
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
384
                if pointer_type is not None:
385
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
386
                pointer_type = t
387
388
389
390
391
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
392
        return pointer_type
393
394

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
395
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
396
    if not all_equal(t.width for t in vector_type):
397
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
398
    types = [peel_off_type(t, VectorType) for t in types]
399
400
401
402

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

403
404
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
405
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
406
407
408
409
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
410
411
412
    return result


413
414
415
416
417
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
                           default_float_type='double',
                           default_int_type='int',
                           symbol_type_dict=None):
418
    from pystencils.astnodes import ResolvedFieldAccess
419
420
    from pystencils.cpu.vectorization import vec_all, vec_any

421
422
423
424
425
426
427
428
    if not symbol_type_dict:
        symbol_type_dict = defaultdict(lambda: create_type('double'))

    get_type = partial(get_type_of_expression,
                       default_float_type=default_float_type,
                       default_int_type=default_int_type,
                       symbol_type_dict=symbol_type_dict)

429
430
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
431
        return create_type(default_int_type)
432
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
433
        return create_type(default_float_type)
434
435
436
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
437
        return expr.dtype
438
    elif isinstance(expr, sp.Symbol):
439
440
        return symbol_type_dict[expr.name]
        # raise ValueError("All symbols iside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
441
    elif isinstance(expr, cast_func):
442
        return expr.args[1]
443
    elif isinstance(expr, (vec_any, vec_all)):
444
        return create_type("bool")
445
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
446
447
        collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
Martin Bauer's avatar
Martin Bauer committed
448
449
450
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
451
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
452
453
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
454
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
455
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
456
        result = create_type("bool")
457
        vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
Martin Bauer's avatar
Martin Bauer committed
458
459
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
460
        return result
461
462
    elif isinstance(expr, (sp.Pow, sp.Sum, sp.Product)):
        return get_type(expr.args[0])
463
    elif isinstance(expr, sp.Expr):
464
465
        expr: sp.Expr
        if expr.args:
466
            types = tuple(get_type(a) for a in expr.args)
467
468
469
470
471
472
            return collate_types(types)
        else:
            if expr.is_integer:
                return create_type(default_int_type)
            else:
                return create_type(default_float_type)
473

474
    raise NotImplementedError("Could not determine type for", expr, type(expr))
475
476


477
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
478
479
    is_Atom = True

480
481
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
482

483
484
    def _sympystr(self, *args, **kwargs):
        return str(self)
485
486
487
488


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
489
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
490
491
492
493
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
494
495
496
497
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
498
            width = int(name[len("uint"):])
499
500
501
502
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
Martin Bauer's avatar
Martin Bauer committed
503
            raise NotImplementedError("Can map numpy to C name for %s" % (name,))
504
505
506

    def __init__(self, dtype, const=False):
        self.const = const
507
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
508
            self._dtype = dtype.numpy_dtype
509
510
        else:
            self._dtype = np.dtype(dtype)
511
512
513
514
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

515
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
516
        return self.numpy_dtype, self.const
517

518
    @property
Martin Bauer's avatar
Martin Bauer committed
519
    def base_type(self):
520
        return None
521

522
    @property
Martin Bauer's avatar
Martin Bauer committed
523
    def numpy_dtype(self):
524
525
        return self._dtype

526
    @property
Martin Bauer's avatar
Martin Bauer committed
527
    def item_size(self):
528
529
        return 1

530
    def is_int(self):
531
        return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint']
532
533

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
534
        return self.numpy_dtype in np.sctypes['float']
535
536

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
537
        return self.numpy_dtype in np.sctypes['uint']
538

Martin Bauer's avatar
Martin Bauer committed
539
540
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
541
542

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
543
        return self.numpy_dtype in np.sctypes['others']
544

545
    @property
Martin Bauer's avatar
Martin Bauer committed
546
547
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
548

Jan Hoenig's avatar
Jan Hoenig committed
549
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
550
        result = BasicType.numpy_name_to_c(str(self._dtype))
551
552
553
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
554

555
556
557
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
558
    def __eq__(self, other):
559
560
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
561
        else:
Martin Bauer's avatar
Martin Bauer committed
562
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
563
564
565
566
567

    def __hash__(self):
        return hash(str(self))


568
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
569
    instruction_set = None
570

Martin Bauer's avatar
Martin Bauer committed
571
572
    def __init__(self, base_type, width=4):
        self._base_type = base_type
573
574
575
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
576
577
    def base_type(self):
        return self._base_type
578
579

    @property
Martin Bauer's avatar
Martin Bauer committed
580
581
    def item_size(self):
        return self.width * self.base_type.item_size
582
583
584
585
586

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
587
            return (self.base_type, self.width) == (other.base_type, other.width)
588
589

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
590
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
591
            return "%s[%d]" % (self.base_type, self.width)
592
        else:
Martin Bauer's avatar
Martin Bauer committed
593
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
594
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
595
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
596
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
597
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
598
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
599
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
600
                return self.instruction_set['bool']
601
602
603
604
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
605
        return hash((self.base_type, self.width))
606

Martin Bauer's avatar
Martin Bauer committed
607
608
609
    def __getnewargs__(self):
        return self._base_type, self.width

610

611
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
612
613
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
614
615
616
        self.const = const
        self.restrict = restrict

617
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
618
        return self.base_type, self.const, self.restrict
619

620
621
622
623
624
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
625
626
    def base_type(self):
        return self._base_type
627

628
    @property
Martin Bauer's avatar
Martin Bauer committed
629
630
    def item_size(self):
        return self.base_type.item_size
631

632
633
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
634
            return False
635
        else:
Martin Bauer's avatar
Martin Bauer committed
636
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
637

Jan Hoenig's avatar
Jan Hoenig committed
638
    def __str__(self):
639
640
641
642
643
644
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
645

646
647
648
    def __repr__(self):
        return str(self)

649
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
650
        return hash((self._base_type, self.const, self.restrict))
651

Jan Hoenig's avatar
Jan Hoenig committed
652

653
class StructType:
Martin Bauer's avatar
Martin Bauer committed
654
    def __init__(self, numpy_type, const=False):
655
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
656
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
657

658
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
659
        return self.numpy_dtype, self.const
660

661
    @property
Martin Bauer's avatar
Martin Bauer committed
662
    def base_type(self):
663
664
665
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
666
    def numpy_dtype(self):
667
668
669
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
670
671
    def item_size(self):
        return self.numpy_dtype.itemsize
672

Martin Bauer's avatar
Martin Bauer committed
673
674
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
675

Martin Bauer's avatar
Martin Bauer committed
676
677
678
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
679

Martin Bauer's avatar
Martin Bauer committed
680
681
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
682

683
684
685
686
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
687
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
688
689
690
691
692
693
694
695

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

696
697
698
    def __repr__(self):
        return str(self)

699
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
700
        return hash((self.numpy_dtype, self.const))