data_types.py 16 KB
Newer Older
1
import ctypes
2
import sympy as sp
3
import numpy as np
4
5
6
7
8
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
9
from sympy.core.cache import cacheit
10

11
from pystencils.cache import memorycache
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.utils import all_equal
13

14

Martin Bauer's avatar
Martin Bauer committed
15
# to work in conditions of sp.Piecewise cast_func has to be of type Relational as well
Martin Bauer's avatar
Martin Bauer committed
16
class cast_func(sp.Function, sp.Rel):
17
18
19
20
21
22
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
23

24
25
26
27
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

28

Martin Bauer's avatar
Martin Bauer committed
29
class pointer_arithmetic_func(sp.Function, sp.Rel):
30
31
32
33
34
35
36
37
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


38
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
39
40
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
41
42
        return obj

43
    def __new_stage2__(cls, name, dtype):
44
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
45
        try:
Martin Bauer's avatar
Martin Bauer committed
46
            obj._dtype = create_type(dtype)
47
48
49
        except TypeError:
            # on error keep the string
            obj._dtype = dtype
50
51
52
53
54
55
56
57
58
59
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
Martin Bauer's avatar
Martin Bauer committed
60
61
        super_class_contents = list(super(TypedSymbol, self)._hashable_content())
        return tuple(super_class_contents + [hash(self._dtype)])
62
63

    def __getnewargs__(self):
64
65
66
        return self.name, self.dtype


Martin Bauer's avatar
Martin Bauer committed
67
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
68
69
70
71
72
73
74
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
75
    """
76
77
78
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
79
80
81
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
82
        else:
Martin Bauer's avatar
Martin Bauer committed
83
            return StructType(numpy_dtype, const=False)
84
85


86
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
87
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
88
89
90
91
92
93
94
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
95
    """
96
97
98
99
100
101
102
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
103
        else:
104
105
106
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
107
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
108
    base_part = parts.pop(0)
109
    const = False
Martin Bauer's avatar
Martin Bauer committed
110
    if 'const' in base_part:
111
        const = True
Martin Bauer's avatar
Martin Bauer committed
112
113
114
115
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
116
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
117
    current_type = BasicType(np.dtype(base_part[0]), const)
118
119
120
121
122
123
124
125
126
127
128
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
129
130
        current_type = PointerType(current_type, const, restrict)
    return current_type
131
132


Martin Bauer's avatar
Martin Bauer committed
133
134
135
136
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
137
138


Martin Bauer's avatar
Martin Bauer committed
139
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
140
141
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
142
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
143
144
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
145
146
147
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
148
        return ctypes.POINTER(ctypes.c_uint8)
149
    else:
Martin Bauer's avatar
Martin Bauer committed
150
        return to_ctypes.map[data_type.numpy_dtype]
151

Martin Bauer's avatar
Martin Bauer committed
152
153

to_ctypes.map = {
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


169
def ctypes_from_llvm(data_type):
170
171
    if not ir:
        raise _ir_importerror
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
205
206
    if not ir:
        raise _ir_importerror
207
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
208
        return to_llvm_type(data_type.base_type).as_pointer()
209
    else:
Martin Bauer's avatar
Martin Bauer committed
210
211
        return to_llvm_type.map[data_type.numpy_dtype]

212

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
228

229

Martin Bauer's avatar
Martin Bauer committed
230
231
232
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
233
234
235
    return dtype


Martin Bauer's avatar
Martin Bauer committed
236
def collate_types(types):
237
238
239
240
241
242
243
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
244
        pointer_type = None
245
246
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
247
                if pointer_type is not None:
248
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
249
                pointer_type = t
250
251
252
253
254
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
255
        return pointer_type
256
257

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
258
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
259
    if not all_equal(t.width for t in vector_type):
260
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
261
    types = [peel_off_type(t, VectorType) for t in types]
262
263
264
265
266

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
267
268
269
270
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
271
272
273
274
    return result


@memorycache(maxsize=2048)
Martin Bauer's avatar
Martin Bauer committed
275
def get_type_of_expression(expr):
276
277
278
    from pystencils.astnodes import ResolvedFieldAccess
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
Martin Bauer's avatar
Martin Bauer committed
279
        return create_type("int")
280
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
Martin Bauer's avatar
Martin Bauer committed
281
        return create_type("double")
282
283
284
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
285
        return expr.dtype
286
287
    elif isinstance(expr, sp.Symbol):
        raise ValueError("All symbols inside this expression have to be typed!")
Martin Bauer's avatar
Martin Bauer committed
288
    elif hasattr(expr, 'func') and expr.func == cast_func:
289
290
        return expr.args[1]
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
291
292
293
294
295
        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
296
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
297
298
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
299
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
300
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
301
302
303
304
        result = create_type("bool")
        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
305
        return result
306
    elif isinstance(expr, sp.Expr):
Martin Bauer's avatar
Martin Bauer committed
307
308
        types = tuple(get_type_of_expression(a) for a in expr.args)
        return collate_types(types)
309

310
    raise NotImplementedError("Could not determine type for", expr, type(expr))
311
312


313
314
315
class Type(sp.Basic):
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
316

317
318
    def _sympystr(self, *args, **kwargs):
        return str(self)
319
320
321
322


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
323
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
324
325
326
327
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
328
329
330
331
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
332
            width = int(name[len("uint"):])
333
334
335
336
337
338
339
340
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
            raise NotImplemented("Can map numpy to C name for %s" % (name,))

    def __init__(self, dtype, const=False):
        self.const = const
341
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
342
            self._dtype = dtype.numpy_dtype
343
344
        else:
            self._dtype = np.dtype(dtype)
345
346
347
348
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

349
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
350
        return self.numpy_dtype, self.const
351

352
    @property
Martin Bauer's avatar
Martin Bauer committed
353
    def base_type(self):
354
        return None
355

356
    @property
Martin Bauer's avatar
Martin Bauer committed
357
    def numpy_dtype(self):
358
359
        return self._dtype

360
    @property
Martin Bauer's avatar
Martin Bauer committed
361
    def item_size(self):
362
363
        return 1

364
    def is_int(self):
Martin Bauer's avatar
Martin Bauer committed
365
        return self.numpy_dtype in np.sctypes['int']
366
367

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
368
        return self.numpy_dtype in np.sctypes['float']
369
370

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
371
        return self.numpy_dtype in np.sctypes['uint']
372

Martin Bauer's avatar
Martin Bauer committed
373
374
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
375
376

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
377
        return self.numpy_dtype in np.sctypes['others']
378

379
    @property
Martin Bauer's avatar
Martin Bauer committed
380
381
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
382

Jan Hoenig's avatar
Jan Hoenig committed
383
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
384
        result = BasicType.numpy_name_to_c(str(self._dtype))
385
386
387
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
388

389
390
391
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
392
    def __eq__(self, other):
393
394
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
395
        else:
Martin Bauer's avatar
Martin Bauer committed
396
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
397
398
399
400
401

    def __hash__(self):
        return hash(str(self))


402
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
403
    instruction_set = None
404

Martin Bauer's avatar
Martin Bauer committed
405
406
    def __init__(self, base_type, width=4):
        self._base_type = base_type
407
408
409
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
410
411
    def base_type(self):
        return self._base_type
412
413

    @property
Martin Bauer's avatar
Martin Bauer committed
414
415
    def item_size(self):
        return self.width * self.base_type.item_size
416
417
418
419
420

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
421
            return (self.base_type, self.width) == (other.base_type, other.width)
422
423

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
424
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
425
            return "%s[%d]" % (self.base_type, self.width)
426
        else:
Martin Bauer's avatar
Martin Bauer committed
427
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
428
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
429
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
430
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
431
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
432
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
433
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
434
                return self.instruction_set['bool']
435
436
437
438
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
439
        return hash((self.base_type, self.width))
440

Martin Bauer's avatar
Martin Bauer committed
441
442
443
    def __getnewargs__(self):
        return self._base_type, self.width

444

445
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
446
447
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
448
449
450
        self.const = const
        self.restrict = restrict

451
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
452
        return self.base_type, self.const, self.restrict
453

454
455
456
457
458
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
459
460
    def base_type(self):
        return self._base_type
461

462
    @property
Martin Bauer's avatar
Martin Bauer committed
463
464
    def item_size(self):
        return self.base_type.item_size
465

466
467
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
468
            return False
469
        else:
Martin Bauer's avatar
Martin Bauer committed
470
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
471

Jan Hoenig's avatar
Jan Hoenig committed
472
    def __str__(self):
473
474
475
476
477
478
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
479

480
481
482
    def __repr__(self):
        return str(self)

483
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
484
        return hash((self._base_type, self.const, self.restrict))
485

Jan Hoenig's avatar
Jan Hoenig committed
486

487
class StructType:
Martin Bauer's avatar
Martin Bauer committed
488
    def __init__(self, numpy_type, const=False):
489
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
490
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
491

492
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
493
        return self.numpy_dtype, self.const
494

495
    @property
Martin Bauer's avatar
Martin Bauer committed
496
    def base_type(self):
497
498
499
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
500
    def numpy_dtype(self):
501
502
503
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
504
505
    def item_size(self):
        return self.numpy_dtype.itemsize
506

Martin Bauer's avatar
Martin Bauer committed
507
508
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
509

Martin Bauer's avatar
Martin Bauer committed
510
511
512
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
513

Martin Bauer's avatar
Martin Bauer committed
514
515
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
516

517
518
519
520
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
521
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
522
523
524
525
526
527
528
529

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

530
531
532
    def __repr__(self):
        return str(self)

533
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
534
        return hash((self.numpy_dtype, self.const))