data_types.py 17.1 KB
Newer Older
1
import ctypes
2
import sympy as sp
3
import numpy as np
4
5
6
7
8
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
9
from sympy.core.cache import cacheit
10

11
from pystencils.cache import memorycache
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.utils import all_equal
Martin Bauer's avatar
Martin Bauer committed
13
from sympy.logic.boolalg import Boolean
14

15

Martin Bauer's avatar
Martin Bauer committed
16
# noinspection PyPep8Naming
17
class cast_func(sp.Function):
18
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
19

20
21
22
23
24
25
26
27
28
29
30
31
32
    def __new__(cls, *args, **kwargs):
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
        # -> thus a separate class bollean_cast_func is introduced
        if isinstance(args[0], Boolean):
            cls = boolean_cast_func
        return sp.Function.__new__(cls, *args, **kwargs)

33
34
35
36
37
38
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
39

40
41
42
43
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
44
45
46
47
48
    @property
    def dtype(self):
        return self.args[1]


49
50
51
52
53
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
54
55
56
57
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

58

Martin Bauer's avatar
Martin Bauer committed
59
60
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
61
62
63
64
65
66
67
68
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


69
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
70
71
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
72
73
        return obj

74
    def __new_stage2__(cls, name, dtype):
75
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
76
        try:
Martin Bauer's avatar
Martin Bauer committed
77
            obj._dtype = create_type(dtype)
78
79
80
        except TypeError:
            # on error keep the string
            obj._dtype = dtype
81
82
83
84
85
86
87
88
89
90
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
Martin Bauer's avatar
Martin Bauer committed
91
92
        super_class_contents = list(super(TypedSymbol, self)._hashable_content())
        return tuple(super_class_contents + [hash(self._dtype)])
93
94

    def __getnewargs__(self):
95
96
97
        return self.name, self.dtype


Martin Bauer's avatar
Martin Bauer committed
98
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
99
100
101
102
103
104
105
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
106
    """
107
108
109
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
110
111
112
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
113
        else:
Martin Bauer's avatar
Martin Bauer committed
114
            return StructType(numpy_dtype, const=False)
115
116


117
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
118
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
119
120
121
122
123
124
125
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
126
    """
127
128
129
130
131
132
133
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
134
        else:
135
136
137
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
138
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
139
    base_part = parts.pop(0)
140
    const = False
Martin Bauer's avatar
Martin Bauer committed
141
    if 'const' in base_part:
142
        const = True
Martin Bauer's avatar
Martin Bauer committed
143
144
145
146
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
147
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
148
    current_type = BasicType(np.dtype(base_part[0]), const)
149
150
151
152
153
154
155
156
157
158
159
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
160
161
        current_type = PointerType(current_type, const, restrict)
    return current_type
162
163


Martin Bauer's avatar
Martin Bauer committed
164
165
166
167
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
168
169


Martin Bauer's avatar
Martin Bauer committed
170
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
171
172
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
173
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
174
175
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
176
177
178
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
179
        return ctypes.POINTER(ctypes.c_uint8)
180
    else:
Martin Bauer's avatar
Martin Bauer committed
181
        return to_ctypes.map[data_type.numpy_dtype]
182

Martin Bauer's avatar
Martin Bauer committed
183
184

to_ctypes.map = {
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


200
def ctypes_from_llvm(data_type):
201
202
    if not ir:
        raise _ir_importerror
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
236
237
    if not ir:
        raise _ir_importerror
238
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
239
        return to_llvm_type(data_type.base_type).as_pointer()
240
    else:
Martin Bauer's avatar
Martin Bauer committed
241
242
        return to_llvm_type.map[data_type.numpy_dtype]

243

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
259

260

Martin Bauer's avatar
Martin Bauer committed
261
262
263
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
264
265
266
    return dtype


Martin Bauer's avatar
Martin Bauer committed
267
def collate_types(types):
268
269
270
271
272
273
274
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
275
        pointer_type = None
276
277
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
278
                if pointer_type is not None:
279
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
280
                pointer_type = t
281
282
283
284
285
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
286
        return pointer_type
287
288

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
289
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
290
    if not all_equal(t.width for t in vector_type):
291
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
292
    types = [peel_off_type(t, VectorType) for t in types]
293
294
295
296

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

297
298
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
299
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
300
301
302
303
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
304
305
306
307
    return result


@memorycache(maxsize=2048)
Martin Bauer's avatar
Martin Bauer committed
308
def get_type_of_expression(expr):
309
310
311
    from pystencils.astnodes import ResolvedFieldAccess
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
312
        return create_type("int")
313
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
Martin Bauer's avatar
Martin Bauer committed
314
        return create_type("double")
315
316
317
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
318
        return expr.dtype
319
    elif isinstance(expr, sp.Symbol):
320
        raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
321
    elif isinstance(expr, cast_func):
322
323
        return expr.args[1]
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
324
325
326
327
328
        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
329
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
330
331
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
332
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
333
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
334
335
336
337
        result = create_type("bool")
        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
338
        return result
339
340
    elif isinstance(expr, sp.Pow):
        return get_type_of_expression(expr.args[0])
341
    elif isinstance(expr, sp.Expr):
Martin Bauer's avatar
Martin Bauer committed
342
343
        types = tuple(get_type_of_expression(a) for a in expr.args)
        return collate_types(types)
344

345
    raise NotImplementedError("Could not determine type for", expr, type(expr))
346
347


348
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
349
350
    is_Atom = True

351
352
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
353

354
355
    def _sympystr(self, *args, **kwargs):
        return str(self)
356
357
358
359


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
360
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
361
362
363
364
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
365
366
367
368
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
369
            width = int(name[len("uint"):])
370
371
372
373
374
375
376
377
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
            raise NotImplemented("Can map numpy to C name for %s" % (name,))

    def __init__(self, dtype, const=False):
        self.const = const
378
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
379
            self._dtype = dtype.numpy_dtype
380
381
        else:
            self._dtype = np.dtype(dtype)
382
383
384
385
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

386
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
387
        return self.numpy_dtype, self.const
388

389
    @property
Martin Bauer's avatar
Martin Bauer committed
390
    def base_type(self):
391
        return None
392

393
    @property
Martin Bauer's avatar
Martin Bauer committed
394
    def numpy_dtype(self):
395
396
        return self._dtype

397
    @property
Martin Bauer's avatar
Martin Bauer committed
398
    def item_size(self):
399
400
        return 1

401
    def is_int(self):
Martin Bauer's avatar
Martin Bauer committed
402
        return self.numpy_dtype in np.sctypes['int']
403
404

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
405
        return self.numpy_dtype in np.sctypes['float']
406
407

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
408
        return self.numpy_dtype in np.sctypes['uint']
409

Martin Bauer's avatar
Martin Bauer committed
410
411
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
412
413

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
414
        return self.numpy_dtype in np.sctypes['others']
415

416
    @property
Martin Bauer's avatar
Martin Bauer committed
417
418
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
419

Jan Hoenig's avatar
Jan Hoenig committed
420
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
421
        result = BasicType.numpy_name_to_c(str(self._dtype))
422
423
424
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
425

426
427
428
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
429
    def __eq__(self, other):
430
431
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
432
        else:
Martin Bauer's avatar
Martin Bauer committed
433
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
434
435
436
437
438

    def __hash__(self):
        return hash(str(self))


439
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
440
    instruction_set = None
441

Martin Bauer's avatar
Martin Bauer committed
442
443
    def __init__(self, base_type, width=4):
        self._base_type = base_type
444
445
446
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
447
448
    def base_type(self):
        return self._base_type
449
450

    @property
Martin Bauer's avatar
Martin Bauer committed
451
452
    def item_size(self):
        return self.width * self.base_type.item_size
453
454
455
456
457

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
458
            return (self.base_type, self.width) == (other.base_type, other.width)
459
460

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
461
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
462
            return "%s[%d]" % (self.base_type, self.width)
463
        else:
Martin Bauer's avatar
Martin Bauer committed
464
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
465
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
466
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
467
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
468
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
469
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
470
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
471
                return self.instruction_set['bool']
472
473
474
475
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
476
        return hash((self.base_type, self.width))
477

Martin Bauer's avatar
Martin Bauer committed
478
479
480
    def __getnewargs__(self):
        return self._base_type, self.width

481

482
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
483
484
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
485
486
487
        self.const = const
        self.restrict = restrict

488
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
489
        return self.base_type, self.const, self.restrict
490

491
492
493
494
495
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
496
497
    def base_type(self):
        return self._base_type
498

499
    @property
Martin Bauer's avatar
Martin Bauer committed
500
501
    def item_size(self):
        return self.base_type.item_size
502

503
504
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
505
            return False
506
        else:
Martin Bauer's avatar
Martin Bauer committed
507
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
508

Jan Hoenig's avatar
Jan Hoenig committed
509
    def __str__(self):
510
511
512
513
514
515
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
516

517
518
519
    def __repr__(self):
        return str(self)

520
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
521
        return hash((self._base_type, self.const, self.restrict))
522

Jan Hoenig's avatar
Jan Hoenig committed
523

524
class StructType:
Martin Bauer's avatar
Martin Bauer committed
525
    def __init__(self, numpy_type, const=False):
526
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
527
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
528

529
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
530
        return self.numpy_dtype, self.const
531

532
    @property
Martin Bauer's avatar
Martin Bauer committed
533
    def base_type(self):
534
535
536
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
537
    def numpy_dtype(self):
538
539
540
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
541
542
    def item_size(self):
        return self.numpy_dtype.itemsize
543

Martin Bauer's avatar
Martin Bauer committed
544
545
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
546

Martin Bauer's avatar
Martin Bauer committed
547
548
549
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
550

Martin Bauer's avatar
Martin Bauer committed
551
552
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
553

554
555
556
557
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
558
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
559
560
561
562
563
564
565
566

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

567
568
569
    def __repr__(self):
        return str(self)

570
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
571
        return hash((self.numpy_dtype, self.const))