data_types.py 22.4 KB
Newer Older
1
import ctypes
2
3
from collections import defaultdict
from functools import partial
4
from typing import Tuple
Martin Bauer's avatar
Martin Bauer committed
5

6
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
7
import sympy as sp
8
import sympy.codegen.ast
Martin Bauer's avatar
Martin Bauer committed
9
10
11
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean

12
import pystencils
13
from pystencils.cache import memorycache, memorycache_if_hashable
Martin Bauer's avatar
Martin Bauer committed
14
from pystencils.utils import all_equal
15

16
17
18
19
20
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
21

22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def typed_symbols(names, dtype, *args):
    symbols = sp.symbols(names, *args)
    if isinstance(symbols, Tuple):
        return tuple(TypedSymbol(str(s), dtype) for s in symbols)
    else:
        return TypedSymbol(str(symbols), dtype)


def matrix_symbols(names, dtype, rows, cols):
    if isinstance(names, str):
        names = names.replace(' ', '').split(',')

    matrices = []
    for n in names:
        symbols = typed_symbols("%s:%i" % (n, rows * cols), dtype)
        matrices.append(sp.Matrix(rows, cols, lambda i, j: symbols[i * cols + j]))

    return tuple(matrices)


43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def assumptions_from_dtype(dtype):
    """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype

    Args:
        dtype (BasicType, np.dtype): a Numpy data type
    Returns:
        A dict of SymPy assumptions
    """
    if hasattr(dtype, 'numpy_dtype'):
        dtype = dtype.numpy_dtype

    assumptions = dict()

    try:
        if np.issubdtype(dtype, np.integer):
            assumptions.update({'integer': True})

        if np.issubdtype(dtype, np.unsignedinteger):
            assumptions.update({'negative': False})

        if np.issubdtype(dtype, np.integer) or \
                np.issubdtype(dtype, np.floating):
            assumptions.update({'real': True})
    except Exception:
        pass

    return assumptions


72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# noinspection PyPep8Naming
class address_of(sp.Function):
    is_Atom = True

    def __new__(cls, arg):
        obj = sp.Function.__new__(cls, arg)
        return obj

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()

    @property
    def is_commutative(self):
        return self.args[0].is_commutative

    @property
    def dtype(self):
        if hasattr(self.args[0], 'dtype'):
            return PointerType(self.args[0].dtype, restrict=True)
        else:
            return PointerType('void', restrict=True)


Martin Bauer's avatar
Martin Bauer committed
99
# noinspection PyPep8Naming
100
class cast_func(sp.Function):
101
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
102

103
    def __new__(cls, *args, **kwargs):
104
105
106
107
108
        if len(args) != 2:
            pass
        expr, dtype, *other_args = args
        if not isinstance(dtype, Type):
            dtype = create_type(dtype)
109
110
111
112
113
114
115
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
116
        # -> thus a separate class boolean_cast_func is introduced
117
        if isinstance(expr, Boolean):
118
            cls = boolean_cast_func
119

120
        return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
121

122
123
124
125
126
127
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
128

129
130
131
132
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
133
134
135
    def _eval_evalf(self, *args, **kwargs):
        return self.args[0].evalf()

Martin Bauer's avatar
Martin Bauer committed
136
137
138
139
    @property
    def dtype(self):
        return self.args[1]

140
141
    @property
    def is_integer(self):
142
143
144
145
146
        """
        Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate

        For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
        """
147
148
149
150
151
152
153
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
        else:
            return super().is_integer

    @property
    def is_negative(self):
154
155
156
        """
        See :func:`.TypedSymbol.is_integer`
        """
157
158
159
160
161
162
163
164
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

        return super().is_negative

    @property
    def is_nonnegative(self):
165
166
167
        """
        See :func:`.TypedSymbol.is_integer`
        """
168
169
170
171
172
173
174
        if self.is_negative is False:
            return True
        else:
            return super().is_nonnegative

    @property
    def is_real(self):
175
176
177
        """
        See :func:`.TypedSymbol.is_integer`
        """
178
179
180
181
182
183
184
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real

Martin Bauer's avatar
Martin Bauer committed
185

186
187
188
189
190
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
191
192
193
194
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

195

196
197
198
199
200
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
    pass


Martin Bauer's avatar
Martin Bauer committed
201
202
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
203
204
205
206
207
208
209
210
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


211
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
212
213
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
214
215
        return obj

216
    def __new_stage2__(cls, name, dtype, *args, **kwargs):
217
218
        assumptions = assumptions_from_dtype(dtype)
        obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **assumptions, **kwargs)
219
        try:
Martin Bauer's avatar
Martin Bauer committed
220
            obj._dtype = create_type(dtype)
221
        except (TypeError, ValueError):
222
223
            # on error keep the string
            obj._dtype = dtype
224
225
226
227
228
229
230
231
232
233
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
234
        return super()._hashable_content(), hash(self._dtype)
235
236

    def __getnewargs__(self):
237
238
        return self.name, self.dtype

239
240
241
242
243
244
245
246
    @property
    def canonical(self):
        return self

    @property
    def reversed(self):
        return self

247

Martin Bauer's avatar
Martin Bauer committed
248
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
249
250
251
252
253
254
255
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
256
    """
257
258
259
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
260
261
262
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
263
        else:
Martin Bauer's avatar
Martin Bauer committed
264
            return StructType(numpy_dtype, const=False)
265
266


267
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
268
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
269
270
271
272
273
274
275
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
276
    """
277
278
279
280
281
282
283
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
284
        else:
285
286
287
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
288
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
289
    base_part = parts.pop(0)
290
    const = False
Martin Bauer's avatar
Martin Bauer committed
291
    if 'const' in base_part:
292
        const = True
Martin Bauer's avatar
Martin Bauer committed
293
294
295
296
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
297
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
298
    current_type = BasicType(np.dtype(base_part[0]), const)
299
300
301
302
303
304
305
306
307
308
309
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
310
311
        current_type = PointerType(current_type, const, restrict)
    return current_type
312
313


Martin Bauer's avatar
Martin Bauer committed
314
315
316
317
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
318
319


Martin Bauer's avatar
Martin Bauer committed
320
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
321
322
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
323
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
324
325
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
326
327
328
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
329
        return ctypes.POINTER(ctypes.c_uint8)
330
    else:
Martin Bauer's avatar
Martin Bauer committed
331
        return to_ctypes.map[data_type.numpy_dtype]
332

Martin Bauer's avatar
Martin Bauer committed
333
334

to_ctypes.map = {
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


350
def ctypes_from_llvm(data_type):
351
352
    if not ir:
        raise _ir_importerror
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


380
def to_llvm_type(data_type, nvvm_target=False):
381
382
383
384
385
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
386
387
    if not ir:
        raise _ir_importerror
388
    if isinstance(data_type, PointerType):
389
        return to_llvm_type(data_type.base_type).as_pointer(1 if nvvm_target else 0)
390
    else:
Martin Bauer's avatar
Martin Bauer committed
391
392
        return to_llvm_type.map[data_type.numpy_dtype]

393

394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
409

410

Martin Bauer's avatar
Martin Bauer committed
411
412
413
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
414
415
416
    return dtype


417
def collate_types(types, forbid_collation_to_float=False):
418
419
420
421
422
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

423
424
425
    if forbid_collation_to_float:
        types = [t for t in types if not (hasattr(t, 'is_float') and t.is_float())]
        if not types:
426
            return create_type('int32')
427

428
429
    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
430
        pointer_type = None
431
432
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
433
                if pointer_type is not None:
434
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
435
                pointer_type = t
436
437
438
439
440
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
441
        return pointer_type
442
443

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
444
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
445
    if not all_equal(t.width for t in vector_type):
446
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
447
    types = [peel_off_type(t, VectorType) for t in types]
448
449
450
451

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

452
453
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
454
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
455
456
457
458
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
459
460
461
    return result


462
463
464
465
466
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
                           default_float_type='double',
                           default_int_type='int',
                           symbol_type_dict=None):
467
    from pystencils.astnodes import ResolvedFieldAccess
468
469
    from pystencils.cpu.vectorization import vec_all, vec_any

470
471
472
473
474
475
476
477
    if not symbol_type_dict:
        symbol_type_dict = defaultdict(lambda: create_type('double'))

    get_type = partial(get_type_of_expression,
                       default_float_type=default_float_type,
                       default_int_type=default_int_type,
                       symbol_type_dict=symbol_type_dict)

478
479
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
480
        return create_type(default_int_type)
481
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
482
        return create_type(default_float_type)
483
484
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
485
486
    elif isinstance(expr, pystencils.field.Field.AbstractAccess):
        return expr.field.dtype
487
    elif isinstance(expr, TypedSymbol):
488
        return expr.dtype
489
    elif isinstance(expr, sp.Symbol):
490
491
492
493
        if symbol_type_dict:
            return symbol_type_dict[expr.name]
        else:
            raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
494
    elif isinstance(expr, cast_func):
495
        return expr.args[1]
496
    elif isinstance(expr, (vec_any, vec_all)):
497
        return create_type("bool")
498
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
499
500
        collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
Martin Bauer's avatar
Martin Bauer committed
501
502
503
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
504
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
505
506
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
507
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
508
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
509
        result = create_type("bool")
510
        vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
Martin Bauer's avatar
Martin Bauer committed
511
512
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
513
        return result
514
515
    elif isinstance(expr, (sp.Pow, sp.Sum, sp.Product)):
        return get_type(expr.args[0])
516
    elif isinstance(expr, sp.Expr):
517
518
        expr: sp.Expr
        if expr.args:
519
            types = tuple(get_type(a) for a in expr.args)
520
521
522
523
524
525
            return collate_types(types)
        else:
            if expr.is_integer:
                return create_type(default_int_type)
            else:
                return create_type(default_float_type)
526

527
    raise NotImplementedError("Could not determine type for", expr, type(expr))
528
529


530
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
531
532
    is_Atom = True

533
534
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
535

536
537
    def _sympystr(self, *args, **kwargs):
        return str(self)
538
539
540
541


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
542
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
543
544
545
546
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
547
548
549
550
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
551
            width = int(name[len("uint"):])
552
553
554
555
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
Martin Bauer's avatar
Martin Bauer committed
556
            raise NotImplementedError("Can map numpy to C name for %s" % (name,))
557
558
559

    def __init__(self, dtype, const=False):
        self.const = const
560
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
561
            self._dtype = dtype.numpy_dtype
562
563
        else:
            self._dtype = np.dtype(dtype)
564
565
566
567
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

568
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
569
        return self.numpy_dtype, self.const
570

571
    @property
Martin Bauer's avatar
Martin Bauer committed
572
    def base_type(self):
573
        return None
574

575
    @property
Martin Bauer's avatar
Martin Bauer committed
576
    def numpy_dtype(self):
577
578
        return self._dtype

579
580
581
582
    @property
    def sympy_dtype(self):
        return getattr(sympy.codegen.ast, str(self.numpy_dtype))

583
    @property
Martin Bauer's avatar
Martin Bauer committed
584
    def item_size(self):
585
586
        return 1

587
    def is_int(self):
588
        return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint']
589
590

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
591
        return self.numpy_dtype in np.sctypes['float']
592
593

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
594
        return self.numpy_dtype in np.sctypes['uint']
595

Martin Bauer's avatar
Martin Bauer committed
596
597
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
598
599

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
600
        return self.numpy_dtype in np.sctypes['others']
601

602
    @property
Martin Bauer's avatar
Martin Bauer committed
603
604
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
605

Jan Hoenig's avatar
Jan Hoenig committed
606
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
607
        result = BasicType.numpy_name_to_c(str(self._dtype))
608
609
610
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
611

612
613
614
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
615
    def __eq__(self, other):
616
617
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
618
        else:
Martin Bauer's avatar
Martin Bauer committed
619
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
620
621
622
623
624

    def __hash__(self):
        return hash(str(self))


625
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
626
    instruction_set = None
627

Martin Bauer's avatar
Martin Bauer committed
628
629
    def __init__(self, base_type, width=4):
        self._base_type = base_type
630
631
632
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
633
634
    def base_type(self):
        return self._base_type
635
636

    @property
Martin Bauer's avatar
Martin Bauer committed
637
638
    def item_size(self):
        return self.width * self.base_type.item_size
639
640
641
642
643

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
644
            return (self.base_type, self.width) == (other.base_type, other.width)
645
646

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
647
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
648
            return "%s[%d]" % (self.base_type, self.width)
649
        else:
Martin Bauer's avatar
Martin Bauer committed
650
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
651
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
652
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
653
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
654
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
655
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
656
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
657
                return self.instruction_set['bool']
658
659
660
661
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
662
        return hash((self.base_type, self.width))
663

Martin Bauer's avatar
Martin Bauer committed
664
665
666
    def __getnewargs__(self):
        return self._base_type, self.width

667

668
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
669
670
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
671
672
673
        self.const = const
        self.restrict = restrict

674
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
675
        return self.base_type, self.const, self.restrict
676

677
678
679
680
681
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
682
683
    def base_type(self):
        return self._base_type
684

685
    @property
Martin Bauer's avatar
Martin Bauer committed
686
687
    def item_size(self):
        return self.base_type.item_size
688

689
690
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
691
            return False
692
        else:
Martin Bauer's avatar
Martin Bauer committed
693
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
694

Jan Hoenig's avatar
Jan Hoenig committed
695
    def __str__(self):
696
697
698
699
700
701
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
702

703
704
705
    def __repr__(self):
        return str(self)

706
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
707
        return hash((self._base_type, self.const, self.restrict))
708

Jan Hoenig's avatar
Jan Hoenig committed
709

710
class StructType:
Martin Bauer's avatar
Martin Bauer committed
711
    def __init__(self, numpy_type, const=False):
712
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
713
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
714

715
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
716
        return self.numpy_dtype, self.const
717

718
    @property
Martin Bauer's avatar
Martin Bauer committed
719
    def base_type(self):
720
721
722
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
723
    def numpy_dtype(self):
724
725
726
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
727
728
    def item_size(self):
        return self.numpy_dtype.itemsize
729

Martin Bauer's avatar
Martin Bauer committed
730
731
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
732

Martin Bauer's avatar
Martin Bauer committed
733
734
735
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
736

Martin Bauer's avatar
Martin Bauer committed
737
738
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
739

740
741
742
743
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
744
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
745
746
747
748
749
750
751
752

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

753
754
755
    def __repr__(self):
        return str(self)

756
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
757
        return hash((self.numpy_dtype, self.const))