types.py 6.34 KB
Newer Older
1
import ctypes
2
import sympy as sp
3
import numpy as np
4
5
6
7
8
9
10
11
12
from sympy.core.cache import cacheit


class TypedSymbol(sp.Symbol):

    def __new__(cls, name, *args, **kwds):
        obj = TypedSymbol.__xnew_cached_(cls, name, *args, **kwds)
        return obj

13
    def __new_stage2__(cls, name, dtype, castTo=None):
14
        obj = super(TypedSymbol, cls).__xnew__(cls, name)
15
        obj._dtype = createType(dtype)
16
        obj.castTo = castTo
17
18
19
20
21
22
23
24
25
26
27
        return obj

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    @property
    def dtype(self):
        return self._dtype

    def _hashable_content(self):
        superClassContents = list(super(TypedSymbol, self)._hashable_content())
28
        t = tuple(superClassContents + [hash(repr(self._dtype) + repr(self.castTo))])
29
        return t
30
31

    def __getnewargs__(self):
32
        return self.name, self.dtype, self.castTo
33

34

35
36
37
38
39
40
41
def createType(specification):
    if isinstance(specification, Type):
        return specification
    elif isinstance(specification, str):
        return createTypeFromString(specification)
    else:
        npDataType = np.dtype(specification)
42
43
44
45
        if npDataType.fields is None:
            return BasicType(npDataType, const=False)
        else:
            return StructType(npDataType, const=False)
46
47
48
49
50
51
52
53
54
55


def createTypeFromString(specification):
    specification = specification.lower().split()
    parts = []
    current = []
    for s in specification:
        if s == '*':
            parts.append(current)
            current = [s]
56
        else:
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
            current.append(s)
    if len(current) > 0:
        parts.append(current)

    # Parse native part
    basePart = parts.pop(0)
    const = False
    if 'const' in basePart:
        const = True
        basePart.remove('const')
    assert len(basePart) == 1
    baseType = BasicType(basePart[0], const)

    currentType = baseType
    # Parse pointer parts
    for part in parts:
        restrict = False
        const = False
        if 'restrict' in part:
            restrict = True
            part.remove('restrict')
        if 'const' in part:
            const = True
            part.remove("const")
        assert len(part) == 1 and part[0] == '*'
        currentType = PointerType(currentType, const, restrict)
    return currentType


def getBaseType(type):
    while type.baseType is not None:
        type = type.baseType
    return type


def toCtypes(dataType):
    if isinstance(dataType, PointerType):
        return ctypes.POINTER(toCtypes(dataType.baseType))
95
96
    elif isinstance(dataType, StructType):
        return ctypes.POINTER(ctypes.c_uint8)
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    else:
        return toCtypes.map[dataType.numpyDtype]

toCtypes.map = {
    np.dtype(np.int8): ctypes.c_int8,
    np.dtype(np.int16): ctypes.c_int16,
    np.dtype(np.int32): ctypes.c_int32,
    np.dtype(np.int64): ctypes.c_int64,

    np.dtype(np.uint8): ctypes.c_uint8,
    np.dtype(np.uint16): ctypes.c_uint16,
    np.dtype(np.uint32): ctypes.c_uint32,
    np.dtype(np.uint64): ctypes.c_uint64,

    np.dtype(np.float32): ctypes.c_float,
    np.dtype(np.float64): ctypes.c_double,
}


class Type(object):
    pass


class BasicType(Type):
    @staticmethod
    def numpyNameToC(name):
        if name == 'float64': return 'double'
        elif name == 'float32': return 'float'
        elif name.startswith('int'):
            width = int(name[len("int"):])
            return "int%d_t" % (width,)
        elif name.startswith('uint'):
129
            width = int(name[len("uint"):])
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
            return "uint%d_t" % (width,)
        elif name == 'bool':
            return 'bool'
        else:
            raise NotImplemented("Can map numpy to C name for %s" % (name,))

    def __init__(self, dtype, const=False):
        self.const = const
        self._dtype = np.dtype(dtype)
        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
        assert self._dtype.hasobject is False
        assert self._dtype.subdtype is None

    @property
    def baseType(self):
        return None
146

147
148
149
150
    @property
    def numpyDtype(self):
        return self._dtype

151
152
153
154
    @property
    def itemSize(self):
        return 1

155
156
157
158
159
    def __str__(self):
        result = BasicType.numpyNameToC(str(self._dtype))
        if self.const:
            result += " const"
        return result
Jan Hoenig's avatar
Jan Hoenig committed
160
161

    def __eq__(self, other):
162
163
        if not isinstance(other, BasicType):
            return False
Jan Hoenig's avatar
Jan Hoenig committed
164
        else:
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
            return (self.numpyDtype, self.const) == (other.numpyDtype, other.const)

    def __hash__(self):
        return hash(str(self))


class PointerType(Type):
    def __init__(self, baseType, const=False, restrict=True):
        self._baseType = baseType
        self.const = const
        self.restrict = restrict

    @property
    def alias(self):
        return not self.restrict

    @property
    def baseType(self):
        return self._baseType

185
186
187
188
    @property
    def itemSize(self):
        return self.baseType.itemSize

189
190
    def __eq__(self, other):
        if not isinstance(other, PointerType):
Jan Hoenig's avatar
Jan Hoenig committed
191
            return False
192
193
194
195
        else:
            return (self.baseType, self.const, self.restrict) == (other.baseType, other.const, other.restrict)

    def __str__(self):
196
        return "%s *%s%s" % (self.baseType, " RESTRICT" if self.restrict else "", " const" if self.const else "")
197
198
199
200

    def __hash__(self):
        return hash(str(self))

Jan Hoenig's avatar
Jan Hoenig committed
201

202
class StructType(object):
203
204
    def __init__(self, numpyType, const=False):
        self.const = const
205
        self._dtype = np.dtype(numpyType)
Martin Bauer's avatar
Martin Bauer committed
206

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    @property
    def baseType(self):
        return None

    @property
    def numpyDtype(self):
        return self._dtype

    @property
    def itemSize(self):
        return self.numpyDtype.itemsize

    def getElementOffset(self, elementName):
        return self.numpyDtype.fields[elementName][1]

    def getElementType(self, elementName):
        npElementType = self.numpyDtype.fields[elementName][0]
        return BasicType(npElementType, self.const)

    def __eq__(self, other):
        if not isinstance(other, StructType):
            return False
        else:
            return (self.numpyDtype, self.const) == (other.numpyDtype, other.const)

    def __str__(self):
        # structs are handled byte-wise
        result = "uint8_t"
        if self.const:
            result += " const"
        return result

    def __hash__(self):
        return hash((self.numpyDtype, self.const))