data_types.py 16.4 KB
Newer Older
1
import ctypes
2
import sympy as sp
3
import numpy as np
4
5
6
7
8
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
9
from sympy.core.cache import cacheit
10

11
from pystencils.cache import memorycache
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.utils import all_equal
Martin Bauer's avatar
Martin Bauer committed
13
from sympy.logic.boolalg import Boolean
14

15

Martin Bauer's avatar
Martin Bauer committed
16
17
18
19
# noinspection PyPep8Naming
class cast_func(sp.Function, Boolean):
    # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well

20
21
22
23
24
25
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
26

27
28
29
30
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
31
32
33
34
35
36
37
38
39
    @property
    def dtype(self):
        return self.args[1]


# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

40

Martin Bauer's avatar
Martin Bauer committed
41
42
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
43
44
45
46
47
48
49
50
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


51
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
52
53
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
54
55
        return obj

56
    def __new_stage2__(cls, name, dtype):
57
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
58
        try:
Martin Bauer's avatar
Martin Bauer committed
59
            obj._dtype = create_type(dtype)
60
61
62
        except TypeError:
            # on error keep the string
            obj._dtype = dtype
63
64
65
66
67
68
69
70
71
72
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
Martin Bauer's avatar
Martin Bauer committed
73
74
        super_class_contents = list(super(TypedSymbol, self)._hashable_content())
        return tuple(super_class_contents + [hash(self._dtype)])
75
76

    def __getnewargs__(self):
77
78
79
        return self.name, self.dtype


Martin Bauer's avatar
Martin Bauer committed
80
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
81
82
83
84
85
86
87
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
88
    """
89
90
91
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
92
93
94
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
95
        else:
Martin Bauer's avatar
Martin Bauer committed
96
            return StructType(numpy_dtype, const=False)
97
98


99
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
100
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
101
102
103
104
105
106
107
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
108
    """
109
110
111
112
113
114
115
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
116
        else:
117
118
119
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
120
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
121
    base_part = parts.pop(0)
122
    const = False
Martin Bauer's avatar
Martin Bauer committed
123
    if 'const' in base_part:
124
        const = True
Martin Bauer's avatar
Martin Bauer committed
125
126
127
128
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
129
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
130
    current_type = BasicType(np.dtype(base_part[0]), const)
131
132
133
134
135
136
137
138
139
140
141
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
142
143
        current_type = PointerType(current_type, const, restrict)
    return current_type
144
145


Martin Bauer's avatar
Martin Bauer committed
146
147
148
149
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
150
151


Martin Bauer's avatar
Martin Bauer committed
152
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
153
154
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
155
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
156
157
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
158
159
160
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
161
        return ctypes.POINTER(ctypes.c_uint8)
162
    else:
Martin Bauer's avatar
Martin Bauer committed
163
        return to_ctypes.map[data_type.numpy_dtype]
164

Martin Bauer's avatar
Martin Bauer committed
165
166

to_ctypes.map = {
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


182
def ctypes_from_llvm(data_type):
183
184
    if not ir:
        raise _ir_importerror
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
218
219
    if not ir:
        raise _ir_importerror
220
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
221
        return to_llvm_type(data_type.base_type).as_pointer()
222
    else:
Martin Bauer's avatar
Martin Bauer committed
223
224
        return to_llvm_type.map[data_type.numpy_dtype]

225

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
241

242

Martin Bauer's avatar
Martin Bauer committed
243
244
245
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
246
247
248
    return dtype


Martin Bauer's avatar
Martin Bauer committed
249
def collate_types(types):
250
251
252
253
254
255
256
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
257
        pointer_type = None
258
259
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
260
                if pointer_type is not None:
261
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
262
                pointer_type = t
263
264
265
266
267
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
268
        return pointer_type
269
270

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
271
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
272
    if not all_equal(t.width for t in vector_type):
273
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
274
    types = [peel_off_type(t, VectorType) for t in types]
275
276
277
278
279

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
280
281
282
283
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
284
285
286
287
    return result


@memorycache(maxsize=2048)
Martin Bauer's avatar
Martin Bauer committed
288
def get_type_of_expression(expr):
289
290
291
    from pystencils.astnodes import ResolvedFieldAccess
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
292
293
294
295
        if expr == 1 or expr == -1:
            return create_type("int16")
        else:
            return create_type("int")
296
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
Martin Bauer's avatar
Martin Bauer committed
297
        return create_type("double")
298
299
300
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
301
        return expr.dtype
302
303
    elif isinstance(expr, sp.Symbol):
        raise ValueError("All symbols inside this expression have to be typed!")
Martin Bauer's avatar
Martin Bauer committed
304
    elif isinstance(expr, cast_func):
305
306
        return expr.args[1]
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
307
308
309
310
311
        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
312
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
313
314
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
315
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
316
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
317
318
319
320
        result = create_type("bool")
        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
321
        return result
322
323
    elif isinstance(expr, sp.Pow):
        return get_type_of_expression(expr.args[0])
324
    elif isinstance(expr, sp.Expr):
Martin Bauer's avatar
Martin Bauer committed
325
326
        types = tuple(get_type_of_expression(a) for a in expr.args)
        return collate_types(types)
327

328
    raise NotImplementedError("Could not determine type for", expr, type(expr))
329
330


331
332
333
class Type(sp.Basic):
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
334

335
336
    def _sympystr(self, *args, **kwargs):
        return str(self)
337
338
339
340


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
341
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
342
343
344
345
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
346
347
348
349
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
350
            width = int(name[len("uint"):])
351
352
353
354
355
356
357
358
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
            raise NotImplemented("Can map numpy to C name for %s" % (name,))

    def __init__(self, dtype, const=False):
        self.const = const
359
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
360
            self._dtype = dtype.numpy_dtype
361
362
        else:
            self._dtype = np.dtype(dtype)
363
364
365
366
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

367
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
368
        return self.numpy_dtype, self.const
369

370
    @property
Martin Bauer's avatar
Martin Bauer committed
371
    def base_type(self):
372
        return None
373

374
    @property
Martin Bauer's avatar
Martin Bauer committed
375
    def numpy_dtype(self):
376
377
        return self._dtype

378
    @property
Martin Bauer's avatar
Martin Bauer committed
379
    def item_size(self):
380
381
        return 1

382
    def is_int(self):
Martin Bauer's avatar
Martin Bauer committed
383
        return self.numpy_dtype in np.sctypes['int']
384
385

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
386
        return self.numpy_dtype in np.sctypes['float']
387
388

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
389
        return self.numpy_dtype in np.sctypes['uint']
390

Martin Bauer's avatar
Martin Bauer committed
391
392
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
393
394

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
395
        return self.numpy_dtype in np.sctypes['others']
396

397
    @property
Martin Bauer's avatar
Martin Bauer committed
398
399
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
400

Jan Hoenig's avatar
Jan Hoenig committed
401
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
402
        result = BasicType.numpy_name_to_c(str(self._dtype))
403
404
405
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
406

407
408
409
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
410
    def __eq__(self, other):
411
412
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
413
        else:
Martin Bauer's avatar
Martin Bauer committed
414
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
415
416
417
418
419

    def __hash__(self):
        return hash(str(self))


420
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
421
    instruction_set = None
422

Martin Bauer's avatar
Martin Bauer committed
423
424
    def __init__(self, base_type, width=4):
        self._base_type = base_type
425
426
427
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
428
429
    def base_type(self):
        return self._base_type
430
431

    @property
Martin Bauer's avatar
Martin Bauer committed
432
433
    def item_size(self):
        return self.width * self.base_type.item_size
434
435
436
437
438

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
439
            return (self.base_type, self.width) == (other.base_type, other.width)
440
441

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
442
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
443
            return "%s[%d]" % (self.base_type, self.width)
444
        else:
Martin Bauer's avatar
Martin Bauer committed
445
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
446
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
447
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
448
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
449
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
450
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
451
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
452
                return self.instruction_set['bool']
453
454
455
456
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
457
        return hash((self.base_type, self.width))
458

Martin Bauer's avatar
Martin Bauer committed
459
460
461
    def __getnewargs__(self):
        return self._base_type, self.width

462

463
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
464
465
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
466
467
468
        self.const = const
        self.restrict = restrict

469
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
470
        return self.base_type, self.const, self.restrict
471

472
473
474
475
476
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
477
478
    def base_type(self):
        return self._base_type
479

480
    @property
Martin Bauer's avatar
Martin Bauer committed
481
482
    def item_size(self):
        return self.base_type.item_size
483

484
485
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
486
            return False
487
        else:
Martin Bauer's avatar
Martin Bauer committed
488
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
489

Jan Hoenig's avatar
Jan Hoenig committed
490
    def __str__(self):
491
492
493
494
495
496
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
497

498
499
500
    def __repr__(self):
        return str(self)

501
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
502
        return hash((self._base_type, self.const, self.restrict))
503

Jan Hoenig's avatar
Jan Hoenig committed
504

505
class StructType:
Martin Bauer's avatar
Martin Bauer committed
506
    def __init__(self, numpy_type, const=False):
507
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
508
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
509

510
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
511
        return self.numpy_dtype, self.const
512

513
    @property
Martin Bauer's avatar
Martin Bauer committed
514
    def base_type(self):
515
516
517
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
518
    def numpy_dtype(self):
519
520
521
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
522
523
    def item_size(self):
        return self.numpy_dtype.itemsize
524

Martin Bauer's avatar
Martin Bauer committed
525
526
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
527

Martin Bauer's avatar
Martin Bauer committed
528
529
530
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
531

Martin Bauer's avatar
Martin Bauer committed
532
533
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
534

535
536
537
538
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
539
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
540
541
542
543
544
545
546
547

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

548
549
550
    def __repr__(self):
        return str(self)

551
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
552
        return hash((self.numpy_dtype, self.const))