data_types.py 25 KB
Newer Older
1
import ctypes
2
3
from collections import defaultdict
from functools import partial
4
from typing import Tuple
Martin Bauer's avatar
Martin Bauer committed
5

6
import numpy as np
7
8
import sympy as sp
import sympy.codegen.ast
9
10
11
12
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean, BooleanFunction

import pystencils
13
from pystencils.cache import memorycache, memorycache_if_hashable
Martin Bauer's avatar
Martin Bauer committed
14
from pystencils.utils import all_equal
15

16
17
18
19
20
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
21

22

23
24
25
26
27
28
29
def typed_symbols(names, dtype, *args):
    symbols = sp.symbols(names, *args)
    if isinstance(symbols, Tuple):
        return tuple(TypedSymbol(str(s), dtype) for s in symbols)
    else:
        return TypedSymbol(str(symbols), dtype)

Martin Bauer's avatar
Martin Bauer committed
30

31
32
33
def type_all_numbers(expr, dtype):
    substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)}
    return expr.subs(substitutions)
34

Martin Bauer's avatar
Martin Bauer committed
35

36
37
38
39
40
41
42
43
44
45
46
47
def matrix_symbols(names, dtype, rows, cols):
    if isinstance(names, str):
        names = names.replace(' ', '').split(',')

    matrices = []
    for n in names:
        symbols = typed_symbols("%s:%i" % (n, rows * cols), dtype)
        matrices.append(sp.Matrix(rows, cols, lambda i, j: symbols[i * cols + j]))

    return tuple(matrices)


48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def assumptions_from_dtype(dtype):
    """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype

    Args:
        dtype (BasicType, np.dtype): a Numpy data type
    Returns:
        A dict of SymPy assumptions
    """
    if hasattr(dtype, 'numpy_dtype'):
        dtype = dtype.numpy_dtype

    assumptions = dict()

    try:
        if np.issubdtype(dtype, np.integer):
            assumptions.update({'integer': True})

        if np.issubdtype(dtype, np.unsignedinteger):
            assumptions.update({'negative': False})

        if np.issubdtype(dtype, np.integer) or \
                np.issubdtype(dtype, np.floating):
            assumptions.update({'real': True})
    except Exception:
        pass

    return assumptions


77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# noinspection PyPep8Naming
class address_of(sp.Function):
    is_Atom = True

    def __new__(cls, arg):
        obj = sp.Function.__new__(cls, arg)
        return obj

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()

    @property
    def is_commutative(self):
        return self.args[0].is_commutative

    @property
    def dtype(self):
        if hasattr(self.args[0], 'dtype'):
            return PointerType(self.args[0].dtype, restrict=True)
        else:
            return PointerType('void', restrict=True)


Martin Bauer's avatar
Martin Bauer committed
104
# noinspection PyPep8Naming
105
class cast_func(sp.Function):
106
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
107

108
    def __new__(cls, *args, **kwargs):
109
110
111
112
113
        if len(args) != 2:
            pass
        expr, dtype, *other_args = args
        if not isinstance(dtype, Type):
            dtype = create_type(dtype)
114
115
116
117
118
119
120
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
121
        # -> thus a separate class boolean_cast_func is introduced
122
        if isinstance(expr, Boolean):
123
            cls = boolean_cast_func
124

125
        return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
126

127
128
129
130
131
132
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
133

134
135
136
137
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
138
139
140
    def _eval_evalf(self, *args, **kwargs):
        return self.args[0].evalf()

Martin Bauer's avatar
Martin Bauer committed
141
142
143
144
    @property
    def dtype(self):
        return self.args[1]

145
146
    @property
    def is_integer(self):
147
148
149
150
151
        """
        Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate

        For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
        """
152
153
154
155
156
157
158
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
        else:
            return super().is_integer

    @property
    def is_negative(self):
159
160
161
        """
        See :func:`.TypedSymbol.is_integer`
        """
162
163
164
165
166
167
168
169
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

        return super().is_negative

    @property
    def is_nonnegative(self):
170
171
172
        """
        See :func:`.TypedSymbol.is_integer`
        """
173
174
175
176
177
178
179
        if self.is_negative is False:
            return True
        else:
            return super().is_nonnegative

    @property
    def is_real(self):
180
181
182
        """
        See :func:`.TypedSymbol.is_integer`
        """
183
184
185
186
187
188
189
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real

Martin Bauer's avatar
Martin Bauer committed
190

191
192
193
194
195
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
196
197
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
Martin Bauer's avatar
Martin Bauer committed
198
199
    # Arguments are: read/write expression, type, aligned, nontemporal, mask (or none)
    nargs = (5,)
Martin Bauer's avatar
Martin Bauer committed
200

201

202
203
204
205
206
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
    pass


Martin Bauer's avatar
Martin Bauer committed
207
208
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
209
210
211
212
213
214
215
216
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


217
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
218
219
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
220
221
        return obj

222
    def __new_stage2__(cls, name, dtype, *args, **kwargs):
223
224
        assumptions = assumptions_from_dtype(dtype)
        obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **assumptions, **kwargs)
225
        try:
Martin Bauer's avatar
Martin Bauer committed
226
            obj._dtype = create_type(dtype)
227
        except (TypeError, ValueError):
228
229
            # on error keep the string
            obj._dtype = dtype
230
231
232
233
234
235
236
237
238
239
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
240
        return super()._hashable_content(), hash(self._dtype)
241
242

    def __getnewargs__(self):
243
244
        return self.name, self.dtype

245
246
247
248
249
250
251
252
    @property
    def canonical(self):
        return self

    @property
    def reversed(self):
        return self

253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    @property
    def headers(self):
        headers = []
        try:
            if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating):
                headers.append('"cuda_complex.hpp"')
        except Exception:
            pass
        try:
            if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating):
                headers.append('"cuda_complex.hpp"')
        except Exception:
            pass

        return headers

269

Martin Bauer's avatar
Martin Bauer committed
270
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
271
272
273
274
275
276
277
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
278
    """
279
280
281
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
282
283
284
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
285
        else:
Martin Bauer's avatar
Martin Bauer committed
286
            return StructType(numpy_dtype, const=False)
287
288


289
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
290
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
291
292
293
294
295
296
297
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
298
    """
299
300
301
302
303
304
305
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
306
        else:
307
308
309
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
310
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
311
    base_part = parts.pop(0)
312
    const = False
Martin Bauer's avatar
Martin Bauer committed
313
    if 'const' in base_part:
314
        const = True
Martin Bauer's avatar
Martin Bauer committed
315
316
317
318
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
319
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
320
    current_type = BasicType(np.dtype(base_part[0]), const)
321
322
323
324
325
326
327
328
329
330
331
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
332
333
        current_type = PointerType(current_type, const, restrict)
    return current_type
334
335


Martin Bauer's avatar
Martin Bauer committed
336
337
338
339
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
340
341


Martin Bauer's avatar
Martin Bauer committed
342
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
343
344
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
345
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
346
347
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
348
349
350
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
351
        return ctypes.POINTER(ctypes.c_uint8)
352
    else:
Martin Bauer's avatar
Martin Bauer committed
353
        return to_ctypes.map[data_type.numpy_dtype]
354

Martin Bauer's avatar
Martin Bauer committed
355
356

to_ctypes.map = {
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


372
def ctypes_from_llvm(data_type):
373
374
    if not ir:
        raise _ir_importerror
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
399
        raise NotImplementedError(f'Data type {type(data_type)} of {data_type} is not supported yet')
400
401


402
def to_llvm_type(data_type, nvvm_target=False):
403
404
405
406
407
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
408
409
    if not ir:
        raise _ir_importerror
410
    if isinstance(data_type, PointerType):
411
        return to_llvm_type(data_type.base_type).as_pointer(1 if nvvm_target else 0)
412
    else:
Martin Bauer's avatar
Martin Bauer committed
413
414
        return to_llvm_type.map[data_type.numpy_dtype]

415

416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
431

432

Martin Bauer's avatar
Martin Bauer committed
433
434
435
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
436
437
438
    return dtype


439
440
def collate_types(types,
                  forbid_collation_to_complex=False,
441
442
443
                  forbid_collation_to_float=False,
                  default_float_type='float64',
                  default_int_type='int64'):
444
445
446
447
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """
448
449
450
451
452
453
    if forbid_collation_to_complex:
        types = [
            t for t in types
            if not np.issubdtype(t.numpy_dtype, np.complexfloating)
        ]
        if not types:
454
            return create_type(default_float_type)
455

456
    if forbid_collation_to_float:
457
458
459
        types = [
            t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)
        ]
460
        if not types:
461
            return create_type(default_int_type)
462

463
464
    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
465
        pointer_type = None
466
467
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
468
                if pointer_type is not None:
469
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
470
                pointer_type = t
471
472
473
474
475
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
476
        return pointer_type
477
478

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
479
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
480
    if not all_equal(t.width for t in vector_type):
481
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
482
    types = [peel_off_type(t, VectorType) for t in types]
483
484
485
486

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

487
488
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
489
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
490
491
492
493
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
494
495
496
    return result


497
498
499
500
501
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
                           default_float_type='double',
                           default_int_type='int',
                           symbol_type_dict=None):
502
    from pystencils.astnodes import ResolvedFieldAccess
503
504
    from pystencils.cpu.vectorization import vec_all, vec_any

505
506
507
508
509
510
511
512
    if not symbol_type_dict:
        symbol_type_dict = defaultdict(lambda: create_type('double'))

    get_type = partial(get_type_of_expression,
                       default_float_type=default_float_type,
                       default_int_type=default_int_type,
                       symbol_type_dict=symbol_type_dict)

513
514
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
515
        return create_type(default_int_type)
516
517
    elif expr.is_real is False:
        return create_type((np.zeros((1,), default_float_type) * 1j).dtype)
518
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
519
        return create_type(default_float_type)
520
521
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
522
523
    elif isinstance(expr, pystencils.field.Field.AbstractAccess):
        return expr.field.dtype
524
    elif isinstance(expr, TypedSymbol):
525
        return expr.dtype
526
    elif isinstance(expr, sp.Symbol):
527
528
529
530
        if symbol_type_dict:
            return symbol_type_dict[expr.name]
        else:
            raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
531
    elif isinstance(expr, cast_func):
532
        return expr.args[1]
533
    elif isinstance(expr, (vec_any, vec_all)):
534
        return create_type("bool")
535
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
536
537
        collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
Martin Bauer's avatar
Martin Bauer committed
538
539
540
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
541
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
542
543
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
544
    elif isinstance(expr, (Boolean, BooleanFunction)):
545
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
546
        result = create_type("bool")
547
        vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
Martin Bauer's avatar
Martin Bauer committed
548
549
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
550
        return result
551
552
553
554
555
556
557
    elif isinstance(expr, sp.Pow):
        base_type = get_type(expr.args[0])
        if expr.exp.is_integer:
            return base_type
        else:
            return collate_types([create_type(default_float_type), base_type])
    elif isinstance(expr, (sp.Sum, sp.Product)):
558
        return get_type(expr.args[0])
559
    elif isinstance(expr, sp.Expr):
560
561
        expr: sp.Expr
        if expr.args:
562
            types = tuple(get_type(a) for a in expr.args)
563
564
565
            return collate_types(
                types,
                forbid_collation_to_complex=expr.is_real is True,
566
567
568
                forbid_collation_to_float=expr.is_integer is True,
                default_float_type=default_float_type,
                default_int_type=default_int_type)
569
570
571
572
573
        else:
            if expr.is_integer:
                return create_type(default_int_type)
            else:
                return create_type(default_float_type)
574

575
    raise NotImplementedError("Could not determine type for", expr, type(expr))
576
577


578
class Type(sp.Atom):
579
580
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
581

582
583
    def _sympystr(self, *args, **kwargs):
        return str(self)
584
585
586
587


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
588
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
589
590
591
592
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
593
594
595
596
        elif name == 'complex64':
            return 'ComplexFloat'
        elif name == 'complex128':
            return 'ComplexDouble'
597
598
599
600
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
601
            width = int(name[len("uint"):])
602
603
604
605
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
606
            raise NotImplementedError(f"Can map numpy to C name for {name}")
607
608
609

    def __init__(self, dtype, const=False):
        self.const = const
610
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
611
            self._dtype = dtype.numpy_dtype
612
613
        else:
            self._dtype = np.dtype(dtype)
614
615
616
617
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

618
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
619
        return self.numpy_dtype, self.const
620

621
    @property
Martin Bauer's avatar
Martin Bauer committed
622
    def base_type(self):
623
        return None
624

625
    @property
Martin Bauer's avatar
Martin Bauer committed
626
    def numpy_dtype(self):
627
628
        return self._dtype

629
630
631
632
    @property
    def sympy_dtype(self):
        return getattr(sympy.codegen.ast, str(self.numpy_dtype))

633
    @property
Martin Bauer's avatar
Martin Bauer committed
634
    def item_size(self):
635
636
        return 1

637
    def is_int(self):
638
        return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint']
639
640

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
641
        return self.numpy_dtype in np.sctypes['float']
642
643

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
644
        return self.numpy_dtype in np.sctypes['uint']
645

Martin Bauer's avatar
Martin Bauer committed
646
647
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
648
649

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
650
        return self.numpy_dtype in np.sctypes['others']
651

652
    @property
Martin Bauer's avatar
Martin Bauer committed
653
654
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
655

Jan Hoenig's avatar
Jan Hoenig committed
656
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
657
        result = BasicType.numpy_name_to_c(str(self._dtype))
658
659
660
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
661

662
663
664
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
665
    def __eq__(self, other):
666
667
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
668
        else:
Martin Bauer's avatar
Martin Bauer committed
669
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
670
671
672
673
674

    def __hash__(self):
        return hash(str(self))


675
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
676
    instruction_set = None
677

Martin Bauer's avatar
Martin Bauer committed
678
679
    def __init__(self, base_type, width=4):
        self._base_type = base_type
680
681
682
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
683
684
    def base_type(self):
        return self._base_type
685
686

    @property
Martin Bauer's avatar
Martin Bauer committed
687
688
    def item_size(self):
        return self.width * self.base_type.item_size
689
690
691
692
693

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
694
            return (self.base_type, self.width) == (other.base_type, other.width)
695
696

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
697
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
698
            return "%s[%d]" % (self.base_type, self.width)
699
        else:
Martin Bauer's avatar
Martin Bauer committed
700
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
701
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
702
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
703
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
704
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
705
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
706
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
707
                return self.instruction_set['bool']
708
709
710
711
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
712
        return hash((self.base_type, self.width))
713

Martin Bauer's avatar
Martin Bauer committed
714
715
716
    def __getnewargs__(self):
        return self._base_type, self.width

717

718
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
719
720
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
721
722
723
        self.const = const
        self.restrict = restrict

724
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
725
        return self.base_type, self.const, self.restrict
726

727
728
729
730
731
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
732
733
    def base_type(self):
        return self._base_type
734

735
    @property
Martin Bauer's avatar
Martin Bauer committed
736
737
    def item_size(self):
        return self.base_type.item_size
738

739
740
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
741
            return False
742
        else:
Martin Bauer's avatar
Martin Bauer committed
743
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
744

Jan Hoenig's avatar
Jan Hoenig committed
745
    def __str__(self):
746
747
748
749
750
751
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
752

753
754
755
    def __repr__(self):
        return str(self)

756
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
757
        return hash((self._base_type, self.const, self.restrict))
758

Jan Hoenig's avatar
Jan Hoenig committed
759

760
class StructType:
Martin Bauer's avatar
Martin Bauer committed
761
    def __init__(self, numpy_type, const=False):
762
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
763
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
764

765
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
766
        return self.numpy_dtype, self.const
767

768
    @property
Martin Bauer's avatar
Martin Bauer committed
769
    def base_type(self):
770
771
772
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
773
    def numpy_dtype(self):
774
775
776
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
777
778
    def item_size(self):
        return self.numpy_dtype.itemsize
779

Martin Bauer's avatar
Martin Bauer committed
780
781
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
782

Martin Bauer's avatar
Martin Bauer committed
783
784
785
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
786

Martin Bauer's avatar
Martin Bauer committed
787
788
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
789

790
791
792
793
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
794
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
795
796
797
798
799
800
801
802

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

803
804
805
    def __repr__(self):
        return str(self)

806
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
807
        return hash((self.numpy_dtype, self.const))
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827


class TypedImaginaryUnit(TypedSymbol):
    def __new__(cls, *args, **kwds):
        obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
        return obj

    def __new_stage2__(cls, dtype, *args, **kwargs):
        obj = super(TypedImaginaryUnit, cls).__xnew__(cls,
                                                      "_i",
                                                      dtype,
                                                      imaginary=True,
                                                      *args,
                                                      **kwargs)
        return obj

    headers = ['"cuda_complex.hpp"']

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
828
829
830

    def __getnewargs__(self):
        return (self.dtype,)