data_types.py 16.4 KB
Newer Older
1
import ctypes
2
import sympy as sp
3
import numpy as np
4
5
6
7
8
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
9
from sympy.core.cache import cacheit
10

11
from pystencils.cache import memorycache
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.utils import all_equal
Martin Bauer's avatar
Martin Bauer committed
13
from sympy.logic.boolalg import Boolean
14

15

Martin Bauer's avatar
Martin Bauer committed
16
17
18
19
# noinspection PyPep8Naming
class cast_func(sp.Function, Boolean):
    # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well

20
21
22
23
24
25
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
26

27
28
29
30
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
31
32
33
34
35
36
37
38
39
    @property
    def dtype(self):
        return self.args[1]


# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

40

Martin Bauer's avatar
Martin Bauer committed
41
42
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
43
44
45
46
47
48
49
50
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


51
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
52
53
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
54
55
        return obj

56
    def __new_stage2__(cls, name, dtype):
57
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
58
        try:
Martin Bauer's avatar
Martin Bauer committed
59
            obj._dtype = create_type(dtype)
60
61
62
        except TypeError:
            # on error keep the string
            obj._dtype = dtype
63
64
65
66
67
68
69
70
71
72
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
Martin Bauer's avatar
Martin Bauer committed
73
74
        super_class_contents = list(super(TypedSymbol, self)._hashable_content())
        return tuple(super_class_contents + [hash(self._dtype)])
75
76

    def __getnewargs__(self):
77
78
79
        return self.name, self.dtype


Martin Bauer's avatar
Martin Bauer committed
80
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
81
82
83
84
85
86
87
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
88
    """
89
90
91
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
92
93
94
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
95
        else:
Martin Bauer's avatar
Martin Bauer committed
96
            return StructType(numpy_dtype, const=False)
97
98


99
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
100
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
101
102
103
104
105
106
107
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
108
    """
109
110
111
112
113
114
115
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
116
        else:
117
118
119
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
120
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
121
    base_part = parts.pop(0)
122
    const = False
Martin Bauer's avatar
Martin Bauer committed
123
    if 'const' in base_part:
124
        const = True
Martin Bauer's avatar
Martin Bauer committed
125
126
127
128
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
129
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
130
    current_type = BasicType(np.dtype(base_part[0]), const)
131
132
133
134
135
136
137
138
139
140
141
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
142
143
        current_type = PointerType(current_type, const, restrict)
    return current_type
144
145


Martin Bauer's avatar
Martin Bauer committed
146
147
148
149
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
150
151


Martin Bauer's avatar
Martin Bauer committed
152
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
153
154
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
155
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
156
157
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
158
159
160
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
161
        return ctypes.POINTER(ctypes.c_uint8)
162
    else:
Martin Bauer's avatar
Martin Bauer committed
163
        return to_ctypes.map[data_type.numpy_dtype]
164

Martin Bauer's avatar
Martin Bauer committed
165
166

to_ctypes.map = {
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


182
def ctypes_from_llvm(data_type):
183
184
    if not ir:
        raise _ir_importerror
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
218
219
    if not ir:
        raise _ir_importerror
220
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
221
        return to_llvm_type(data_type.base_type).as_pointer()
222
    else:
Martin Bauer's avatar
Martin Bauer committed
223
224
        return to_llvm_type.map[data_type.numpy_dtype]

225

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
241

242

Martin Bauer's avatar
Martin Bauer committed
243
244
245
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
246
247
248
    return dtype


Martin Bauer's avatar
Martin Bauer committed
249
def collate_types(types):
250
251
252
253
254
255
256
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
257
        pointer_type = None
258
259
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
260
                if pointer_type is not None:
261
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
262
                pointer_type = t
263
264
265
266
267
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
268
        return pointer_type
269
270

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
271
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
272
    if not all_equal(t.width for t in vector_type):
273
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
274
    types = [peel_off_type(t, VectorType) for t in types]
275
276
277
278

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

279
280
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
281
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
282
283
284
285
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
286
287
288
289
    return result


@memorycache(maxsize=2048)
Martin Bauer's avatar
Martin Bauer committed
290
def get_type_of_expression(expr):
291
292
293
    from pystencils.astnodes import ResolvedFieldAccess
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
294
        return create_type("int")
295
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
Martin Bauer's avatar
Martin Bauer committed
296
        return create_type("double")
297
298
299
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
300
        return expr.dtype
301
302
    elif isinstance(expr, sp.Symbol):
        raise ValueError("All symbols inside this expression have to be typed!")
Martin Bauer's avatar
Martin Bauer committed
303
    elif isinstance(expr, cast_func):
304
305
        return expr.args[1]
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
306
307
308
309
310
        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
311
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
312
313
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
314
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
315
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
316
317
318
319
        result = create_type("bool")
        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
320
        return result
321
322
    elif isinstance(expr, sp.Pow):
        return get_type_of_expression(expr.args[0])
323
    elif isinstance(expr, sp.Expr):
Martin Bauer's avatar
Martin Bauer committed
324
325
        types = tuple(get_type_of_expression(a) for a in expr.args)
        return collate_types(types)
326

327
    raise NotImplementedError("Could not determine type for", expr, type(expr))
328
329


330
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
331
332
    is_Atom = True

333
334
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
335

336
337
    def _sympystr(self, *args, **kwargs):
        return str(self)
338
339
340
341


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
342
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
343
344
345
346
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
347
348
349
350
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
351
            width = int(name[len("uint"):])
352
353
354
355
356
357
358
359
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
            raise NotImplemented("Can map numpy to C name for %s" % (name,))

    def __init__(self, dtype, const=False):
        self.const = const
360
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
361
            self._dtype = dtype.numpy_dtype
362
363
        else:
            self._dtype = np.dtype(dtype)
364
365
366
367
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

368
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
369
        return self.numpy_dtype, self.const
370

371
    @property
Martin Bauer's avatar
Martin Bauer committed
372
    def base_type(self):
373
        return None
374

375
    @property
Martin Bauer's avatar
Martin Bauer committed
376
    def numpy_dtype(self):
377
378
        return self._dtype

379
    @property
Martin Bauer's avatar
Martin Bauer committed
380
    def item_size(self):
381
382
        return 1

383
    def is_int(self):
Martin Bauer's avatar
Martin Bauer committed
384
        return self.numpy_dtype in np.sctypes['int']
385
386

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
387
        return self.numpy_dtype in np.sctypes['float']
388
389

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
390
        return self.numpy_dtype in np.sctypes['uint']
391

Martin Bauer's avatar
Martin Bauer committed
392
393
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
394
395

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
396
        return self.numpy_dtype in np.sctypes['others']
397

398
    @property
Martin Bauer's avatar
Martin Bauer committed
399
400
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
401

Jan Hoenig's avatar
Jan Hoenig committed
402
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
403
        result = BasicType.numpy_name_to_c(str(self._dtype))
404
405
406
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
407

408
409
410
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
411
    def __eq__(self, other):
412
413
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
414
        else:
Martin Bauer's avatar
Martin Bauer committed
415
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
416
417
418
419
420

    def __hash__(self):
        return hash(str(self))


421
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
422
    instruction_set = None
423

Martin Bauer's avatar
Martin Bauer committed
424
425
    def __init__(self, base_type, width=4):
        self._base_type = base_type
426
427
428
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
429
430
    def base_type(self):
        return self._base_type
431
432

    @property
Martin Bauer's avatar
Martin Bauer committed
433
434
    def item_size(self):
        return self.width * self.base_type.item_size
435
436
437
438
439

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
440
            return (self.base_type, self.width) == (other.base_type, other.width)
441
442

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
443
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
444
            return "%s[%d]" % (self.base_type, self.width)
445
        else:
Martin Bauer's avatar
Martin Bauer committed
446
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
447
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
448
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
449
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
450
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
451
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
452
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
453
                return self.instruction_set['bool']
454
455
456
457
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
458
        return hash((self.base_type, self.width))
459

Martin Bauer's avatar
Martin Bauer committed
460
461
462
    def __getnewargs__(self):
        return self._base_type, self.width

463

464
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
465
466
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
467
468
469
        self.const = const
        self.restrict = restrict

470
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
471
        return self.base_type, self.const, self.restrict
472

473
474
475
476
477
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
478
479
    def base_type(self):
        return self._base_type
480

481
    @property
Martin Bauer's avatar
Martin Bauer committed
482
483
    def item_size(self):
        return self.base_type.item_size
484

485
486
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
487
            return False
488
        else:
Martin Bauer's avatar
Martin Bauer committed
489
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
490

Jan Hoenig's avatar
Jan Hoenig committed
491
    def __str__(self):
492
493
494
495
496
497
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
498

499
500
501
    def __repr__(self):
        return str(self)

502
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
503
        return hash((self._base_type, self.const, self.restrict))
504

Jan Hoenig's avatar
Jan Hoenig committed
505

506
class StructType:
Martin Bauer's avatar
Martin Bauer committed
507
    def __init__(self, numpy_type, const=False):
508
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
509
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
510

511
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
512
        return self.numpy_dtype, self.const
513

514
    @property
Martin Bauer's avatar
Martin Bauer committed
515
    def base_type(self):
516
517
518
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
519
    def numpy_dtype(self):
520
521
522
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
523
524
    def item_size(self):
        return self.numpy_dtype.itemsize
525

Martin Bauer's avatar
Martin Bauer committed
526
527
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
528

Martin Bauer's avatar
Martin Bauer committed
529
530
531
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
532

Martin Bauer's avatar
Martin Bauer committed
533
534
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
535

536
537
538
539
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
540
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
541
542
543
544
545
546
547
548

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

549
550
551
    def __repr__(self):
        return str(self)

552
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
553
        return hash((self.numpy_dtype, self.const))