data_types.py 17 KB
Newer Older
1
import ctypes
2
import sympy as sp
3
import numpy as np
4
5
6
7
8
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
9
from sympy.core.cache import cacheit
10

11
from pystencils.cache import memorycache
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.utils import all_equal
Martin Bauer's avatar
Martin Bauer committed
13
from sympy.logic.boolalg import Boolean
14

15

Martin Bauer's avatar
Martin Bauer committed
16
# noinspection PyPep8Naming
17
class cast_func(sp.Function):
18
    is_Atom = True
Martin Bauer's avatar
Martin Bauer committed
19

20
21
22
23
24
25
26
27
    def __new__(cls, *args, **kwargs):
        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
        # to problems when for example comparing cast_func's for equality
        #
        # lhs = bitwise_and(a, cast_func(1, 'int'))
        # rhs = cast_func(0, 'int')
        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
Martin Bauer's avatar
Martin Bauer committed
28
        # -> thus a separate class boolean_cast_func is introduced
29
30
31
32
        if isinstance(args[0], Boolean):
            cls = boolean_cast_func
        return sp.Function.__new__(cls, *args, **kwargs)

33
34
35
36
37
38
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
39

40
41
42
43
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

Martin Bauer's avatar
Martin Bauer committed
44
45
46
47
48
    @property
    def dtype(self):
        return self.args[1]


49
50
51
52
53
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
    pass


Martin Bauer's avatar
Martin Bauer committed
54
55
56
57
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
    nargs = (4,)

58

Martin Bauer's avatar
Martin Bauer committed
59
60
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
61
62
63
64
65
66
67
68
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


69
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
70
71
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
72
73
        return obj

74
    def __new_stage2__(cls, name, dtype):
75
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
76
        try:
Martin Bauer's avatar
Martin Bauer committed
77
            obj._dtype = create_type(dtype)
78
79
80
        except TypeError:
            # on error keep the string
            obj._dtype = dtype
81
82
83
84
85
86
87
88
89
90
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
91
        return super()._hashable_content(), hash(self._dtype)
92
93

    def __getnewargs__(self):
94
95
96
        return self.name, self.dtype


Martin Bauer's avatar
Martin Bauer committed
97
def create_type(specification):
Martin Bauer's avatar
Martin Bauer committed
98
99
100
101
102
103
104
    """Creates a subclass of Type according to a string or an object of subclass Type.

    Args:
        specification: Type object, or a string

    Returns:
        Type object, or a new Type object parsed from the string
Jan Hoenig's avatar
Jan Hoenig committed
105
    """
106
107
108
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
109
110
111
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
112
        else:
Martin Bauer's avatar
Martin Bauer committed
113
            return StructType(numpy_dtype, const=False)
114
115


116
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
117
def create_composite_type_from_string(specification):
Martin Bauer's avatar
Martin Bauer committed
118
119
120
121
122
123
124
    """Creates a new Type object from a c-like string specification.

    Args:
        specification: Specification string

    Returns:
        Type object
Jan Hoenig's avatar
Jan Hoenig committed
125
    """
126
127
128
129
130
131
132
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
133
        else:
134
135
136
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
137
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
138
    base_part = parts.pop(0)
139
    const = False
Martin Bauer's avatar
Martin Bauer committed
140
    if 'const' in base_part:
141
        const = True
Martin Bauer's avatar
Martin Bauer committed
142
143
144
145
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
146
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
147
    current_type = BasicType(np.dtype(base_part[0]), const)
148
149
150
151
152
153
154
155
156
157
158
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
159
160
        current_type = PointerType(current_type, const, restrict)
    return current_type
161
162


Martin Bauer's avatar
Martin Bauer committed
163
164
165
166
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
167
168


Martin Bauer's avatar
Martin Bauer committed
169
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
170
171
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
172
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
173
174
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
175
176
177
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
178
        return ctypes.POINTER(ctypes.c_uint8)
179
    else:
Martin Bauer's avatar
Martin Bauer committed
180
        return to_ctypes.map[data_type.numpy_dtype]
181

Martin Bauer's avatar
Martin Bauer committed
182
183

to_ctypes.map = {
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


199
def ctypes_from_llvm(data_type):
200
201
    if not ir:
        raise _ir_importerror
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
235
236
    if not ir:
        raise _ir_importerror
237
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
238
        return to_llvm_type(data_type.base_type).as_pointer()
239
    else:
Martin Bauer's avatar
Martin Bauer committed
240
241
        return to_llvm_type.map[data_type.numpy_dtype]

242

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
258

259

Martin Bauer's avatar
Martin Bauer committed
260
261
262
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
263
264
265
    return dtype


Martin Bauer's avatar
Martin Bauer committed
266
def collate_types(types):
267
268
269
270
271
272
273
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
274
        pointer_type = None
275
276
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
277
                if pointer_type is not None:
278
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
279
                pointer_type = t
280
281
282
283
284
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
285
        return pointer_type
286
287

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
288
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
289
    if not all_equal(t.width for t in vector_type):
290
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
291
    types = [peel_off_type(t, VectorType) for t in types]
292
293
294
295

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

296
297
    if any(t.is_float() for t in types):
        types = tuple(t for t in types if t.is_float())
298
    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
299
300
301
302
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
303
304
305
306
    return result


@memorycache(maxsize=2048)
Martin Bauer's avatar
Martin Bauer committed
307
def get_type_of_expression(expr):
308
309
310
    from pystencils.astnodes import ResolvedFieldAccess
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
311
        return create_type("int")
312
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
Martin Bauer's avatar
Martin Bauer committed
313
        return create_type("double")
314
315
316
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
317
        return expr.dtype
318
    elif isinstance(expr, sp.Symbol):
319
        raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
Martin Bauer's avatar
Martin Bauer committed
320
    elif isinstance(expr, cast_func):
321
322
        return expr.args[1]
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
323
324
325
326
327
        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
328
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
329
330
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
331
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
332
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
333
334
335
336
        result = create_type("bool")
        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
337
        return result
338
339
    elif isinstance(expr, sp.Pow):
        return get_type_of_expression(expr.args[0])
340
    elif isinstance(expr, sp.Expr):
Martin Bauer's avatar
Martin Bauer committed
341
342
        types = tuple(get_type_of_expression(a) for a in expr.args)
        return collate_types(types)
343

344
    raise NotImplementedError("Could not determine type for", expr, type(expr))
345
346


347
class Type(sp.Basic):
Martin Bauer's avatar
Martin Bauer committed
348
349
    is_Atom = True

350
351
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
352

353
354
    def _sympystr(self, *args, **kwargs):
        return str(self)
355
356
357
358


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
359
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
360
361
362
363
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
364
365
366
367
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
368
            width = int(name[len("uint"):])
369
370
371
372
373
374
375
376
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
            raise NotImplemented("Can map numpy to C name for %s" % (name,))

    def __init__(self, dtype, const=False):
        self.const = const
377
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
378
            self._dtype = dtype.numpy_dtype
379
380
        else:
            self._dtype = np.dtype(dtype)
381
382
383
384
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

385
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
386
        return self.numpy_dtype, self.const
387

388
    @property
Martin Bauer's avatar
Martin Bauer committed
389
    def base_type(self):
390
        return None
391

392
    @property
Martin Bauer's avatar
Martin Bauer committed
393
    def numpy_dtype(self):
394
395
        return self._dtype

396
    @property
Martin Bauer's avatar
Martin Bauer committed
397
    def item_size(self):
398
399
        return 1

400
    def is_int(self):
Martin Bauer's avatar
Martin Bauer committed
401
        return self.numpy_dtype in np.sctypes['int']
402
403

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
404
        return self.numpy_dtype in np.sctypes['float']
405
406

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
407
        return self.numpy_dtype in np.sctypes['uint']
408

Martin Bauer's avatar
Martin Bauer committed
409
410
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
411
412

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
413
        return self.numpy_dtype in np.sctypes['others']
414

415
    @property
Martin Bauer's avatar
Martin Bauer committed
416
417
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
418

Jan Hoenig's avatar
Jan Hoenig committed
419
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
420
        result = BasicType.numpy_name_to_c(str(self._dtype))
421
422
423
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
424

425
426
427
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
428
    def __eq__(self, other):
429
430
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
431
        else:
Martin Bauer's avatar
Martin Bauer committed
432
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
433
434
435
436
437

    def __hash__(self):
        return hash(str(self))


438
class VectorType(Type):
Martin Bauer's avatar
Martin Bauer committed
439
    instruction_set = None
440

Martin Bauer's avatar
Martin Bauer committed
441
442
    def __init__(self, base_type, width=4):
        self._base_type = base_type
443
444
445
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
446
447
    def base_type(self):
        return self._base_type
448
449

    @property
Martin Bauer's avatar
Martin Bauer committed
450
451
    def item_size(self):
        return self.width * self.base_type.item_size
452
453
454
455
456

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
457
            return (self.base_type, self.width) == (other.base_type, other.width)
458
459

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
460
        if self.instruction_set is None:
Martin Bauer's avatar
Martin Bauer committed
461
            return "%s[%d]" % (self.base_type, self.width)
462
        else:
Martin Bauer's avatar
Martin Bauer committed
463
            if self.base_type == create_type("int64"):
Martin Bauer's avatar
Martin Bauer committed
464
                return self.instruction_set['int']
Martin Bauer's avatar
Martin Bauer committed
465
            elif self.base_type == create_type("float64"):
Martin Bauer's avatar
Martin Bauer committed
466
                return self.instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
467
            elif self.base_type == create_type("float32"):
Martin Bauer's avatar
Martin Bauer committed
468
                return self.instruction_set['float']
Martin Bauer's avatar
Martin Bauer committed
469
            elif self.base_type == create_type("bool"):
Martin Bauer's avatar
Martin Bauer committed
470
                return self.instruction_set['bool']
471
472
473
474
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
475
        return hash((self.base_type, self.width))
476

Martin Bauer's avatar
Martin Bauer committed
477
478
479
    def __getnewargs__(self):
        return self._base_type, self.width

480

481
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
482
483
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
484
485
486
        self.const = const
        self.restrict = restrict

487
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
488
        return self.base_type, self.const, self.restrict
489

490
491
492
493
494
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
495
496
    def base_type(self):
        return self._base_type
497

498
    @property
Martin Bauer's avatar
Martin Bauer committed
499
500
    def item_size(self):
        return self.base_type.item_size
501

502
503
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
504
            return False
505
        else:
Martin Bauer's avatar
Martin Bauer committed
506
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
507

Jan Hoenig's avatar
Jan Hoenig committed
508
    def __str__(self):
509
510
511
512
513
514
        components = [str(self.base_type), '*']
        if self.restrict:
            components.append('RESTRICT')
        if self.const:
            components.append("const")
        return " ".join(components)
515

516
517
518
    def __repr__(self):
        return str(self)

519
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
520
        return hash((self._base_type, self.const, self.restrict))
521

Jan Hoenig's avatar
Jan Hoenig committed
522

523
class StructType:
Martin Bauer's avatar
Martin Bauer committed
524
    def __init__(self, numpy_type, const=False):
525
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
526
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
527

528
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
529
        return self.numpy_dtype, self.const
530

531
    @property
Martin Bauer's avatar
Martin Bauer committed
532
    def base_type(self):
533
534
535
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
536
    def numpy_dtype(self):
537
538
539
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
540
541
    def item_size(self):
        return self.numpy_dtype.itemsize
542

Martin Bauer's avatar
Martin Bauer committed
543
544
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
545

Martin Bauer's avatar
Martin Bauer committed
546
547
548
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
549

Martin Bauer's avatar
Martin Bauer committed
550
551
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
552

553
554
555
556
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
557
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
558
559
560
561
562
563
564
565

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

566
567
568
    def __repr__(self):
        return str(self)

569
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
570
        return hash((self.numpy_dtype, self.const))