data_types.py 26.1 KB
Newer Older
1
import ctypes
2
3
from collections import defaultdict
from functools import partial
4
from typing import Tuple
Martin Bauer's avatar
Martin Bauer committed
5

6
import numpy as np
7
8
import sympy as sp
import sympy.codegen.ast
9
10
11
12
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean, BooleanFunction

import pystencils
13
from pystencils.cache import memorycache, memorycache_if_hashable
Martin Bauer's avatar
Martin Bauer committed
14
from pystencils.utils import all_equal
15

16
17
18
19
20
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
21

22

23
24
25
26
27
28
29
def typed_symbols(names, dtype, *args):
    symbols = sp.symbols(names, *args)
    if isinstance(symbols, Tuple):
        return tuple(TypedSymbol(str(s), dtype) for s in symbols)
    else:
        return TypedSymbol(str(symbols), dtype)

Martin Bauer's avatar
Martin Bauer committed
30

31
32
33
def type_all_numbers(expr, dtype):
    substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)}
    return expr.subs(substitutions)
34

Martin Bauer's avatar
Martin Bauer committed
35

36
37
38
39
40
41
42
43
44
45
46
47
def matrix_symbols(names, dtype, rows, cols):
    if isinstance(names, str):
        names = names.replace(' ', '').split(',')

    matrices = []
    for n in names:
        symbols = typed_symbols("%s:%i" % (n, rows * cols), dtype)
        matrices.append(sp.Matrix(rows, cols, lambda i, j: symbols[i * cols + j]))

    return tuple(matrices)


48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def assumptions_from_dtype(dtype):
    """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype

    Args:
        dtype (BasicType, np.dtype): a Numpy data type
    Returns:
        A dict of SymPy assumptions
    """
    if hasattr(dtype, 'numpy_dtype'):
        dtype = dtype.numpy_dtype

    assumptions = dict()

    try:
        if np.issubdtype(dtype, np.integer):
            assumptions.update({'integer': True})

        if np.issubdtype(dtype, np.unsignedinteger):
            assumptions.update({'negative': False})

        if np.issubdtype(dtype, np.integer) or \
                np.issubdtype(dtype, np.floating):
            assumptions.update({'real': True})
    except Exception:
        pass

    return assumptions


77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# noinspection PyPep8Naming
class address_of(sp.Function):
    is_Atom = True

    def __new__(cls, arg):
        obj = sp.Function.__new__(cls, arg)
        return obj

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()

    @property
    def is_commutative(self):
        return self.args[0].is_commutative

    @property
    def dtype(self):
        if hasattr(self.args[0], 'dtype'):
            return PointerType(self.args[0].dtype, restrict=True)
        else:
            return PointerType('void', restrict=True)


Martin Bauer's avatar
Martin Bauer committed
104
# noinspection PyPep8Naming
105
class cast_func(sp.Function):
106
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
107

108
    def __new__(cls, *args, **kwargs):
109
110
111
112
113
        if len(args) != 2:
            pass
        expr, dtype, *other_args = args
        if not isinstance(dtype, Type):
            dtype = create_type(dtype)
114
115
116
117
118
119
120
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
121
        # -> thus a separate class boolean_cast_func is introduced
122
        if isinstance(expr, Boolean) and (not isinstance(expr, TypedSymbol) or expr.dtype == BasicType(bool)):
123
            cls = boolean_cast_func
124

125
        return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
126

127
128
129
130
131
132
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
133

134
135
136
137
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
138
139
140
    def _eval_evalf(self, *args, **kwargs):
        return self.args[0].evalf()

Martin Bauer's avatar
Martin Bauer committed
141
142
143
144
    @property
    def dtype(self):
        return self.args[1]

145
146
    @property
    def is_integer(self):
147
148
149
150
151
        """
        Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate

        For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
        """
152
153
154
155
156
157
158
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
        else:
            return super().is_integer

    @property
    def is_negative(self):
159
160
161
        """
        See :func:`.TypedSymbol.is_integer`
        """
162
163
164
165
166
167
168
169
        if hasattr(self.dtype, 'numpy_dtype'):
            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                return False

        return super().is_negative

    @property
    def is_nonnegative(self):
170
171
172
        """
        See :func:`.TypedSymbol.is_integer`
        """
173
174
175
176
177
178
179
        if self.is_negative is False:
            return True
        else:
            return super().is_nonnegative

    @property
    def is_real(self):
180
181
182
        """
        See :func:`.TypedSymbol.is_integer`
        """
183
184
185
186
187
188
189
        if hasattr(self.dtype, 'numpy_dtype'):
            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                super().is_real
        else:
            return super().is_real

Martin Bauer's avatar
Martin Bauer committed
190

191
192
193
194
195
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
196
197
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
Michael Kuron's avatar
Michael Kuron committed
198
199
    # Arguments are: read/write expression, type, aligned, nontemporal, mask (or none), stride
    nargs = (6,)
Martin Bauer's avatar
Martin Bauer committed
200

201

202
203
204
205
206
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
    pass


Martin Bauer's avatar
Martin Bauer committed
207
208
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
209
210
211
212
213
214
215
216
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


217
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
218
219
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
220
221
        return obj

Michael Kuron's avatar
Michael Kuron committed
222
    def __new_stage2__(cls, name, dtype, **kwargs):
223
        assumptions = assumptions_from_dtype(dtype)
Michael Kuron's avatar
Michael Kuron committed
224
225
        assumptions.update(kwargs)
        obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions)
226
        try:
Martin Bauer's avatar
Martin Bauer committed
227
            obj._dtype = create_type(dtype)
228
        except (TypeError, ValueError):
229
230
            # on error keep the string
            obj._dtype = dtype
231
232
233
234
235
236
237
238
239
240
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
241
        return super()._hashable_content(), hash(self._dtype)
242
243

    def __getnewargs__(self):
244
245
        return self.name, self.dtype

Michael Kuron's avatar
Michael Kuron committed
246
247
248
    def __getnewargs_ex__(self):
        return (self.name, self.dtype), self.assumptions0

249
250
251
252
253
254
255
256
    @property
    def canonical(self):
        return self

    @property
    def reversed(self):
        return self

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    @property
    def headers(self):
        headers = []
        try:
            if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating):
                headers.append('"cuda_complex.hpp"')
        except Exception:
            pass
        try:
            if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating):
                headers.append('"cuda_complex.hpp"')
        except Exception:
            pass

        return headers

273

Martin Bauer's avatar
Martin Bauer committed
274
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
275
276
277
278
279
280
281
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
282
    """
283
284
285
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
286
287
288
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
289
        else:
Martin Bauer's avatar
Martin Bauer committed
290
            return StructType(numpy_dtype, const=False)
291
292


293
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
294
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
295
296
297
298
299
300
301
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
302
    """
303
304
305
306
307
308
309
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
310
        else:
311
312
313
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
314
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
315
    base_part = parts.pop(0)
316
    const = False
Martin Bauer's avatar
Martin Bauer committed
317
    if 'const' in base_part:
318
        const = True
Martin Bauer's avatar
Martin Bauer committed
319
320
321
322
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
323
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
324
    current_type = BasicType(np.dtype(base_part[0]), const)
325
326
327
328
329
330
331
332
333
334
335
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
336
337
        current_type = PointerType(current_type, const, restrict)
    return current_type
338
339


Martin Bauer's avatar
Martin Bauer committed
340
341
342
343
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
344
345


Martin Bauer's avatar
Martin Bauer committed
346
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
347
348
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
349
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
350
351
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
352
353
354
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
355
        return ctypes.POINTER(ctypes.c_uint8)
356
    else:
Martin Bauer's avatar
Martin Bauer committed
357
        return to_ctypes.map[data_type.numpy_dtype]
358

Martin Bauer's avatar
Martin Bauer committed
359
360

to_ctypes.map = {
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


376
def ctypes_from_llvm(data_type):
377
378
    if not ir:
        raise _ir_importerror
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
403
        raise NotImplementedError(f'Data type {type(data_type)} of {data_type} is not supported yet')
404
405


406
def to_llvm_type(data_type, nvvm_target=False):
407
408
409
410
411
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
412
413
    if not ir:
        raise _ir_importerror
414
    if isinstance(data_type, PointerType):
415
        return to_llvm_type(data_type.base_type).as_pointer(1 if nvvm_target else 0)
416
    else:
Martin Bauer's avatar
Martin Bauer committed
417
418
        return to_llvm_type.map[data_type.numpy_dtype]

419

420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
435

436

Martin Bauer's avatar
Martin Bauer committed
437
438
439
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
440
441
442
    return dtype


443
444
def collate_types(types,
                  forbid_collation_to_complex=False,
445
446
447
                  forbid_collation_to_float=False,
                  default_float_type='float64',
                  default_int_type='int64'):
448
449
450
451
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """
452
    if forbid_collation_to_complex:
Markus Holzer's avatar
Markus Holzer committed
453
        types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.complexfloating)]
454
        if not types:
455
            return create_type(default_float_type)
456

457
    if forbid_collation_to_float:
Markus Holzer's avatar
Markus Holzer committed
458
        types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)]
459
        if not types:
460
            return create_type(default_int_type)
461

462
463
    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
464
        pointer_type = None
465
466
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
467
                if pointer_type is not None:
468
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
469
                pointer_type = t
470
471
472
473
474
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
475
        return pointer_type
476
477

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
478
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
479
    if not all_equal(t.width for t in vector_type):
480
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
481
    types = [peel_off_type(t, VectorType) for t in types]
482
483
484
485

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

486
487
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
488
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
489
490
491
492
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
493
494
495
    return result


496
497
498
499
500
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
                           default_float_type='double',
                           default_int_type='int',
                           symbol_type_dict=None):
501
    from pystencils.astnodes import ResolvedFieldAccess
502
503
    from pystencils.cpu.vectorization import vec_all, vec_any

504
505
506
    if default_float_type == 'float':
        default_float_type = 'float32'

507
508
509
510
511
512
513
514
    if not symbol_type_dict:
        symbol_type_dict = defaultdict(lambda: create_type('double'))

    get_type = partial(get_type_of_expression,
                       default_float_type=default_float_type,
                       default_int_type=default_int_type,
                       symbol_type_dict=symbol_type_dict)

515
516
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
517
        return create_type(default_int_type)
518
519
    elif expr.is_real is False:
        return create_type((np.zeros((1,), default_float_type) * 1j).dtype)
520
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
521
        return create_type(default_float_type)
522
523
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
524
525
    elif isinstance(expr, pystencils.field.Field.AbstractAccess):
        return expr.field.dtype
526
    elif isinstance(expr, TypedSymbol):
527
        return expr.dtype
528
    elif isinstance(expr, sp.Symbol):
529
530
531
532
        if symbol_type_dict:
            return symbol_type_dict[expr.name]
        else:
            raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
533
    elif isinstance(expr, cast_func):
534
        return expr.args[1]
535
    elif isinstance(expr, (vec_any, vec_all)):
536
        return create_type("bool")
537
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
538
539
        collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
Martin Bauer's avatar
Martin Bauer committed
540
541
542
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
543
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
544
545
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
546
    elif isinstance(expr, (Boolean, BooleanFunction)):
547
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
548
        result = create_type("bool")
549
        vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
Martin Bauer's avatar
Martin Bauer committed
550
551
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
552
        return result
553
554
555
556
557
558
559
    elif isinstance(expr, sp.Pow):
        base_type = get_type(expr.args[0])
        if expr.exp.is_integer:
            return base_type
        else:
            return collate_types([create_type(default_float_type), base_type])
    elif isinstance(expr, (sp.Sum, sp.Product)):
560
        return get_type(expr.args[0])
561
    elif isinstance(expr, sp.Expr):
562
563
        expr: sp.Expr
        if expr.args:
564
            types = tuple(get_type(a) for a in expr.args)
Markus Holzer's avatar
Markus Holzer committed
565
566
567
568
569
570
571
            # collate_types checks numpy_dtype in the special cases
            if any(not hasattr(t, 'numpy_dtype') for t in types):
                forbid_collation_to_complex = False
                forbid_collation_to_float = False
            else:
                forbid_collation_to_complex = expr.is_real is True
                forbid_collation_to_float = expr.is_integer is True
572
573
            return collate_types(
                types,
Markus Holzer's avatar
Markus Holzer committed
574
575
                forbid_collation_to_complex=forbid_collation_to_complex,
                forbid_collation_to_float=forbid_collation_to_float,
576
577
                default_float_type=default_float_type,
                default_int_type=default_int_type)
578
579
580
581
582
        else:
            if expr.is_integer:
                return create_type(default_int_type)
            else:
                return create_type(default_float_type)
583

584
    raise NotImplementedError("Could not determine type for", expr, type(expr))
585
586


Michael Kuron's avatar
Michael Kuron committed
587
588
589
590
591
592
593
sympy_version = sp.__version__.split('.')
if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
    # __setstate__ would bypass the contructor, so we remove it
    sp.Number.__getstate__ = sp.Basic.__getstate__
    del sp.Basic.__getstate__


594
class Type(sp.Atom):
595
596
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
597

598
599
    def _sympystr(self, *args, **kwargs):
        return str(self)
600
601
602
603


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
604
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
605
606
607
608
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
609
610
611
612
        elif name == 'complex64':
            return 'ComplexFloat'
        elif name == 'complex128':
            return 'ComplexDouble'
613
614
615
616
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
617
            width = int(name[len("uint"):])
618
619
620
621
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
622
            raise NotImplementedError(f"Can map numpy to C name for {name}")
623
624
625

    def __init__(self, dtype, const=False):
        self.const = const
626
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
627
            self._dtype = dtype.numpy_dtype
628
629
        else:
            self._dtype = np.dtype(dtype)
630
631
632
633
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

634
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
635
        return self.numpy_dtype, self.const
636

Michael Kuron's avatar
Michael Kuron committed
637
638
639
    def __getnewargs_ex__(self):
        return (self.numpy_dtype, self.const), {}

640
    @property
Martin Bauer's avatar
Martin Bauer committed
641
    def base_type(self):
642
        return None
643

644
    @property
Martin Bauer's avatar
Martin Bauer committed
645
    def numpy_dtype(self):
646
647
        return self._dtype

648
649
650
651
    @property
    def sympy_dtype(self):
        return getattr(sympy.codegen.ast, str(self.numpy_dtype))

652
    @property
Martin Bauer's avatar
Martin Bauer committed
653
    def item_size(self):
654
655
        return 1

656
    def is_int(self):
657
        return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint']
658
659

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
660
        return self.numpy_dtype in np.sctypes['float']
661
662

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
663
        return self.numpy_dtype in np.sctypes['uint']
664

Martin Bauer's avatar
Martin Bauer committed
665
666
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
667
668

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
669
        return self.numpy_dtype in np.sctypes['others']
670

671
    @property
Martin Bauer's avatar
Martin Bauer committed
672
673
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
674

Jan Hoenig's avatar
Jan Hoenig committed
675
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
676
        result = BasicType.numpy_name_to_c(str(self._dtype))
677
678
679
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
680

681
682
683
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
684
    def __eq__(self, other):
685
686
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
687
        else:
Martin Bauer's avatar
Martin Bauer committed
688
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
689
690
691
692
693

    def __hash__(self):
        return hash(str(self))


694
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
695
    instruction_set = None
696

Martin Bauer's avatar
Martin Bauer committed
697
698
    def __init__(self, base_type, width=4):
        self._base_type = base_type
699
700
701
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
702
703
    def base_type(self):
        return self._base_type
704
705

    @property
Martin Bauer's avatar
Martin Bauer committed
706
707
    def item_size(self):
        return self.width * self.base_type.item_size
708
709
710
711
712

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
713
            return (self.base_type, self.width) == (other.base_type, other.width)
714
715

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
716
        if self.instruction_set is None:
Michael Kuron's avatar
Michael Kuron committed
717
            return "%s[%s]" % (self.base_type, self.width)
718
        else:
719
            if self.base_type == create_type("int64") or self.base_type == create_type("int32"):
Martin Bauer's avatar
Martin Bauer committed
720
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
721
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
722
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
723
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
724
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
725
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
726
                return self.instruction_set['bool']
727
728
729
730
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
731
        return hash((self.base_type, self.width))
732

Martin Bauer's avatar
Martin Bauer committed
733
734
735
    def __getnewargs__(self):
        return self._base_type, self.width

Michael Kuron's avatar
Michael Kuron committed
736
737
738
    def __getnewargs_ex__(self):
        return (self._base_type, self.width), {}

739

740
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
741
742
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
743
744
745
        self.const = const
        self.restrict = restrict

746
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
747
        return self.base_type, self.const, self.restrict
748

Michael Kuron's avatar
Michael Kuron committed
749
750
751
    def __getnewargs_ex__(self):
        return (self.base_type, self.const, self.restrict), {}

752
753
754
755
756
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
757
758
    def base_type(self):
        return self._base_type
759

760
    @property
Martin Bauer's avatar
Martin Bauer committed
761
762
    def item_size(self):
        return self.base_type.item_size
763

764
765
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
766
            return False
767
        else:
Martin Bauer's avatar
Martin Bauer committed
768
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
769

Jan Hoenig's avatar
Jan Hoenig committed
770
    def __str__(self):
771
772
773
774
775
776
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
777

778
779
780
    def __repr__(self):
        return str(self)

781
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
782
        return hash((self._base_type, self.const, self.restrict))
783

Jan Hoenig's avatar
Jan Hoenig committed
784

785
class StructType:
Martin Bauer's avatar
Martin Bauer committed
786
    def __init__(self, numpy_type, const=False):
787
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
788
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
789

790
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
791
        return self.numpy_dtype, self.const
792

Michael Kuron's avatar
Michael Kuron committed
793
794
795
    def __getnewargs_ex__(self):
        return (self.numpy_dtype, self.const), {}

796
    @property
Martin Bauer's avatar
Martin Bauer committed
797
    def base_type(self):
798
799
800
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
801
    def numpy_dtype(self):
802
803
804
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
805
806
    def item_size(self):
        return self.numpy_dtype.itemsize
807

Martin Bauer's avatar
Martin Bauer committed
808
809
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
810

Martin Bauer's avatar
Martin Bauer committed
811
812
813
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
814

Martin Bauer's avatar
Martin Bauer committed
815
816
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
817

818
819
820
821
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
822
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
823
824
825
826
827
828
829
830

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

831
832
833
    def __repr__(self):
        return str(self)

834
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
835
        return hash((self.numpy_dtype, self.const))
836
837
838
839
840
841
842


class TypedImaginaryUnit(TypedSymbol):
    def __new__(cls, *args, **kwds):
        obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
        return obj

Michael Kuron's avatar
Michael Kuron committed
843
    def __new_stage2__(cls, dtype):
844
845
846
        obj = super(TypedImaginaryUnit, cls).__xnew__(cls,
                                                      "_i",
                                                      dtype,
Michael Kuron's avatar
Michael Kuron committed
847
                                                      imaginary=True)
848
849
850
851
852
853
        return obj

    headers = ['"cuda_complex.hpp"']

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
854
855
856

    def __getnewargs__(self):
        return (self.dtype,)
Michael Kuron's avatar
Michael Kuron committed
857
858
859

    def __getnewargs_ex__(self):
        return (self.dtype,), {}