data_types.py 22.3 KB
Newer Older
1
import ctypes
2
3
from collections import defaultdict
from functools import partial
4
from typing import Tuple
Martin Bauer's avatar
Martin Bauer committed
5

6
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
7
import sympy as sp
8
import sympy.codegen.ast
Martin Bauer's avatar
Martin Bauer committed
9
10
11
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean

12
import pystencils
13
from pystencils.cache import memorycache, memorycache_if_hashable
Martin Bauer's avatar
Martin Bauer committed
14
from pystencils.utils import all_equal
15

16
17
18
19
20
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
21

22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def typed_symbols(names, dtype, *args):
    symbols = sp.symbols(names, *args)
    if isinstance(symbols, Tuple):
        return tuple(TypedSymbol(str(s), dtype) for s in symbols)
    else:
        return TypedSymbol(str(symbols), dtype)


def matrix_symbols(names, dtype, rows, cols):
    if isinstance(names, str):
        names = names.replace(' ', '').split(',')

    matrices = []
    for n in names:
        symbols = typed_symbols("%s:%i" % (n, rows * cols), dtype)
        matrices.append(sp.Matrix(rows, cols, lambda i, j: symbols[i * cols + j]))

    return tuple(matrices)


43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def assumptions_from_dtype(dtype):
    """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype

    Args:
        dtype (BasicType, np.dtype): a Numpy data type
    Returns:
        A dict of SymPy assumptions
    """
    if hasattr(dtype, 'numpy_dtype'):
        dtype = dtype.numpy_dtype

    assumptions = dict()

    try:
        if np.issubdtype(dtype, np.integer):
            assumptions.update({'integer': True})

        if np.issubdtype(dtype, np.unsignedinteger):
            assumptions.update({'negative': False})

        if np.issubdtype(dtype, np.integer) or \
                np.issubdtype(dtype, np.floating):
            assumptions.update({'real': True})
    except Exception:
        pass

    return assumptions


72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# noinspection PyPep8Naming
class address_of(sp.Function):
    is_Atom = True

    def __new__(cls, arg):
        obj = sp.Function.__new__(cls, arg)
        return obj

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()

    @property
    def is_commutative(self):
        return self.args[0].is_commutative

    @property
    def dtype(self):
        if hasattr(self.args[0], 'dtype'):
            return PointerType(self.args[0].dtype, restrict=True)
        else:
            return PointerType('void', restrict=True)


Martin Bauer's avatar
Martin Bauer committed
99
# noinspection PyPep8Naming
100
class cast_func(sp.Function):
101
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
102

103
    def __new__(cls, *args, **kwargs):
104
105
106
107
108
        if len(args) != 2:
            pass
        expr, dtype, *other_args = args
        if not isinstance(dtype, Type):
            dtype = create_type(dtype)
109
110
111
112
113
114
115
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
116
        # -> thus a separate class boolean_cast_func is introduced
117
        if isinstance(expr, Boolean):
118
            cls = boolean_cast_func
119

120
        return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
121

122
123
124
125
126
127
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
128

129
130
131
132
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
133
134
135
    def _eval_evalf(self, *args, **kwargs):
        return self.args[0].evalf()

Martin Bauer's avatar
Martin Bauer committed
136
137
138
139
    @property
    def dtype(self):
        return self.args[1]

140
141
    @property
    def is_integer(self):
142
143
144
145
146
        """
        Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate

        For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
        """
147
148
149
150
151
152
153
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
        else:
            return super().is_integer

    @property
    def is_negative(self):
154
155
156
        """
        See :func:`.TypedSymbol.is_integer`
        """
157
158
159
160
161
162
163
164
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

        return super().is_negative

    @property
    def is_nonnegative(self):
165
166
167
        """
        See :func:`.TypedSymbol.is_integer`
        """
168
169
170
171
172
173
174
        if self.is_negative is False:
            return True
        else:
            return super().is_nonnegative

    @property
    def is_real(self):
175
176
177
        """
        See :func:`.TypedSymbol.is_integer`
        """
178
179
180
181
182
183
184
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real

Martin Bauer's avatar
Martin Bauer committed
185

186
187
188
189
190
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
191
192
193
194
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

195

196
197
198
199
200
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
    pass


Martin Bauer's avatar
Martin Bauer committed
201
202
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
203
204
205
206
207
208
209
210
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


211
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
212
213
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
214
215
        return obj

216
    def __new_stage2__(cls, name, dtype, *args, **kwargs):
217
218
        assumptions = assumptions_from_dtype(dtype)
        obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **assumptions, **kwargs)
219
        try:
Martin Bauer's avatar
Martin Bauer committed
220
            obj._dtype = create_type(dtype)
221
        except (TypeError, ValueError):
222
223
            # on error keep the string
            obj._dtype = dtype
224
225
226
227
228
229
230
231
232
233
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
234
        return super()._hashable_content(), hash(self._dtype)
235
236

    def __getnewargs__(self):
237
238
239
        return self.name, self.dtype


Martin Bauer's avatar
Martin Bauer committed
240
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
241
242
243
244
245
246
247
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
248
    """
249
250
251
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
252
253
254
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
255
        else:
Martin Bauer's avatar
Martin Bauer committed
256
            return StructType(numpy_dtype, const=False)
257
258


259
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
260
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
261
262
263
264
265
266
267
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
268
    """
269
270
271
272
273
274
275
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
276
        else:
277
278
279
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
280
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
281
    base_part = parts.pop(0)
282
    const = False
Martin Bauer's avatar
Martin Bauer committed
283
    if 'const' in base_part:
284
        const = True
Martin Bauer's avatar
Martin Bauer committed
285
286
287
288
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
289
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
290
    current_type = BasicType(np.dtype(base_part[0]), const)
291
292
293
294
295
296
297
298
299
300
301
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
302
303
        current_type = PointerType(current_type, const, restrict)
    return current_type
304
305


Martin Bauer's avatar
Martin Bauer committed
306
307
308
309
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
310
311


Martin Bauer's avatar
Martin Bauer committed
312
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
313
314
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
315
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
316
317
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
318
319
320
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
321
        return ctypes.POINTER(ctypes.c_uint8)
322
    else:
Martin Bauer's avatar
Martin Bauer committed
323
        return to_ctypes.map[data_type.numpy_dtype]
324

Martin Bauer's avatar
Martin Bauer committed
325
326

to_ctypes.map = {
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


342
def ctypes_from_llvm(data_type):
343
344
    if not ir:
        raise _ir_importerror
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


372
def to_llvm_type(data_type, nvvm_target=False):
373
374
375
376
377
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
378
379
    if not ir:
        raise _ir_importerror
380
    if isinstance(data_type, PointerType):
381
        return to_llvm_type(data_type.base_type).as_pointer(1 if nvvm_target else 0)
382
    else:
Martin Bauer's avatar
Martin Bauer committed
383
384
        return to_llvm_type.map[data_type.numpy_dtype]

385

386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
401

402

Martin Bauer's avatar
Martin Bauer committed
403
404
405
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
406
407
408
    return dtype


409
def collate_types(types, forbid_collation_to_float=False):
410
411
412
413
414
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

415
416
417
    if forbid_collation_to_float:
        types = [t for t in types if not (hasattr(t, 'is_float') and t.is_float())]
        if not types:
418
            return create_type('int32')
419

420
421
    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
422
        pointer_type = None
423
424
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
425
                if pointer_type is not None:
426
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
427
                pointer_type = t
428
429
430
431
432
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
433
        return pointer_type
434
435

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
436
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
437
    if not all_equal(t.width for t in vector_type):
438
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
439
    types = [peel_off_type(t, VectorType) for t in types]
440
441
442
443

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

444
445
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
446
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
447
448
449
450
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
451
452
453
    return result


454
455
456
457
458
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
                           default_float_type='double',
                           default_int_type='int',
                           symbol_type_dict=None):
459
    from pystencils.astnodes import ResolvedFieldAccess
460
461
    from pystencils.cpu.vectorization import vec_all, vec_any

462
463
464
465
466
467
468
469
    if not symbol_type_dict:
        symbol_type_dict = defaultdict(lambda: create_type('double'))

    get_type = partial(get_type_of_expression,
                       default_float_type=default_float_type,
                       default_int_type=default_int_type,
                       symbol_type_dict=symbol_type_dict)

470
471
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
472
        return create_type(default_int_type)
473
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
474
        return create_type(default_float_type)
475
476
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
477
478
    elif isinstance(expr, pystencils.field.Field.AbstractAccess):
        return expr.field.dtype
479
    elif isinstance(expr, TypedSymbol):
480
        return expr.dtype
481
    elif isinstance(expr, sp.Symbol):
482
483
484
485
        if symbol_type_dict:
            return symbol_type_dict[expr.name]
        else:
            raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
486
    elif isinstance(expr, cast_func):
487
        return expr.args[1]
488
    elif isinstance(expr, (vec_any, vec_all)):
489
        return create_type("bool")
490
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
491
492
        collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
Martin Bauer's avatar
Martin Bauer committed
493
494
495
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
496
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
497
498
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
499
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
500
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
501
        result = create_type("bool")
502
        vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
Martin Bauer's avatar
Martin Bauer committed
503
504
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
505
        return result
506
507
    elif isinstance(expr, (sp.Pow, sp.Sum, sp.Product)):
        return get_type(expr.args[0])
508
    elif isinstance(expr, sp.Expr):
509
510
        expr: sp.Expr
        if expr.args:
511
            types = tuple(get_type(a) for a in expr.args)
512
513
514
515
516
517
            return collate_types(types)
        else:
            if expr.is_integer:
                return create_type(default_int_type)
            else:
                return create_type(default_float_type)
518

519
    raise NotImplementedError("Could not determine type for", expr, type(expr))
520
521


522
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
523
524
    is_Atom = True

525
526
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
527

528
529
    def _sympystr(self, *args, **kwargs):
        return str(self)
530
531
532
533


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
534
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
535
536
537
538
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
539
540
541
542
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
543
            width = int(name[len("uint"):])
544
545
546
547
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
Martin Bauer's avatar
Martin Bauer committed
548
            raise NotImplementedError("Can map numpy to C name for %s" % (name,))
549
550
551

    def __init__(self, dtype, const=False):
        self.const = const
552
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
553
            self._dtype = dtype.numpy_dtype
554
555
        else:
            self._dtype = np.dtype(dtype)
556
557
558
559
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

560
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
561
        return self.numpy_dtype, self.const
562

563
    @property
Martin Bauer's avatar
Martin Bauer committed
564
    def base_type(self):
565
        return None
566

567
    @property
Martin Bauer's avatar
Martin Bauer committed
568
    def numpy_dtype(self):
569
570
        return self._dtype

571
572
573
574
    @property
    def sympy_dtype(self):
        return getattr(sympy.codegen.ast, str(self.numpy_dtype))

575
    @property
Martin Bauer's avatar
Martin Bauer committed
576
    def item_size(self):
577
578
        return 1

579
    def is_int(self):
580
        return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint']
581
582

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
583
        return self.numpy_dtype in np.sctypes['float']
584
585

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
586
        return self.numpy_dtype in np.sctypes['uint']
587

Martin Bauer's avatar
Martin Bauer committed
588
589
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
590
591

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
592
        return self.numpy_dtype in np.sctypes['others']
593

594
    @property
Martin Bauer's avatar
Martin Bauer committed
595
596
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
597

Jan Hoenig's avatar
Jan Hoenig committed
598
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
599
        result = BasicType.numpy_name_to_c(str(self._dtype))
600
601
602
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
603

604
605
606
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
607
    def __eq__(self, other):
608
609
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
610
        else:
Martin Bauer's avatar
Martin Bauer committed
611
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
612
613
614
615
616

    def __hash__(self):
        return hash(str(self))


617
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
618
    instruction_set = None
619

Martin Bauer's avatar
Martin Bauer committed
620
621
    def __init__(self, base_type, width=4):
        self._base_type = base_type
622
623
624
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
625
626
    def base_type(self):
        return self._base_type
627
628

    @property
Martin Bauer's avatar
Martin Bauer committed
629
630
    def item_size(self):
        return self.width * self.base_type.item_size
631
632
633
634
635

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
636
            return (self.base_type, self.width) == (other.base_type, other.width)
637
638

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
639
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
640
            return "%s[%d]" % (self.base_type, self.width)
641
        else:
Martin Bauer's avatar
Martin Bauer committed
642
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
643
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
644
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
645
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
646
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
647
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
648
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
649
                return self.instruction_set['bool']
650
651
652
653
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
654
        return hash((self.base_type, self.width))
655

Martin Bauer's avatar
Martin Bauer committed
656
657
658
    def __getnewargs__(self):
        return self._base_type, self.width

659

660
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
661
662
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
663
664
665
        self.const = const
        self.restrict = restrict

666
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
667
        return self.base_type, self.const, self.restrict
668

669
670
671
672
673
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
674
675
    def base_type(self):
        return self._base_type
676

677
    @property
Martin Bauer's avatar
Martin Bauer committed
678
679
    def item_size(self):
        return self.base_type.item_size
680

681
682
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
683
            return False
684
        else:
Martin Bauer's avatar
Martin Bauer committed
685
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
686

Jan Hoenig's avatar
Jan Hoenig committed
687
    def __str__(self):
688
689
690
691
692
693
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
694

695
696
697
    def __repr__(self):
        return str(self)

698
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
699
        return hash((self._base_type, self.const, self.restrict))
700

Jan Hoenig's avatar
Jan Hoenig committed
701

702
class StructType:
Martin Bauer's avatar
Martin Bauer committed
703
    def __init__(self, numpy_type, const=False):
704
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
705
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
706

707
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
708
        return self.numpy_dtype, self.const
709

710
    @property
Martin Bauer's avatar
Martin Bauer committed
711
    def base_type(self):
712
713
714
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
715
    def numpy_dtype(self):
716
717
718
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
719
720
    def item_size(self):
        return self.numpy_dtype.itemsize
721

Martin Bauer's avatar
Martin Bauer committed
722
723
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
724

Martin Bauer's avatar
Martin Bauer committed
725
726
727
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
728

Martin Bauer's avatar
Martin Bauer committed
729
730
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
731

732
733
734
735
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
736
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
737
738
739
740
741
742
743
744

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

745
746
747
    def __repr__(self):
        return str(self)

748
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
749
        return hash((self.numpy_dtype, self.const))