data_types.py 16.4 KB
Newer Older
1
import ctypes
2
import sympy as sp
3
import numpy as np
4
5
6
7
8
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
9
from sympy.core.cache import cacheit
10

11
from pystencils.cache import memorycache
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.utils import all_equal
Martin Bauer's avatar
Martin Bauer committed
13
from sympy.logic.boolalg import Boolean
14

15

Martin Bauer's avatar
Martin Bauer committed
16
17
18
19
# noinspection PyPep8Naming
class cast_func(sp.Function, Boolean):
    # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well

20
21
22
23
24
25
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
26

27
28
29
30
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
31
32
33
34
35
36
37
38
39
    @property
    def dtype(self):
        return self.args[1]


# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

40

Martin Bauer's avatar
Martin Bauer committed
41
42
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
43
44
45
46
47
48
49
50
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


51
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
52
53
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
54
55
        return obj

56
    def __new_stage2__(cls, name, dtype):
57
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
58
        try:
Martin Bauer's avatar
Martin Bauer committed
59
            obj._dtype = create_type(dtype)
60
61
62
        except TypeError:
            # on error keep the string
            obj._dtype = dtype
63
64
65
66
67
68
69
70
71
72
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
Martin Bauer's avatar
Martin Bauer committed
73
74
        super_class_contents = list(super(TypedSymbol, self)._hashable_content())
        return tuple(super_class_contents + [hash(self._dtype)])
75
76

    def __getnewargs__(self):
77
78
79
        return self.name, self.dtype


Martin Bauer's avatar
Martin Bauer committed
80
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
81
82
83
84
85
86
87
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
88
    """
89
90
91
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
92
93
94
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
95
        else:
Martin Bauer's avatar
Martin Bauer committed
96
            return StructType(numpy_dtype, const=False)
97
98


99
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
100
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
101
102
103
104
105
106
107
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
108
    """
109
110
111
112
113
114
115
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
116
        else:
117
118
119
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
120
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
121
    base_part = parts.pop(0)
122
    const = False
Martin Bauer's avatar
Martin Bauer committed
123
    if 'const' in base_part:
124
        const = True
Martin Bauer's avatar
Martin Bauer committed
125
126
127
128
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
129
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
130
    current_type = BasicType(np.dtype(base_part[0]), const)
131
132
133
134
135
136
137
138
139
140
141
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
142
143
        current_type = PointerType(current_type, const, restrict)
    return current_type
144
145


Martin Bauer's avatar
Martin Bauer committed
146
147
148
149
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
150
151


Martin Bauer's avatar
Martin Bauer committed
152
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
153
154
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
155
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
156
157
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
158
159
160
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
161
        return ctypes.POINTER(ctypes.c_uint8)
162
    else:
Martin Bauer's avatar
Martin Bauer committed
163
        return to_ctypes.map[data_type.numpy_dtype]
164

Martin Bauer's avatar
Martin Bauer committed
165
166

to_ctypes.map = {
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


182
def ctypes_from_llvm(data_type):
183
184
    if not ir:
        raise _ir_importerror
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
218
219
    if not ir:
        raise _ir_importerror
220
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
221
        return to_llvm_type(data_type.base_type).as_pointer()
222
    else:
Martin Bauer's avatar
Martin Bauer committed
223
224
        return to_llvm_type.map[data_type.numpy_dtype]

225

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
241

242

Martin Bauer's avatar
Martin Bauer committed
243
244
245
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
246
247
248
    return dtype


Martin Bauer's avatar
Martin Bauer committed
249
def collate_types(types):
250
251
252
253
254
255
256
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
257
        pointer_type = None
258
259
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
260
                if pointer_type is not None:
261
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
262
                pointer_type = t
263
264
265
266
267
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
268
        return pointer_type
269
270

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
271
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
272
    if not all_equal(t.width for t in vector_type):
273
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
274
    types = [peel_off_type(t, VectorType) for t in types]
275
276
277
278

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

279
280
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
281
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
282
283
284
285
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
286
287
288
289
    return result


@memorycache(maxsize=2048)
Martin Bauer's avatar
Martin Bauer committed
290
def get_type_of_expression(expr):
291
292
293
    from pystencils.astnodes import ResolvedFieldAccess
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
294
        return create_type("int")
295
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
Martin Bauer's avatar
Martin Bauer committed
296
        return create_type("double")
297
298
299
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
300
        return expr.dtype
301
302
    elif isinstance(expr, sp.Symbol):
        raise ValueError("All symbols inside this expression have to be typed!")
Martin Bauer's avatar
Martin Bauer committed
303
    elif isinstance(expr, cast_func):
304
305
        return expr.args[1]
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
306
307
308
309
310
        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
311
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
312
313
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
314
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
315
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
316
317
318
319
        result = create_type("bool")
        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
320
        return result
321
322
    elif isinstance(expr, sp.Pow):
        return get_type_of_expression(expr.args[0])
323
    elif isinstance(expr, sp.Expr):
Martin Bauer's avatar
Martin Bauer committed
324
325
        types = tuple(get_type_of_expression(a) for a in expr.args)
        return collate_types(types)
326

327
    raise NotImplementedError("Could not determine type for", expr, type(expr))
328
329


330
331
332
class Type(sp.Basic):
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
333

334
335
    def _sympystr(self, *args, **kwargs):
        return str(self)
336
337
338
339


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
340
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
341
342
343
344
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
345
346
347
348
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
349
            width = int(name[len("uint"):])
350
351
352
353
354
355
356
357
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
            raise NotImplemented("Can map numpy to C name for %s" % (name,))

    def __init__(self, dtype, const=False):
        self.const = const
358
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
359
            self._dtype = dtype.numpy_dtype
360
361
        else:
            self._dtype = np.dtype(dtype)
362
363
364
365
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

366
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
367
        return self.numpy_dtype, self.const
368

369
    @property
Martin Bauer's avatar
Martin Bauer committed
370
    def base_type(self):
371
        return None
372

373
    @property
Martin Bauer's avatar
Martin Bauer committed
374
    def numpy_dtype(self):
375
376
        return self._dtype

377
    @property
Martin Bauer's avatar
Martin Bauer committed
378
    def item_size(self):
379
380
        return 1

381
    def is_int(self):
Martin Bauer's avatar
Martin Bauer committed
382
        return self.numpy_dtype in np.sctypes['int']
383
384

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
385
        return self.numpy_dtype in np.sctypes['float']
386
387

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
388
        return self.numpy_dtype in np.sctypes['uint']
389

Martin Bauer's avatar
Martin Bauer committed
390
391
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
392
393

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
394
        return self.numpy_dtype in np.sctypes['others']
395

396
    @property
Martin Bauer's avatar
Martin Bauer committed
397
398
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
399

Jan Hoenig's avatar
Jan Hoenig committed
400
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
401
        result = BasicType.numpy_name_to_c(str(self._dtype))
402
403
404
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
405

406
407
408
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
409
    def __eq__(self, other):
410
411
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
412
        else:
Martin Bauer's avatar
Martin Bauer committed
413
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
414
415
416
417
418

    def __hash__(self):
        return hash(str(self))


419
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
420
    instruction_set = None
421

Martin Bauer's avatar
Martin Bauer committed
422
423
    def __init__(self, base_type, width=4):
        self._base_type = base_type
424
425
426
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
427
428
    def base_type(self):
        return self._base_type
429
430

    @property
Martin Bauer's avatar
Martin Bauer committed
431
432
    def item_size(self):
        return self.width * self.base_type.item_size
433
434
435
436
437

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
438
            return (self.base_type, self.width) == (other.base_type, other.width)
439
440

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
441
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
442
            return "%s[%d]" % (self.base_type, self.width)
443
        else:
Martin Bauer's avatar
Martin Bauer committed
444
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
445
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
446
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
447
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
448
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
449
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
450
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
451
                return self.instruction_set['bool']
452
453
454
455
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
456
        return hash((self.base_type, self.width))
457

Martin Bauer's avatar
Martin Bauer committed
458
459
460
    def __getnewargs__(self):
        return self._base_type, self.width

461

462
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
463
464
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
465
466
467
        self.const = const
        self.restrict = restrict

468
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
469
        return self.base_type, self.const, self.restrict
470

471
472
473
474
475
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
476
477
    def base_type(self):
        return self._base_type
478

479
    @property
Martin Bauer's avatar
Martin Bauer committed
480
481
    def item_size(self):
        return self.base_type.item_size
482

483
484
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
485
            return False
486
        else:
Martin Bauer's avatar
Martin Bauer committed
487
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
488

Jan Hoenig's avatar
Jan Hoenig committed
489
    def __str__(self):
490
491
492
493
494
495
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
496

497
498
499
    def __repr__(self):
        return str(self)

500
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
501
        return hash((self._base_type, self.const, self.restrict))
502

Jan Hoenig's avatar
Jan Hoenig committed
503

504
class StructType:
Martin Bauer's avatar
Martin Bauer committed
505
    def __init__(self, numpy_type, const=False):
506
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
507
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
508

509
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
510
        return self.numpy_dtype, self.const
511

512
    @property
Martin Bauer's avatar
Martin Bauer committed
513
    def base_type(self):
514
515
516
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
517
    def numpy_dtype(self):
518
519
520
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
521
522
    def item_size(self):
        return self.numpy_dtype.itemsize
523

Martin Bauer's avatar
Martin Bauer committed
524
525
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
526

Martin Bauer's avatar
Martin Bauer committed
527
528
529
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
530

Martin Bauer's avatar
Martin Bauer committed
531
532
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
533

534
535
536
537
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
538
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
539
540
541
542
543
544
545
546

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

547
548
549
    def __repr__(self):
        return str(self)

550
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
551
        return hash((self.numpy_dtype, self.const))