data_types.py 17.1 KB
Newer Older
1
import ctypes
2
import sympy as sp
3
import numpy as np
4
5
6
7
8
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
9
from sympy.core.cache import cacheit
10

11
from pystencils.cache import memorycache
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.utils import all_equal
Martin Bauer's avatar
Martin Bauer committed
13
from sympy.logic.boolalg import Boolean
14

15

Martin Bauer's avatar
Martin Bauer committed
16
# noinspection PyPep8Naming
17
class cast_func(sp.Function):
18
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
19

20
21
22
23
24
25
26
27
    def __new__(cls, *args, **kwargs):
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
28
        # -> thus a separate class boolean_cast_func is introduced
29
30
31
32
        if isinstance(args[0], Boolean):
            cls = boolean_cast_func
        return sp.Function.__new__(cls, *args, **kwargs)

33
34
35
36
37
38
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
39

40
41
42
43
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
44
45
46
47
48
    @property
    def dtype(self):
        return self.args[1]


49
50
51
52
53
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
54
55
56
57
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

58

59
60
61
62
63
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
    pass


Martin Bauer's avatar
Martin Bauer committed
64
65
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
66
67
68
69
70
71
72
73
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


74
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
75
76
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
77
78
        return obj

79
    def __new_stage2__(cls, name, dtype):
80
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
81
        try:
Martin Bauer's avatar
Martin Bauer committed
82
            obj._dtype = create_type(dtype)
83
84
85
        except TypeError:
            # on error keep the string
            obj._dtype = dtype
86
87
88
89
90
91
92
93
94
95
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
96
        return super()._hashable_content(), hash(self._dtype)
97
98

    def __getnewargs__(self):
99
100
101
        return self.name, self.dtype


Martin Bauer's avatar
Martin Bauer committed
102
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
103
104
105
106
107
108
109
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
110
    """
111
112
113
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
114
115
116
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
117
        else:
Martin Bauer's avatar
Martin Bauer committed
118
            return StructType(numpy_dtype, const=False)
119
120


121
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
122
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
123
124
125
126
127
128
129
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
130
    """
131
132
133
134
135
136
137
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
138
        else:
139
140
141
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
142
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
143
    base_part = parts.pop(0)
144
    const = False
Martin Bauer's avatar
Martin Bauer committed
145
    if 'const' in base_part:
146
        const = True
Martin Bauer's avatar
Martin Bauer committed
147
148
149
150
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
151
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
152
    current_type = BasicType(np.dtype(base_part[0]), const)
153
154
155
156
157
158
159
160
161
162
163
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
164
165
        current_type = PointerType(current_type, const, restrict)
    return current_type
166
167


Martin Bauer's avatar
Martin Bauer committed
168
169
170
171
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
172
173


Martin Bauer's avatar
Martin Bauer committed
174
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
175
176
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
177
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
178
179
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
180
181
182
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
183
        return ctypes.POINTER(ctypes.c_uint8)
184
    else:
Martin Bauer's avatar
Martin Bauer committed
185
        return to_ctypes.map[data_type.numpy_dtype]
186

Martin Bauer's avatar
Martin Bauer committed
187
188

to_ctypes.map = {
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


204
def ctypes_from_llvm(data_type):
205
206
    if not ir:
        raise _ir_importerror
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
240
241
    if not ir:
        raise _ir_importerror
242
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
243
        return to_llvm_type(data_type.base_type).as_pointer()
244
    else:
Martin Bauer's avatar
Martin Bauer committed
245
246
        return to_llvm_type.map[data_type.numpy_dtype]

247

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
263

264

Martin Bauer's avatar
Martin Bauer committed
265
266
267
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
268
269
270
    return dtype


Martin Bauer's avatar
Martin Bauer committed
271
def collate_types(types):
272
273
274
275
276
277
278
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
279
        pointer_type = None
280
281
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
282
                if pointer_type is not None:
283
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
284
                pointer_type = t
285
286
287
288
289
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
290
        return pointer_type
291
292

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
293
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
294
    if not all_equal(t.width for t in vector_type):
295
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
296
    types = [peel_off_type(t, VectorType) for t in types]
297
298
299
300

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

301
302
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
303
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
304
305
306
307
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
308
309
310
311
    return result


@memorycache(maxsize=2048)
Martin Bauer's avatar
Martin Bauer committed
312
def get_type_of_expression(expr):
313
314
315
    from pystencils.astnodes import ResolvedFieldAccess
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
316
        return create_type("int")
317
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
Martin Bauer's avatar
Martin Bauer committed
318
        return create_type("double")
319
320
321
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
322
        return expr.dtype
323
    elif isinstance(expr, sp.Symbol):
324
        raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
325
    elif isinstance(expr, cast_func):
326
327
        return expr.args[1]
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
328
329
330
331
332
        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
333
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
334
335
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
336
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
337
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
338
339
340
341
        result = create_type("bool")
        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
342
        return result
343
344
    elif isinstance(expr, sp.Pow):
        return get_type_of_expression(expr.args[0])
345
    elif isinstance(expr, sp.Expr):
Martin Bauer's avatar
Martin Bauer committed
346
347
        types = tuple(get_type_of_expression(a) for a in expr.args)
        return collate_types(types)
348

349
    raise NotImplementedError("Could not determine type for", expr, type(expr))
350
351


352
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
353
354
    is_Atom = True

355
356
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
357

358
359
    def _sympystr(self, *args, **kwargs):
        return str(self)
360
361
362
363


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
364
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
365
366
367
368
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
369
370
371
372
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
373
            width = int(name[len("uint"):])
374
375
376
377
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
Martin Bauer's avatar
Martin Bauer committed
378
            raise NotImplementedError("Can map numpy to C name for %s" % (name,))
379
380
381

    def __init__(self, dtype, const=False):
        self.const = const
382
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
383
            self._dtype = dtype.numpy_dtype
384
385
        else:
            self._dtype = np.dtype(dtype)
386
387
388
389
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

390
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
391
        return self.numpy_dtype, self.const
392

393
    @property
Martin Bauer's avatar
Martin Bauer committed
394
    def base_type(self):
395
        return None
396

397
    @property
Martin Bauer's avatar
Martin Bauer committed
398
    def numpy_dtype(self):
399
400
        return self._dtype

401
    @property
Martin Bauer's avatar
Martin Bauer committed
402
    def item_size(self):
403
404
        return 1

405
    def is_int(self):
Martin Bauer's avatar
Martin Bauer committed
406
        return self.numpy_dtype in np.sctypes['int']
407
408

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
409
        return self.numpy_dtype in np.sctypes['float']
410
411

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
412
        return self.numpy_dtype in np.sctypes['uint']
413

Martin Bauer's avatar
Martin Bauer committed
414
415
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
416
417

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
418
        return self.numpy_dtype in np.sctypes['others']
419

420
    @property
Martin Bauer's avatar
Martin Bauer committed
421
422
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
423

Jan Hoenig's avatar
Jan Hoenig committed
424
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
425
        result = BasicType.numpy_name_to_c(str(self._dtype))
426
427
428
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
429

430
431
432
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
433
    def __eq__(self, other):
434
435
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
436
        else:
Martin Bauer's avatar
Martin Bauer committed
437
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
438
439
440
441
442

    def __hash__(self):
        return hash(str(self))


443
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
444
    instruction_set = None
445

Martin Bauer's avatar
Martin Bauer committed
446
447
    def __init__(self, base_type, width=4):
        self._base_type = base_type
448
449
450
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
451
452
    def base_type(self):
        return self._base_type
453
454

    @property
Martin Bauer's avatar
Martin Bauer committed
455
456
    def item_size(self):
        return self.width * self.base_type.item_size
457
458
459
460
461

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
462
            return (self.base_type, self.width) == (other.base_type, other.width)
463
464

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
465
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
466
            return "%s[%d]" % (self.base_type, self.width)
467
        else:
Martin Bauer's avatar
Martin Bauer committed
468
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
469
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
470
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
471
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
472
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
473
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
474
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
475
                return self.instruction_set['bool']
476
477
478
479
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
480
        return hash((self.base_type, self.width))
481

Martin Bauer's avatar
Martin Bauer committed
482
483
484
    def __getnewargs__(self):
        return self._base_type, self.width

485

486
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
487
488
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
489
490
491
        self.const = const
        self.restrict = restrict

492
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
493
        return self.base_type, self.const, self.restrict
494

495
496
497
498
499
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
500
501
    def base_type(self):
        return self._base_type
502

503
    @property
Martin Bauer's avatar
Martin Bauer committed
504
505
    def item_size(self):
        return self.base_type.item_size
506

507
508
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
509
            return False
510
        else:
Martin Bauer's avatar
Martin Bauer committed
511
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
512

Jan Hoenig's avatar
Jan Hoenig committed
513
    def __str__(self):
514
515
516
517
518
519
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
520

521
522
523
    def __repr__(self):
        return str(self)

524
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
525
        return hash((self._base_type, self.const, self.restrict))
526

Jan Hoenig's avatar
Jan Hoenig committed
527

528
class StructType:
Martin Bauer's avatar
Martin Bauer committed
529
    def __init__(self, numpy_type, const=False):
530
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
531
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
532

533
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
534
        return self.numpy_dtype, self.const
535

536
    @property
Martin Bauer's avatar
Martin Bauer committed
537
    def base_type(self):
538
539
540
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
541
    def numpy_dtype(self):
542
543
544
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
545
546
    def item_size(self):
        return self.numpy_dtype.itemsize
547

Martin Bauer's avatar
Martin Bauer committed
548
549
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
550

Martin Bauer's avatar
Martin Bauer committed
551
552
553
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
554

Martin Bauer's avatar
Martin Bauer committed
555
556
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
557

558
559
560
561
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
562
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
563
564
565
566
567
568
569
570

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

571
572
573
    def __repr__(self):
        return str(self)

574
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
575
        return hash((self.numpy_dtype, self.const))