data_types.py 22.9 KB
Newer Older
1
import ctypes
2
3
from collections import defaultdict
from functools import partial
4
from typing import Tuple
Martin Bauer's avatar
Martin Bauer committed
5

6
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
7
import sympy as sp
8
import sympy.codegen.ast
Martin Bauer's avatar
Martin Bauer committed
9
10
11
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean

12
import pystencils
13
from pystencils.cache import memorycache, memorycache_if_hashable
Martin Bauer's avatar
Martin Bauer committed
14
from pystencils.utils import all_equal
15

16
17
18
19
20
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
21

22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def typed_symbols(names, dtype, *args):
    symbols = sp.symbols(names, *args)
    if isinstance(symbols, Tuple):
        return tuple(TypedSymbol(str(s), dtype) for s in symbols)
    else:
        return TypedSymbol(str(symbols), dtype)


def matrix_symbols(names, dtype, rows, cols):
    if isinstance(names, str):
        names = names.replace(' ', '').split(',')

    matrices = []
    for n in names:
        symbols = typed_symbols("%s:%i" % (n, rows * cols), dtype)
        matrices.append(sp.Matrix(rows, cols, lambda i, j: symbols[i * cols + j]))

    return tuple(matrices)


43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# noinspection PyPep8Naming
class address_of(sp.Function):
    is_Atom = True

    def __new__(cls, arg):
        obj = sp.Function.__new__(cls, arg)
        return obj

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()

    @property
    def is_commutative(self):
        return self.args[0].is_commutative

    @property
    def dtype(self):
        if hasattr(self.args[0], 'dtype'):
            return PointerType(self.args[0].dtype, restrict=True)
        else:
            return PointerType('void', restrict=True)


Martin Bauer's avatar
Martin Bauer committed
70
# noinspection PyPep8Naming
71
class cast_func(sp.Function):
72
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
73

74
    def __new__(cls, *args, **kwargs):
75
76
77
78
79
        if len(args) != 2:
            pass
        expr, dtype, *other_args = args
        if not isinstance(dtype, Type):
            dtype = create_type(dtype)
80
81
82
83
84
85
86
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
87
        # -> thus a separate class boolean_cast_func is introduced
88
        if isinstance(expr, Boolean):
89
            cls = boolean_cast_func
90
        return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
91

92
93
94
95
96
97
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
98

99
100
101
102
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
103
104
105
    def _eval_evalf(self, *args, **kwargs):
        return self.args[0].evalf()

Martin Bauer's avatar
Martin Bauer committed
106
107
108
109
    @property
    def dtype(self):
        return self.args[1]

110
111
    @property
    def is_integer(self):
112
113
114
115
116
        """
        Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate

        For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
        """
117
118
119
120
121
122
123
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
        else:
            return super().is_integer

    @property
    def is_negative(self):
124
125
126
        """
        See :func:`.TypedSymbol.is_integer`
        """
127
128
129
130
131
132
133
134
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

        return super().is_negative

    @property
    def is_nonnegative(self):
135
136
137
        """
        See :func:`.TypedSymbol.is_integer`
        """
138
139
140
141
142
143
144
        if self.is_negative is False:
            return True
        else:
            return super().is_nonnegative

    @property
    def is_real(self):
145
146
147
        """
        See :func:`.TypedSymbol.is_integer`
        """
148
149
150
151
152
153
154
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real

Martin Bauer's avatar
Martin Bauer committed
155

156
157
158
159
160
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
161
162
163
164
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

165

166
167
168
169
170
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
    pass


Martin Bauer's avatar
Martin Bauer committed
171
172
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
173
174
175
176
177
178
179
180
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


181
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
182
183
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
184
185
        return obj

186
187
    def __new_stage2__(cls, name, dtype, *args, **kwargs):
        obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **kwargs)
188
        try:
Martin Bauer's avatar
Martin Bauer committed
189
            obj._dtype = create_type(dtype)
190
        except (TypeError, ValueError):
191
192
            # on error keep the string
            obj._dtype = dtype
193
194
195
196
197
198
199
200
201
202
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
203
        return super()._hashable_content(), hash(self._dtype)
204
205

    def __getnewargs__(self):
206
207
        return self.name, self.dtype

208
    # For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
209
210
    @property
    def is_integer(self):
211
212
213
214
215
        """
        Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate

        For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
        """
216
        if hasattr(self.dtype, 'numpy_dtype'):
217
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
218
        else:
219
220
221
222
            return super().is_integer

    @property
    def is_negative(self):
223
224
225
        """
        See :func:`.TypedSymbol.is_integer`
        """
226
227
228
229
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

230
        return super().is_negative
231

232
233
    @property
    def is_nonnegative(self):
234
235
236
        """
        See :func:`.TypedSymbol.is_integer`
        """
237
238
239
240
241
        if self.is_negative is False:
            return True
        else:
            return super().is_nonnegative

242
243
    @property
    def is_real(self):
244
245
246
        """
        See :func:`.TypedSymbol.is_integer`
        """
247
248
249
250
251
252
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real
253

254

Martin Bauer's avatar
Martin Bauer committed
255
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
256
257
258
259
260
261
262
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
263
    """
264
265
266
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
267
268
269
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
270
        else:
Martin Bauer's avatar
Martin Bauer committed
271
            return StructType(numpy_dtype, const=False)
272
273


274
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
275
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
276
277
278
279
280
281
282
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
283
    """
284
285
286
287
288
289
290
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
291
        else:
292
293
294
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
295
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
296
    base_part = parts.pop(0)
297
    const = False
Martin Bauer's avatar
Martin Bauer committed
298
    if 'const' in base_part:
299
        const = True
Martin Bauer's avatar
Martin Bauer committed
300
301
302
303
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
304
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
305
    current_type = BasicType(np.dtype(base_part[0]), const)
306
307
308
309
310
311
312
313
314
315
316
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
317
318
        current_type = PointerType(current_type, const, restrict)
    return current_type
319
320


Martin Bauer's avatar
Martin Bauer committed
321
322
323
324
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
325
326


Martin Bauer's avatar
Martin Bauer committed
327
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
328
329
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
330
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
331
332
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
333
334
335
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
336
        return ctypes.POINTER(ctypes.c_uint8)
337
    else:
Martin Bauer's avatar
Martin Bauer committed
338
        return to_ctypes.map[data_type.numpy_dtype]
339

Martin Bauer's avatar
Martin Bauer committed
340
341

to_ctypes.map = {
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


357
def ctypes_from_llvm(data_type):
358
359
    if not ir:
        raise _ir_importerror
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


387
def to_llvm_type(data_type, nvvm_target=False):
388
389
390
391
392
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
393
394
    if not ir:
        raise _ir_importerror
395
    if isinstance(data_type, PointerType):
396
        return to_llvm_type(data_type.base_type).as_pointer(1 if nvvm_target else 0)
397
    else:
Martin Bauer's avatar
Martin Bauer committed
398
399
        return to_llvm_type.map[data_type.numpy_dtype]

400

401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
416

417

Martin Bauer's avatar
Martin Bauer committed
418
419
420
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
421
422
423
    return dtype


424
def collate_types(types, forbid_collation_to_float=False):
425
426
427
428
429
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

430
431
432
433
434
    if forbid_collation_to_float:
        types = [t for t in types if not (hasattr(t, 'is_float') and t.is_float())]
        if not types:
            return [create_type('int32')]

435
436
    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
437
        pointer_type = None
438
439
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
440
                if pointer_type is not None:
441
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
442
                pointer_type = t
443
444
445
446
447
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
448
        return pointer_type
449
450

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
451
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
452
    if not all_equal(t.width for t in vector_type):
453
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
454
    types = [peel_off_type(t, VectorType) for t in types]
455
456
457
458

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

459
460
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
461
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
462
463
464
465
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
466
467
468
    return result


469
470
471
472
473
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
                           default_float_type='double',
                           default_int_type='int',
                           symbol_type_dict=None):
474
    from pystencils.astnodes import ResolvedFieldAccess
475
476
    from pystencils.cpu.vectorization import vec_all, vec_any

477
478
479
480
481
482
483
484
    if not symbol_type_dict:
        symbol_type_dict = defaultdict(lambda: create_type('double'))

    get_type = partial(get_type_of_expression,
                       default_float_type=default_float_type,
                       default_int_type=default_int_type,
                       symbol_type_dict=symbol_type_dict)

485
486
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
487
        return create_type(default_int_type)
488
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
489
        return create_type(default_float_type)
490
491
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
492
493
    elif isinstance(expr, pystencils.field.Field.AbstractAccess):
        return expr.field.dtype
494
    elif isinstance(expr, TypedSymbol):
495
        return expr.dtype
496
    elif isinstance(expr, sp.Symbol):
497
498
499
500
        if symbol_type_dict:
            return symbol_type_dict[expr.name]
        else:
            raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
501
    elif isinstance(expr, cast_func):
502
        return expr.args[1]
503
    elif isinstance(expr, (vec_any, vec_all)):
504
        return create_type("bool")
505
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
506
507
        collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
Martin Bauer's avatar
Martin Bauer committed
508
509
510
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
511
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
512
513
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
514
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
515
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
516
        result = create_type("bool")
517
        vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
Martin Bauer's avatar
Martin Bauer committed
518
519
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
520
        return result
521
522
    elif isinstance(expr, (sp.Pow, sp.Sum, sp.Product)):
        return get_type(expr.args[0])
523
    elif isinstance(expr, sp.Expr):
524
525
        expr: sp.Expr
        if expr.args:
526
            types = tuple(get_type(a) for a in expr.args)
527
528
529
530
531
532
            return collate_types(types)
        else:
            if expr.is_integer:
                return create_type(default_int_type)
            else:
                return create_type(default_float_type)
533

534
    raise NotImplementedError("Could not determine type for", expr, type(expr))
535
536


537
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
538
539
    is_Atom = True

540
541
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
542

543
544
    def _sympystr(self, *args, **kwargs):
        return str(self)
545
546
547
548


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
549
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
550
551
552
553
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
554
555
556
557
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
558
            width = int(name[len("uint"):])
559
560
561
562
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
Martin Bauer's avatar
Martin Bauer committed
563
            raise NotImplementedError("Can map numpy to C name for %s" % (name,))
564
565
566

    def __init__(self, dtype, const=False):
        self.const = const
567
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
568
            self._dtype = dtype.numpy_dtype
569
570
        else:
            self._dtype = np.dtype(dtype)
571
572
573
574
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

575
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
576
        return self.numpy_dtype, self.const
577

578
    @property
Martin Bauer's avatar
Martin Bauer committed
579
    def base_type(self):
580
        return None
581

582
    @property
Martin Bauer's avatar
Martin Bauer committed
583
    def numpy_dtype(self):
584
585
        return self._dtype

586
587
588
589
    @property
    def sympy_dtype(self):
        return getattr(sympy.codegen.ast, str(self.numpy_dtype))

590
    @property
Martin Bauer's avatar
Martin Bauer committed
591
    def item_size(self):
592
593
        return 1

594
    def is_int(self):
595
        return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint']
596
597

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
598
        return self.numpy_dtype in np.sctypes['float']
599
600

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
601
        return self.numpy_dtype in np.sctypes['uint']
602

Martin Bauer's avatar
Martin Bauer committed
603
604
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
605
606

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
607
        return self.numpy_dtype in np.sctypes['others']
608

609
    @property
Martin Bauer's avatar
Martin Bauer committed
610
611
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
612

Jan Hoenig's avatar
Jan Hoenig committed
613
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
614
        result = BasicType.numpy_name_to_c(str(self._dtype))
615
616
617
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
618

619
620
621
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
622
    def __eq__(self, other):
623
624
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
625
        else:
Martin Bauer's avatar
Martin Bauer committed
626
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
627
628
629
630
631

    def __hash__(self):
        return hash(str(self))


632
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
633
    instruction_set = None
634

Martin Bauer's avatar
Martin Bauer committed
635
636
    def __init__(self, base_type, width=4):
        self._base_type = base_type
637
638
639
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
640
641
    def base_type(self):
        return self._base_type
642
643

    @property
Martin Bauer's avatar
Martin Bauer committed
644
645
    def item_size(self):
        return self.width * self.base_type.item_size
646
647
648
649
650

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
651
            return (self.base_type, self.width) == (other.base_type, other.width)
652
653

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
654
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
655
            return "%s[%d]" % (self.base_type, self.width)
656
        else:
Martin Bauer's avatar
Martin Bauer committed
657
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
658
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
659
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
660
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
661
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
662
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
663
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
664
                return self.instruction_set['bool']
665
666
667
668
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
669
        return hash((self.base_type, self.width))
670

Martin Bauer's avatar
Martin Bauer committed
671
672
673
    def __getnewargs__(self):
        return self._base_type, self.width

674

675
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
676
677
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
678
679
680
        self.const = const
        self.restrict = restrict

681
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
682
        return self.base_type, self.const, self.restrict
683

684
685
686
687
688
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
689
690
    def base_type(self):
        return self._base_type
691

692
    @property
Martin Bauer's avatar
Martin Bauer committed
693
694
    def item_size(self):
        return self.base_type.item_size
695

696
697
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
698
            return False
699
        else:
Martin Bauer's avatar
Martin Bauer committed
700
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
701

Jan Hoenig's avatar
Jan Hoenig committed
702
    def __str__(self):
703
704
705
706
707
708
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
709

710
711
712
    def __repr__(self):
        return str(self)

713
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
714
        return hash((self._base_type, self.const, self.restrict))
715

Jan Hoenig's avatar
Jan Hoenig committed
716

717
class StructType:
Martin Bauer's avatar
Martin Bauer committed
718
    def __init__(self, numpy_type, const=False):
719
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
720
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
721

722
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
723
        return self.numpy_dtype, self.const
724

725
    @property
Martin Bauer's avatar
Martin Bauer committed
726
    def base_type(self):
727
728
729
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
730
    def numpy_dtype(self):
731
732
733
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
734
735
    def item_size(self):
        return self.numpy_dtype.itemsize
736

Martin Bauer's avatar
Martin Bauer committed
737
738
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
739

Martin Bauer's avatar
Martin Bauer committed
740
741
742
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
743

Martin Bauer's avatar
Martin Bauer committed
744
745
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
746

747
748
749
750
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
751
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
752
753
754
755
756
757
758
759

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

760
761
762
    def __repr__(self):
        return str(self)

763
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
764
        return hash((self.numpy_dtype, self.const))