data_types.py 24.5 KB
Newer Older
1
import ctypes
2
3
from collections import defaultdict
from functools import partial
4
from typing import Tuple
Martin Bauer's avatar
Martin Bauer committed
5

6
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
7

8
import pystencils
9
10
import sympy as sp
import sympy.codegen.ast
11
from pystencils.cache import memorycache, memorycache_if_hashable
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.utils import all_equal
13
14
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean
15

16
17
18
19
20
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
21

22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def typed_symbols(names, dtype, *args):
    symbols = sp.symbols(names, *args)
    if isinstance(symbols, Tuple):
        return tuple(TypedSymbol(str(s), dtype) for s in symbols)
    else:
        return TypedSymbol(str(symbols), dtype)


def matrix_symbols(names, dtype, rows, cols):
    if isinstance(names, str):
        names = names.replace(' ', '').split(',')

    matrices = []
    for n in names:
        symbols = typed_symbols("%s:%i" % (n, rows * cols), dtype)
        matrices.append(sp.Matrix(rows, cols, lambda i, j: symbols[i * cols + j]))

    return tuple(matrices)


43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def assumptions_from_dtype(dtype):
    """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype

    Args:
        dtype (BasicType, np.dtype): a Numpy data type
    Returns:
        A dict of SymPy assumptions
    """
    if hasattr(dtype, 'numpy_dtype'):
        dtype = dtype.numpy_dtype

    assumptions = dict()

    try:
        if np.issubdtype(dtype, np.integer):
            assumptions.update({'integer': True})

        if np.issubdtype(dtype, np.unsignedinteger):
            assumptions.update({'negative': False})

        if np.issubdtype(dtype, np.integer) or \
                np.issubdtype(dtype, np.floating):
            assumptions.update({'real': True})
    except Exception:
        pass

    return assumptions


72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# noinspection PyPep8Naming
class address_of(sp.Function):
    is_Atom = True

    def __new__(cls, arg):
        obj = sp.Function.__new__(cls, arg)
        return obj

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()

    @property
    def is_commutative(self):
        return self.args[0].is_commutative

    @property
    def dtype(self):
        if hasattr(self.args[0], 'dtype'):
            return PointerType(self.args[0].dtype, restrict=True)
        else:
            return PointerType('void', restrict=True)


Martin Bauer's avatar
Martin Bauer committed
99
# noinspection PyPep8Naming
100
class cast_func(sp.Function):
101
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
102

103
    def __new__(cls, *args, **kwargs):
104
105
106
107
108
        if len(args) != 2:
            pass
        expr, dtype, *other_args = args
        if not isinstance(dtype, Type):
            dtype = create_type(dtype)
109
110
111
112
113
114
115
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
116
        # -> thus a separate class boolean_cast_func is introduced
117
        if isinstance(expr, Boolean):
118
            cls = boolean_cast_func
119

120
        return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
121

122
123
124
125
126
127
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
128

129
130
131
132
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
133
134
135
    def _eval_evalf(self, *args, **kwargs):
        return self.args[0].evalf()

Martin Bauer's avatar
Martin Bauer committed
136
137
138
139
    @property
    def dtype(self):
        return self.args[1]

140
141
    @property
    def is_integer(self):
142
143
144
145
146
        """
        Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate

        For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
        """
147
148
149
150
151
152
153
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
        else:
            return super().is_integer

    @property
    def is_negative(self):
154
155
156
        """
        See :func:`.TypedSymbol.is_integer`
        """
157
158
159
160
161
162
163
164
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

        return super().is_negative

    @property
    def is_nonnegative(self):
165
166
167
        """
        See :func:`.TypedSymbol.is_integer`
        """
168
169
170
171
172
173
174
        if self.is_negative is False:
            return True
        else:
            return super().is_nonnegative

    @property
    def is_real(self):
175
176
177
        """
        See :func:`.TypedSymbol.is_integer`
        """
178
179
180
181
182
183
184
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real

Martin Bauer's avatar
Martin Bauer committed
185

186
187
188
189
190
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
191
192
193
194
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

195

196
197
198
199
200
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
    pass


Martin Bauer's avatar
Martin Bauer committed
201
202
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
203
204
205
206
207
208
209
210
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


211
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
212
213
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
214
215
        return obj

216
    def __new_stage2__(cls, name, dtype, *args, **kwargs):
217
218
        assumptions = assumptions_from_dtype(dtype)
        obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **assumptions, **kwargs)
219
        try:
Martin Bauer's avatar
Martin Bauer committed
220
            obj._dtype = create_type(dtype)
221
        except (TypeError, ValueError):
222
223
            # on error keep the string
            obj._dtype = dtype
224
225
226
227
228
229
230
231
232
233
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
234
        return super()._hashable_content(), hash(self._dtype)
235
236

    def __getnewargs__(self):
237
238
        return self.name, self.dtype

239
240
241
242
243
244
245
246
    @property
    def canonical(self):
        return self

    @property
    def reversed(self):
        return self

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    @property
    def headers(self):
        headers = []
        try:
            if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating):
                headers.append('"cuda_complex.hpp"')
        except Exception:
            pass
        try:
            if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating):
                headers.append('"cuda_complex.hpp"')
        except Exception:
            pass

        return headers

263

Martin Bauer's avatar
Martin Bauer committed
264
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
265
266
267
268
269
270
271
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
272
    """
273
274
275
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
276
277
278
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
279
        else:
Martin Bauer's avatar
Martin Bauer committed
280
            return StructType(numpy_dtype, const=False)
281
282


283
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
284
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
285
286
287
288
289
290
291
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
292
    """
293
294
295
296
297
298
299
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
300
        else:
301
302
303
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
304
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
305
    base_part = parts.pop(0)
306
    const = False
Martin Bauer's avatar
Martin Bauer committed
307
    if 'const' in base_part:
308
        const = True
Martin Bauer's avatar
Martin Bauer committed
309
310
311
312
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
313
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
314
    current_type = BasicType(np.dtype(base_part[0]), const)
315
316
317
318
319
320
321
322
323
324
325
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
326
327
        current_type = PointerType(current_type, const, restrict)
    return current_type
328
329


Martin Bauer's avatar
Martin Bauer committed
330
331
332
333
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
334
335


Martin Bauer's avatar
Martin Bauer committed
336
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
337
338
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
339
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
340
341
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
342
343
344
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
345
        return ctypes.POINTER(ctypes.c_uint8)
346
    else:
Martin Bauer's avatar
Martin Bauer committed
347
        return to_ctypes.map[data_type.numpy_dtype]
348

Martin Bauer's avatar
Martin Bauer committed
349
350

to_ctypes.map = {
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


366
def ctypes_from_llvm(data_type):
367
368
    if not ir:
        raise _ir_importerror
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


396
def to_llvm_type(data_type, nvvm_target=False):
397
398
399
400
401
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
402
403
    if not ir:
        raise _ir_importerror
404
    if isinstance(data_type, PointerType):
405
        return to_llvm_type(data_type.base_type).as_pointer(1 if nvvm_target else 0)
406
    else:
Martin Bauer's avatar
Martin Bauer committed
407
408
        return to_llvm_type.map[data_type.numpy_dtype]

409

410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
425

426

Martin Bauer's avatar
Martin Bauer committed
427
428
429
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
430
431
432
    return dtype


433
434
def collate_types(types,
                  forbid_collation_to_complex=False,
435
436
437
                  forbid_collation_to_float=False,
                  default_float_type='float64',
                  default_int_type='int64'):
438
439
440
441
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """
442
443
444
445
446
447
    if forbid_collation_to_complex:
        types = [
            t for t in types
            if not np.issubdtype(t.numpy_dtype, np.complexfloating)
        ]
        if not types:
448
            return create_type(default_float_type)
449

450
    if forbid_collation_to_float:
451
452
453
        types = [
            t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)
        ]
454
        if not types:
455
            return create_type(default_int_type)
456

457
458
    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
459
        pointer_type = None
460
461
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
462
                if pointer_type is not None:
463
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
464
                pointer_type = t
465
466
467
468
469
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
470
        return pointer_type
471
472

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
473
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
474
    if not all_equal(t.width for t in vector_type):
475
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
476
    types = [peel_off_type(t, VectorType) for t in types]
477
478
479
480

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

481
482
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
483
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
484
485
486
487
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
488
489
490
    return result


491
492
493
494
495
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
                           default_float_type='double',
                           default_int_type='int',
                           symbol_type_dict=None):
496
    from pystencils.astnodes import ResolvedFieldAccess
497
498
    from pystencils.cpu.vectorization import vec_all, vec_any

499
500
501
502
503
504
505
506
    if not symbol_type_dict:
        symbol_type_dict = defaultdict(lambda: create_type('double'))

    get_type = partial(get_type_of_expression,
                       default_float_type=default_float_type,
                       default_int_type=default_int_type,
                       symbol_type_dict=symbol_type_dict)

507
508
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
509
        return create_type(default_int_type)
510
511
    elif expr.is_real is False:
        return create_type((np.zeros((1,), default_float_type) * 1j).dtype)
512
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
513
        return create_type(default_float_type)
514
515
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
516
517
    elif isinstance(expr, pystencils.field.Field.AbstractAccess):
        return expr.field.dtype
518
    elif isinstance(expr, TypedSymbol):
519
        return expr.dtype
520
    elif isinstance(expr, sp.Symbol):
521
522
523
524
        if symbol_type_dict:
            return symbol_type_dict[expr.name]
        else:
            raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
525
    elif isinstance(expr, cast_func):
526
        return expr.args[1]
527
    elif isinstance(expr, (vec_any, vec_all)):
528
        return create_type("bool")
529
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
530
531
        collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
Martin Bauer's avatar
Martin Bauer committed
532
533
534
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
535
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
536
537
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
538
    elif isinstance(expr, (sp.boolalg.Boolean, sp.boolalg.BooleanFunction)):
539
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
540
        result = create_type("bool")
541
        vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
Martin Bauer's avatar
Martin Bauer committed
542
543
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
544
        return result
545
546
    elif isinstance(expr, (sp.Pow, sp.Sum, sp.Product)):
        return get_type(expr.args[0])
547
    elif isinstance(expr, sp.Expr):
548
549
        expr: sp.Expr
        if expr.args:
550
            types = tuple(get_type(a) for a in expr.args)
551
552
553
            return collate_types(
                types,
                forbid_collation_to_complex=expr.is_real is True,
554
555
556
                forbid_collation_to_float=expr.is_integer is True,
                default_float_type=default_float_type,
                default_int_type=default_int_type)
557
558
559
560
561
        else:
            if expr.is_integer:
                return create_type(default_int_type)
            else:
                return create_type(default_float_type)
562

563
    raise NotImplementedError("Could not determine type for", expr, type(expr))
564
565


566
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
567
568
    is_Atom = True

569
570
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
571

572
573
    def _sympystr(self, *args, **kwargs):
        return str(self)
574
575
576
577


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
578
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
579
580
581
582
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
583
584
585
586
        elif name == 'complex64':
            return 'ComplexFloat'
        elif name == 'complex128':
            return 'ComplexDouble'
587
588
589
590
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
591
            width = int(name[len("uint"):])
592
593
594
595
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
Martin Bauer's avatar
Martin Bauer committed
596
            raise NotImplementedError("Can map numpy to C name for %s" % (name,))
597
598
599

    def __init__(self, dtype, const=False):
        self.const = const
600
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
601
            self._dtype = dtype.numpy_dtype
602
603
        else:
            self._dtype = np.dtype(dtype)
604
605
606
607
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

608
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
609
        return self.numpy_dtype, self.const
610

611
    @property
Martin Bauer's avatar
Martin Bauer committed
612
    def base_type(self):
613
        return None
614

615
    @property
Martin Bauer's avatar
Martin Bauer committed
616
    def numpy_dtype(self):
617
618
        return self._dtype

619
620
621
622
    @property
    def sympy_dtype(self):
        return getattr(sympy.codegen.ast, str(self.numpy_dtype))

623
    @property
Martin Bauer's avatar
Martin Bauer committed
624
    def item_size(self):
625
626
        return 1

627
    def is_int(self):
628
        return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint']
629
630

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
631
        return self.numpy_dtype in np.sctypes['float']
632
633

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
634
        return self.numpy_dtype in np.sctypes['uint']
635

Martin Bauer's avatar
Martin Bauer committed
636
637
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
638
639

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
640
        return self.numpy_dtype in np.sctypes['others']
641

642
    @property
Martin Bauer's avatar
Martin Bauer committed
643
644
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
645

Jan Hoenig's avatar
Jan Hoenig committed
646
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
647
        result = BasicType.numpy_name_to_c(str(self._dtype))
648
649
650
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
651

652
653
654
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
655
    def __eq__(self, other):
656
657
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
658
        else:
Martin Bauer's avatar
Martin Bauer committed
659
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
660
661
662
663
664

    def __hash__(self):
        return hash(str(self))


665
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
666
    instruction_set = None
667

Martin Bauer's avatar
Martin Bauer committed
668
669
    def __init__(self, base_type, width=4):
        self._base_type = base_type
670
671
672
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
673
674
    def base_type(self):
        return self._base_type
675
676

    @property
Martin Bauer's avatar
Martin Bauer committed
677
678
    def item_size(self):
        return self.width * self.base_type.item_size
679
680
681
682
683

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
684
            return (self.base_type, self.width) == (other.base_type, other.width)
685
686

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
687
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
688
            return "%s[%d]" % (self.base_type, self.width)
689
        else:
Martin Bauer's avatar
Martin Bauer committed
690
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
691
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
692
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
693
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
694
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
695
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
696
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
697
                return self.instruction_set['bool']
698
699
700
701
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
702
        return hash((self.base_type, self.width))
703

Martin Bauer's avatar
Martin Bauer committed
704
705
706
    def __getnewargs__(self):
        return self._base_type, self.width

707

708
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
709
710
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
711
712
713
        self.const = const
        self.restrict = restrict

714
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
715
        return self.base_type, self.const, self.restrict
716

717
718
719
720
721
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
722
723
    def base_type(self):
        return self._base_type
724

725
    @property
Martin Bauer's avatar
Martin Bauer committed
726
727
    def item_size(self):
        return self.base_type.item_size
728

729
730
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
731
            return False
732
        else:
Martin Bauer's avatar
Martin Bauer committed
733
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
734

Jan Hoenig's avatar
Jan Hoenig committed
735
    def __str__(self):
736
737
738
739
740
741
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
742

743
744
745
    def __repr__(self):
        return str(self)

746
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
747
        return hash((self._base_type, self.const, self.restrict))
748

Jan Hoenig's avatar
Jan Hoenig committed
749

750
class StructType:
Martin Bauer's avatar
Martin Bauer committed
751
    def __init__(self, numpy_type, const=False):
752
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
753
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
754

755
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
756
        return self.numpy_dtype, self.const
757

758
    @property
Martin Bauer's avatar
Martin Bauer committed
759
    def base_type(self):
760
761
762
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
763
    def numpy_dtype(self):
764
765
766
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
767
768
    def item_size(self):
        return self.numpy_dtype.itemsize
769

Martin Bauer's avatar
Martin Bauer committed
770
771
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
772

Martin Bauer's avatar
Martin Bauer committed
773
774
775
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
776

Martin Bauer's avatar
Martin Bauer committed
777
778
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
779

780
781
782
783
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
784
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
785
786
787
788
789
790
791
792

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

793
794
795
    def __repr__(self):
        return str(self)

796
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
797
        return hash((self.numpy_dtype, self.const))
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817


class TypedImaginaryUnit(TypedSymbol):
    def __new__(cls, *args, **kwds):
        obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
        return obj

    def __new_stage2__(cls, dtype, *args, **kwargs):
        obj = super(TypedImaginaryUnit, cls).__xnew__(cls,
                                                      "_i",
                                                      dtype,
                                                      imaginary=True,
                                                      *args,
                                                      **kwargs)
        return obj

    headers = ['"cuda_complex.hpp"']

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))