data_types.py 15.8 KB
Newer Older
1
import ctypes
2
import sympy as sp
3
import numpy as np
4
5
6
7
8
try:
    import llvmlite.ir as ir
except ImportError as e:
    ir = None
    _ir_importerror = e
9
from sympy.core.cache import cacheit
10

11
from pystencils.cache import memorycache
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.utils import all_equal
13

14
15

# to work in conditions of sp.Piecewise castFunc has to be of type Relational as well
Martin Bauer's avatar
Martin Bauer committed
16
class cast_func(sp.Function, sp.Rel):
17
18
19
20
21
22
    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()
23

24
25
26
27
    @property
    def is_commutative(self):
        return self.args[0].is_commutative

28

Martin Bauer's avatar
Martin Bauer committed
29
class pointer_arithmetic_func(sp.Function, sp.Rel):
30
31
32
33
34
35
36
37
38

    @property
    def canonical(self):
        if hasattr(self.args[0], 'canonical'):
            return self.args[0].canonical
        else:
            raise NotImplementedError()


39
class TypedSymbol(sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
40
41
    def __new__(cls, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
42
43
        return obj

44
    def __new_stage2__(cls, name, dtype):
45
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
46
        try:
Martin Bauer's avatar
Martin Bauer committed
47
            obj._dtype = create_type(dtype)
48
49
50
        except TypeError:
            # on error keep the string
            obj._dtype = dtype
51
52
53
54
55
56
57
58
59
60
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
Martin Bauer's avatar
Martin Bauer committed
61
62
        super_class_contents = list(super(TypedSymbol, self)._hashable_content())
        return tuple(super_class_contents + [hash(self._dtype)])
63
64

    def __getnewargs__(self):
65
66
67
        return self.name, self.dtype


Martin Bauer's avatar
Martin Bauer committed
68
def create_type(specification):
Jan Hoenig's avatar
Jan Hoenig committed
69
70
71
72
73
    """
    Create a subclass of Type according to a string or an object of subclass Type
    :param specification: Type object, or a string
    :return: Type object, or a new Type object parsed from the string
    """
74
75
76
    if isinstance(specification, Type):
        return specification
    else:
Martin Bauer's avatar
Martin Bauer committed
77
78
79
        numpy_dtype = np.dtype(specification)
        if numpy_dtype.fields is None:
            return BasicType(numpy_dtype, const=False)
80
        else:
Martin Bauer's avatar
Martin Bauer committed
81
            return StructType(numpy_dtype, const=False)
82
83


84
@memorycache(maxsize=64)
Martin Bauer's avatar
Martin Bauer committed
85
def create_composite_type_from_string(specification):
Jan Hoenig's avatar
Jan Hoenig committed
86
87
88
89
90
    """
    Creates a new Type object from a c-like string specification
    :param specification: Specification string
    :return: Type object
    """
91
92
93
94
95
96
97
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
98
        else:
99
100
101
            current.append(s)
    if len(current) > 0:
        parts.append(current)
Jan Hoenig's avatar
Jan Hoenig committed
102
        # Parse native part
Martin Bauer's avatar
Martin Bauer committed
103
    base_part = parts.pop(0)
104
    const = False
Martin Bauer's avatar
Martin Bauer committed
105
    if 'const' in base_part:
106
        const = True
Martin Bauer's avatar
Martin Bauer committed
107
108
109
110
        base_part.remove('const')
    assert len(base_part) == 1
    if base_part[0][-1] == "*":
        base_part[0] = base_part[0][:-1]
Jan Hoenig's avatar
Jan Hoenig committed
111
        parts.append('*')
Martin Bauer's avatar
Martin Bauer committed
112
    current_type = BasicType(np.dtype(base_part[0]), const)
113
114
115
116
117
118
119
120
121
122
123
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
Martin Bauer's avatar
Martin Bauer committed
124
125
        current_type = PointerType(current_type, const, restrict)
    return current_type
126
127


Martin Bauer's avatar
Martin Bauer committed
128
129
130
131
def get_base_type(data_type):
    while data_type.base_type is not None:
        data_type = data_type.base_type
    return data_type
132
133


Martin Bauer's avatar
Martin Bauer committed
134
def to_ctypes(data_type):
Jan Hoenig's avatar
Jan Hoenig committed
135
136
    """
    Transforms a given Type into ctypes
Martin Bauer's avatar
Martin Bauer committed
137
    :param data_type: Subclass of Type
Jan Hoenig's avatar
Jan Hoenig committed
138
139
    :return: ctypes type object
    """
Martin Bauer's avatar
Martin Bauer committed
140
141
142
    if isinstance(data_type, PointerType):
        return ctypes.POINTER(to_ctypes(data_type.base_type))
    elif isinstance(data_type, StructType):
143
        return ctypes.POINTER(ctypes.c_uint8)
144
    else:
Martin Bauer's avatar
Martin Bauer committed
145
        return to_ctypes.map[data_type.numpy_dtype]
146

Martin Bauer's avatar
Martin Bauer committed
147
148

to_ctypes.map = {
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


164
def ctypes_from_llvm(data_type):
165
166
    if not ir:
        raise _ir_importerror
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    if isinstance(data_type, ir.PointerType):
        ctype = ctypes_from_llvm(data_type.pointee)
        if ctype is None:
            return ctypes.c_void_p
        else:
            return ctypes.POINTER(ctype)
    elif isinstance(data_type, ir.IntType):
        if data_type.width == 8:
            return ctypes.c_int8
        elif data_type.width == 16:
            return ctypes.c_int16
        elif data_type.width == 32:
            return ctypes.c_int32
        elif data_type.width == 64:
            return ctypes.c_int64
        else:
            raise ValueError("Int width %d is not supported" % data_type.width)
    elif isinstance(data_type, ir.FloatType):
        return ctypes.c_float
    elif isinstance(data_type, ir.DoubleType):
        return ctypes.c_double
    elif isinstance(data_type, ir.VoidType):
        return None  # Void type is not supported by ctypes
    else:
        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))


def to_llvm_type(data_type):
    """
    Transforms a given type into ctypes
    :param data_type: Subclass of Type
    :return: llvmlite type object
    """
200
201
    if not ir:
        raise _ir_importerror
202
    if isinstance(data_type, PointerType):
Martin Bauer's avatar
Martin Bauer committed
203
        return to_llvm_type(data_type.base_type).as_pointer()
204
    else:
Martin Bauer's avatar
Martin Bauer committed
205
206
        return to_llvm_type.map[data_type.numpy_dtype]

207

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
if ir:
    to_llvm_type.map = {
        np.dtype(np.int8): ir.IntType(8),
        np.dtype(np.int16): ir.IntType(16),
        np.dtype(np.int32): ir.IntType(32),
        np.dtype(np.int64): ir.IntType(64),

        np.dtype(np.uint8): ir.IntType(8),
        np.dtype(np.uint16): ir.IntType(16),
        np.dtype(np.uint32): ir.IntType(32),
        np.dtype(np.uint64): ir.IntType(64),

        np.dtype(np.float32): ir.FloatType(),
        np.dtype(np.float64): ir.DoubleType(),
    }
223

224

Martin Bauer's avatar
Martin Bauer committed
225
226
227
def peel_off_type(dtype, type_to_peel_off):
    while type(dtype) is type_to_peel_off:
        dtype = dtype.base_type
228
229
230
    return dtype


Martin Bauer's avatar
Martin Bauer committed
231
def collate_types(types):
232
233
234
235
236
237
238
    """
    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
    Uses the collation rules from numpy.
    """

    # Pointer arithmetic case i.e. pointer + integer is allowed
    if any(type(t) is PointerType for t in types):
Martin Bauer's avatar
Martin Bauer committed
239
        pointer_type = None
240
241
        for t in types:
            if type(t) is PointerType:
Martin Bauer's avatar
Martin Bauer committed
242
                if pointer_type is not None:
243
                    raise ValueError("Cannot collate the combination of two pointer types")
Martin Bauer's avatar
Martin Bauer committed
244
                pointer_type = t
245
246
247
248
249
            elif type(t) is BasicType:
                if not (t.is_int() or t.is_uint()):
                    raise ValueError("Invalid pointer arithmetic")
            else:
                raise ValueError("Invalid pointer arithmetic")
Martin Bauer's avatar
Martin Bauer committed
250
        return pointer_type
251
252

    # peel of vector types, if at least one vector type occurred the result will also be the vector type
Martin Bauer's avatar
Martin Bauer committed
253
    vector_type = [t for t in types if type(t) is VectorType]
Martin Bauer's avatar
Martin Bauer committed
254
    if not all_equal(t.width for t in vector_type):
255
        raise ValueError("Collation failed because of vector types with different width")
Martin Bauer's avatar
Martin Bauer committed
256
    types = [peel_off_type(t, VectorType) for t in types]
257
258
259
260
261

    # now we should have a list of basic types - struct types are not yet supported
    assert all(type(t) is BasicType for t in types)

    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
Martin Bauer's avatar
Martin Bauer committed
262
263
264
265
    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
    result = BasicType(result_numpy_type)
    if vector_type:
        result = VectorType(result, vector_type[0].width)
266
267
268
269
    return result


@memorycache(maxsize=2048)
Martin Bauer's avatar
Martin Bauer committed
270
def get_type_of_expression(expr):
271
272
273
    from pystencils.astnodes import ResolvedFieldAccess
    expr = sp.sympify(expr)
    if isinstance(expr, sp.Integer):
Martin Bauer's avatar
Martin Bauer committed
274
        return create_type("int")
275
    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
Martin Bauer's avatar
Martin Bauer committed
276
        return create_type("double")
277
278
279
    elif isinstance(expr, ResolvedFieldAccess):
        return expr.field.dtype
    elif isinstance(expr, TypedSymbol):
280
        return expr.dtype
281
282
    elif isinstance(expr, sp.Symbol):
        raise ValueError("All symbols inside this expression have to be typed!")
Martin Bauer's avatar
Martin Bauer committed
283
    elif hasattr(expr, 'func') and expr.func == cast_func:
284
285
        return expr.args[1]
    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
286
287
288
289
290
        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
        return collated_result_type
291
    elif isinstance(expr, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
292
293
        typed_symbol = expr.base.label
        return typed_symbol.dtype.base_type
294
    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
295
        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
Martin Bauer's avatar
Martin Bauer committed
296
297
298
299
        result = create_type("bool")
        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
        if vec_args:
            result = VectorType(result, width=vec_args[0].width)
300
        return result
301
    elif isinstance(expr, sp.Expr):
Martin Bauer's avatar
Martin Bauer committed
302
303
        types = tuple(get_type_of_expression(a) for a in expr.args)
        return collate_types(types)
304

305
    raise NotImplementedError("Could not determine type for", expr, type(expr))
306
307


308
309
310
class Type(sp.Basic):
    def __new__(cls, *args, **kwargs):
        return sp.Basic.__new__(cls)
311

312
313
    def _sympystr(self, *args, **kwargs):
        return str(self)
314
315
316
317


class BasicType(Type):
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
318
    def numpy_name_to_c(name):
Jan Hoenig's avatar
Jan Hoenig committed
319
320
321
322
        if name == 'float64':
            return 'double'
        elif name == 'float32':
            return 'float'
323
324
325
326
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
327
            width = int(name[len("uint"):])
328
329
330
331
332
333
334
335
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
            raise NotImplemented("Can map numpy to C name for %s" % (name,))

    def __init__(self, dtype, const=False):
        self.const = const
336
        if isinstance(dtype, Type):
Martin Bauer's avatar
Martin Bauer committed
337
            self._dtype = dtype.numpy_dtype
338
339
        else:
            self._dtype = np.dtype(dtype)
340
341
342
343
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

344
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
345
        return self.numpy_dtype, self.const
346

347
    @property
Martin Bauer's avatar
Martin Bauer committed
348
    def base_type(self):
349
        return None
350

351
    @property
Martin Bauer's avatar
Martin Bauer committed
352
    def numpy_dtype(self):
353
354
        return self._dtype

355
    @property
Martin Bauer's avatar
Martin Bauer committed
356
    def item_size(self):
357
358
        return 1

359
    def is_int(self):
Martin Bauer's avatar
Martin Bauer committed
360
        return self.numpy_dtype in np.sctypes['int']
361
362

    def is_float(self):
Martin Bauer's avatar
Martin Bauer committed
363
        return self.numpy_dtype in np.sctypes['float']
364
365

    def is_uint(self):
Martin Bauer's avatar
Martin Bauer committed
366
        return self.numpy_dtype in np.sctypes['uint']
367

Martin Bauer's avatar
Martin Bauer committed
368
369
    def is_complex(self):
        return self.numpy_dtype in np.sctypes['complex']
370
371

    def is_other(self):
Martin Bauer's avatar
Martin Bauer committed
372
        return self.numpy_dtype in np.sctypes['others']
373

374
    @property
Martin Bauer's avatar
Martin Bauer committed
375
376
    def base_name(self):
        return BasicType.numpy_name_to_c(str(self._dtype))
377

Jan Hoenig's avatar
Jan Hoenig committed
378
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
379
        result = BasicType.numpy_name_to_c(str(self._dtype))
380
381
382
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
383

384
385
386
    def __repr__(self):
        return str(self)

Jan Hoenig's avatar
Jan Hoenig committed
387
    def __eq__(self, other):
388
389
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
390
        else:
Martin Bauer's avatar
Martin Bauer committed
391
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
392
393
394
395
396

    def __hash__(self):
        return hash(str(self))


397
398
399
class VectorType(Type):
    instructionSet = None

Martin Bauer's avatar
Martin Bauer committed
400
401
    def __init__(self, base_type, width=4):
        self._base_type = base_type
402
403
404
        self.width = width

    @property
Martin Bauer's avatar
Martin Bauer committed
405
406
    def base_type(self):
        return self._base_type
407
408

    @property
Martin Bauer's avatar
Martin Bauer committed
409
410
    def item_size(self):
        return self.width * self.base_type.item_size
411
412
413
414
415

    def __eq__(self, other):
        if not isinstance(other, VectorType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
416
            return (self.base_type, self.width) == (other.base_type, other.width)
417
418
419

    def __str__(self):
        if self.instructionSet is None:
Martin Bauer's avatar
Martin Bauer committed
420
            return "%s[%d]" % (self.base_type, self.width)
421
        else:
Martin Bauer's avatar
Martin Bauer committed
422
            if self.base_type == create_type("int64"):
423
                return self.instructionSet['int']
Martin Bauer's avatar
Martin Bauer committed
424
            elif self.base_type == create_type("float64"):
425
                return self.instructionSet['double']
Martin Bauer's avatar
Martin Bauer committed
426
            elif self.base_type == create_type("float32"):
427
                return self.instructionSet['float']
Martin Bauer's avatar
Martin Bauer committed
428
            elif self.base_type == create_type("bool"):
429
                return self.instructionSet['bool']
430
431
432
433
            else:
                raise NotImplementedError()

    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
434
        return hash((self.base_type, self.width))
435
436


437
class PointerType(Type):
Martin Bauer's avatar
Martin Bauer committed
438
439
    def __init__(self, base_type, const=False, restrict=True):
        self._base_type = base_type
440
441
442
        self.const = const
        self.restrict = restrict

443
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
444
        return self.base_type, self.const, self.restrict
445

446
447
448
449
450
    @property
    def alias(self):
        return not self.restrict

    @property
Martin Bauer's avatar
Martin Bauer committed
451
452
    def base_type(self):
        return self._base_type
453

454
    @property
Martin Bauer's avatar
Martin Bauer committed
455
456
    def item_size(self):
        return self.base_type.item_size
457

458
459
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
460
            return False
461
        else:
Martin Bauer's avatar
Martin Bauer committed
462
            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
463

Jan Hoenig's avatar
Jan Hoenig committed
464
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
465
        return "%s *%s%s" % (self.base_type, " RESTRICT " if self.restrict else "", " const " if self.const else "")
466

467
468
469
    def __repr__(self):
        return str(self)

470
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
471
        return hash((self._base_type, self.const, self.restrict))
472

Jan Hoenig's avatar
Jan Hoenig committed
473

474
class StructType(object):
Martin Bauer's avatar
Martin Bauer committed
475
    def __init__(self, numpy_type, const=False):
476
        self.const = const
Martin Bauer's avatar
Martin Bauer committed
477
        self._dtype = np.dtype(numpy_type)
Martin Bauer's avatar
Martin Bauer committed
478

479
    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
480
        return self.numpy_dtype, self.const
481

482
    @property
Martin Bauer's avatar
Martin Bauer committed
483
    def base_type(self):
484
485
486
        return None

    @property
Martin Bauer's avatar
Martin Bauer committed
487
    def numpy_dtype(self):
488
489
490
        return self._dtype

    @property
Martin Bauer's avatar
Martin Bauer committed
491
492
    def item_size(self):
        return self.numpy_dtype.itemsize
493

Martin Bauer's avatar
Martin Bauer committed
494
495
    def get_element_offset(self, element_name):
        return self.numpy_dtype.fields[element_name][1]
496

Martin Bauer's avatar
Martin Bauer committed
497
498
499
    def get_element_type(self, element_name):
        np_element_type = self.numpy_dtype.fields[element_name][0]
        return BasicType(np_element_type, self.const)
500

Martin Bauer's avatar
Martin Bauer committed
501
502
    def has_element(self, element_name):
        return element_name in self.numpy_dtype.fields
503

504
505
506
507
    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
Martin Bauer's avatar
Martin Bauer committed
508
            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
509
510
511
512
513
514
515
516

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

517
518
519
    def __repr__(self):
        return str(self)

520
    def __hash__(self):
Martin Bauer's avatar
Martin Bauer committed
521
        return hash((self.numpy_dtype, self.const))