data_types.py 16.4 KB
Newer Older
1
import ctypes
2
import sympy as sp
3
import numpy as np
4
5
6
7
8
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
9
from sympy.core.cache import cacheit
10

11
from pystencils.cache import memorycache
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.utils import all_equal
Martin Bauer's avatar
Martin Bauer committed
13
from sympy.logic.boolalg import Boolean
14

15

Martin Bauer's avatar
Martin Bauer committed
16
17
18
# noinspection PyPep8Naming
class cast_func(sp.Function, Boolean):
    # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
19
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
20

21
22
23
24
25
26
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
27

28
29
30
31
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
32
33
34
35
36
37
38
39
40
    @property
    def dtype(self):
        return self.args[1]


# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

41

Martin Bauer's avatar
Martin Bauer committed
42
43
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
44
45
46
47
48
49
50
51
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


52
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
53
54
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
55
56
        return obj

57
    def __new_stage2__(cls, name, dtype):
58
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
59
        try:
Martin Bauer's avatar
Martin Bauer committed
60
            obj._dtype = create_type(dtype)
61
62
63
        except TypeError:
            # on error keep the string
            obj._dtype = dtype
64
65
66
67
68
69
70
71
72
73
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
Martin Bauer's avatar
Martin Bauer committed
74
75
        super_class_contents = list(super(TypedSymbol, self)._hashable_content())
        return tuple(super_class_contents + [hash(self._dtype)])
76
77

    def __getnewargs__(self):
78
79
80
        return self.name, self.dtype


Martin Bauer's avatar
Martin Bauer committed
81
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
82
83
84
85
86
87
88
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
89
    """
90
91
92
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
93
94
95
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
96
        else:
Martin Bauer's avatar
Martin Bauer committed
97
            return StructType(numpy_dtype, const=False)
98
99


100
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
101
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
102
103
104
105
106
107
108
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
109
    """
110
111
112
113
114
115
116
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
117
        else:
118
119
120
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
121
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
122
    base_part = parts.pop(0)
123
    const = False
Martin Bauer's avatar
Martin Bauer committed
124
    if 'const' in base_part:
125
        const = True
Martin Bauer's avatar
Martin Bauer committed
126
127
128
129
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
130
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
131
    current_type = BasicType(np.dtype(base_part[0]), const)
132
133
134
135
136
137
138
139
140
141
142
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
143
144
        current_type = PointerType(current_type, const, restrict)
    return current_type
145
146


Martin Bauer's avatar
Martin Bauer committed
147
148
149
150
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
151
152


Martin Bauer's avatar
Martin Bauer committed
153
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
154
155
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
156
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
157
158
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
159
160
161
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
162
        return ctypes.POINTER(ctypes.c_uint8)
163
    else:
Martin Bauer's avatar
Martin Bauer committed
164
        return to_ctypes.map[data_type.numpy_dtype]
165

Martin Bauer's avatar
Martin Bauer committed
166
167

to_ctypes.map = {
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


183
def ctypes_from_llvm(data_type):
184
185
    if not ir:
        raise _ir_importerror
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
219
220
    if not ir:
        raise _ir_importerror
221
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
222
        return to_llvm_type(data_type.base_type).as_pointer()
223
    else:
Martin Bauer's avatar
Martin Bauer committed
224
225
        return to_llvm_type.map[data_type.numpy_dtype]

226

227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
242

243

Martin Bauer's avatar
Martin Bauer committed
244
245
246
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
247
248
249
    return dtype


Martin Bauer's avatar
Martin Bauer committed
250
def collate_types(types):
251
252
253
254
255
256
257
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
258
        pointer_type = None
259
260
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
261
                if pointer_type is not None:
262
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
263
                pointer_type = t
264
265
266
267
268
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
269
        return pointer_type
270
271

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
272
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
273
    if not all_equal(t.width for t in vector_type):
274
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
275
    types = [peel_off_type(t, VectorType) for t in types]
276
277
278
279

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

280
281
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
282
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
283
284
285
286
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
287
288
289
290
    return result


@memorycache(maxsize=2048)
Martin Bauer's avatar
Martin Bauer committed
291
def get_type_of_expression(expr):
292
293
294
    from pystencils.astnodes import ResolvedFieldAccess
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
295
        return create_type("int")
296
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
Martin Bauer's avatar
Martin Bauer committed
297
        return create_type("double")
298
299
300
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
301
        return expr.dtype
302
    elif isinstance(expr, sp.Symbol):
303
        raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
304
    elif isinstance(expr, cast_func):
305
306
        return expr.args[1]
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
307
308
309
310
311
        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
312
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
313
314
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
315
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
316
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
317
318
319
320
        result = create_type("bool")
        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
321
        return result
322
323
    elif isinstance(expr, sp.Pow):
        return get_type_of_expression(expr.args[0])
324
    elif isinstance(expr, sp.Expr):
Martin Bauer's avatar
Martin Bauer committed
325
326
        types = tuple(get_type_of_expression(a) for a in expr.args)
        return collate_types(types)
327

328
    raise NotImplementedError("Could not determine type for", expr, type(expr))
329
330


331
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
332
333
    is_Atom = True

334
335
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
336

337
338
    def _sympystr(self, *args, **kwargs):
        return str(self)
339
340
341
342


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
343
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
344
345
346
347
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
348
349
350
351
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
352
            width = int(name[len("uint"):])
353
354
355
356
357
358
359
360
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
            raise NotImplemented("Can map numpy to C name for %s" % (name,))

    def __init__(self, dtype, const=False):
        self.const = const
361
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
362
            self._dtype = dtype.numpy_dtype
363
364
        else:
            self._dtype = np.dtype(dtype)
365
366
367
368
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

369
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
370
        return self.numpy_dtype, self.const
371

372
    @property
Martin Bauer's avatar
Martin Bauer committed
373
    def base_type(self):
374
        return None
375

376
    @property
Martin Bauer's avatar
Martin Bauer committed
377
    def numpy_dtype(self):
378
379
        return self._dtype

380
    @property
Martin Bauer's avatar
Martin Bauer committed
381
    def item_size(self):
382
383
        return 1

384
    def is_int(self):
Martin Bauer's avatar
Martin Bauer committed
385
        return self.numpy_dtype in np.sctypes['int']
386
387

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
388
        return self.numpy_dtype in np.sctypes['float']
389
390

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
391
        return self.numpy_dtype in np.sctypes['uint']
392

Martin Bauer's avatar
Martin Bauer committed
393
394
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
395
396

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
397
        return self.numpy_dtype in np.sctypes['others']
398

399
    @property
Martin Bauer's avatar
Martin Bauer committed
400
401
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
402

Jan Hoenig's avatar
Jan Hoenig committed
403
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
404
        result = BasicType.numpy_name_to_c(str(self._dtype))
405
406
407
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
408

409
410
411
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
412
    def __eq__(self, other):
413
414
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
415
        else:
Martin Bauer's avatar
Martin Bauer committed
416
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
417
418
419
420
421

    def __hash__(self):
        return hash(str(self))


422
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
423
    instruction_set = None
424

Martin Bauer's avatar
Martin Bauer committed
425
426
    def __init__(self, base_type, width=4):
        self._base_type = base_type
427
428
429
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
430
431
    def base_type(self):
        return self._base_type
432
433

    @property
Martin Bauer's avatar
Martin Bauer committed
434
435
    def item_size(self):
        return self.width * self.base_type.item_size
436
437
438
439
440

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
441
            return (self.base_type, self.width) == (other.base_type, other.width)
442
443

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
444
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
445
            return "%s[%d]" % (self.base_type, self.width)
446
        else:
Martin Bauer's avatar
Martin Bauer committed
447
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
448
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
449
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
450
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
451
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
452
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
453
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
454
                return self.instruction_set['bool']
455
456
457
458
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
459
        return hash((self.base_type, self.width))
460

Martin Bauer's avatar
Martin Bauer committed
461
462
463
    def __getnewargs__(self):
        return self._base_type, self.width

464

465
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
466
467
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
468
469
470
        self.const = const
        self.restrict = restrict

471
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
472
        return self.base_type, self.const, self.restrict
473

474
475
476
477
478
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
479
480
    def base_type(self):
        return self._base_type
481

482
    @property
Martin Bauer's avatar
Martin Bauer committed
483
484
    def item_size(self):
        return self.base_type.item_size
485

486
487
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
488
            return False
489
        else:
Martin Bauer's avatar
Martin Bauer committed
490
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
491

Jan Hoenig's avatar
Jan Hoenig committed
492
    def __str__(self):
493
494
495
496
497
498
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
499

500
501
502
    def __repr__(self):
        return str(self)

503
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
504
        return hash((self._base_type, self.const, self.restrict))
505

Jan Hoenig's avatar
Jan Hoenig committed
506

507
class StructType:
Martin Bauer's avatar
Martin Bauer committed
508
    def __init__(self, numpy_type, const=False):
509
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
510
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
511

512
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
513
        return self.numpy_dtype, self.const
514

515
    @property
Martin Bauer's avatar
Martin Bauer committed
516
    def base_type(self):
517
518
519
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
520
    def numpy_dtype(self):
521
522
523
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
524
525
    def item_size(self):
        return self.numpy_dtype.itemsize
526

Martin Bauer's avatar
Martin Bauer committed
527
528
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
529

Martin Bauer's avatar
Martin Bauer committed
530
531
532
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
533

Martin Bauer's avatar
Martin Bauer committed
534
535
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
536

537
538
539
540
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
541
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
542
543
544
545
546
547
548
549

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

550
551
552
    def __repr__(self):
        return str(self)

553
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
554
        return hash((self.numpy_dtype, self.const))