An error occurred while loading the file. Please try again.
-
Martin Bauer authored
- alignment of index array when using additional data - data types have been hashed incorrectly
af435b7f
data_types.py 18.10 KiB
import ctypes
import sympy as sp
import numpy as np
try:
import llvmlite.ir as ir
except ImportError as e:
ir = None
_ir_importerror = e
from sympy.core.cache import cacheit
from pystencils.cache import memorycache
from pystencils.utils import allEqual
# to work in conditions of sp.Piecewise castFunc has to be of type Relational as well
class castFunc(sp.Function, sp.Rel):
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
return self.args[0].canonical
else:
raise NotImplementedError()
@property
def is_commutative(self):
return self.args[0].is_commutative
class pointerArithmeticFunc(sp.Function, sp.Rel):
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
return self.args[0].canonical
else:
raise NotImplementedError()
class TypedSymbol(sp.Symbol):
def __new__(cls, *args, **kwds):
obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, name, dtype):
obj = super(TypedSymbol, cls).__xnew__(cls, name)
try:
obj._dtype = createType(dtype)
except TypeError:
# on error keep the string
obj._dtype = dtype
return obj
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
@property
def dtype(self):
return self._dtype
def _hashable_content(self):
superClassContents = list(super(TypedSymbol, self)._hashable_content())
return tuple(superClassContents + [hash(self._dtype)])
def __getnewargs__(self):
return self.name, self.dtype
def createType(specification):
"""
Create a subclass of Type according to a string or an object of subclass Type
:param specification: Type object, or a string
:return: Type object, or a new Type object parsed from the string
"""
if isinstance(specification, Type):
return specification
elif isinstance(specification, str):
return createTypeFromString(specification)
else:
npDataType = np.dtype(specification)
if npDataType.fields is None:
return BasicType(npDataType, const=False)
else:
return StructType(npDataType, const=False)
@memorycache(maxsize=64)
def createTypeFromString(specification):
"""
Creates a new Type object from a c-like string specification
:param specification: Specification string
:return: Type object
"""
specification = specification.lower().split()
parts = []
current = []
for s in specification:
if s == '*':
parts.append(current)
current = [s]
else:
current.append(s)
if len(current) > 0:
parts.append(current)
# Parse native part
basePart = parts.pop(0)
const = False
if 'const' in basePart:
const = True
basePart.remove('const')
assert len(basePart) == 1
if basePart[0][-1] == "*":
basePart[0] = basePart[0][:-1]
parts.append('*')
try:
baseType = BasicType(basePart[0], const)
except TypeError:
baseType = BasicType(createTypeFromString.map[basePart[0]], const)
currentType = baseType
# Parse pointer parts
for part in parts:
restrict = False
const = False
if 'restrict' in part:
restrict = True
part.remove('restrict')
if 'const' in part:
const = True
part.remove("const")
assert len(part) == 1 and part[0] == '*'
currentType = PointerType(currentType, const, restrict)
return currentType
createTypeFromString.map = {
'i64': np.int64,
'i32': np.int32,
'i16': np.int16,
'i8': np.int8,
}
def getBaseType(type):
while type.baseType is not None:
type = type.baseType
return type
def toCtypes(dataType):
"""
Transforms a given Type into ctypes
:param dataType: Subclass of Type
:return: ctypes type object
"""
if isinstance(dataType, PointerType):
return ctypes.POINTER(toCtypes(dataType.baseType))
elif isinstance(dataType, StructType):
return ctypes.POINTER(ctypes.c_uint8)
else:
return toCtypes.map[dataType.numpyDtype]
toCtypes.map = {
np.dtype(np.int8): ctypes.c_int8,
np.dtype(np.int16): ctypes.c_int16,
np.dtype(np.int32): ctypes.c_int32,
np.dtype(np.int64): ctypes.c_int64,
np.dtype(np.uint8): ctypes.c_uint8,
np.dtype(np.uint16): ctypes.c_uint16,
np.dtype(np.uint32): ctypes.c_uint32,
np.dtype(np.uint64): ctypes.c_uint64,
np.dtype(np.float32): ctypes.c_float,
np.dtype(np.float64): ctypes.c_double,
}
def ctypes_from_llvm(data_type):
if not ir:
raise _ir_importerror
if isinstance(data_type, ir.PointerType):
ctype = ctypes_from_llvm(data_type.pointee)
if ctype is None:
return ctypes.c_void_p
else:
return ctypes.POINTER(ctype)
elif isinstance(data_type, ir.IntType):
if data_type.width == 8:
return ctypes.c_int8
elif data_type.width == 16:
return ctypes.c_int16
elif data_type.width == 32:
return ctypes.c_int32
elif data_type.width == 64:
return ctypes.c_int64
else:
raise ValueError("Int width %d is not supported" % data_type.width)
elif isinstance(data_type, ir.FloatType):
return ctypes.c_float
elif isinstance(data_type, ir.DoubleType):
return ctypes.c_double
elif isinstance(data_type, ir.VoidType):
return None # Void type is not supported by ctypes
else:
raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))
def to_llvm_type(data_type):
"""
Transforms a given type into ctypes
:param data_type: Subclass of Type
:return: llvmlite type object
"""
if not ir:
raise _ir_importerror
if isinstance(data_type, PointerType):
return to_llvm_type(data_type.baseType).as_pointer()
else:
return to_llvm_type.map[data_type.numpyDtype]
if ir:
to_llvm_type.map = {
np.dtype(np.int8): ir.IntType(8),
np.dtype(np.int16): ir.IntType(16),
np.dtype(np.int32): ir.IntType(32),
np.dtype(np.int64): ir.IntType(64),
np.dtype(np.uint8): ir.IntType(8),
np.dtype(np.uint16): ir.IntType(16),
np.dtype(np.uint32): ir.IntType(32),
np.dtype(np.uint64): ir.IntType(64),
np.dtype(np.float32): ir.FloatType(),
np.dtype(np.float64): ir.DoubleType(),
}
def peelOffType(dtype, typeToPeelOff):
while type(dtype) is typeToPeelOff:
dtype = dtype.baseType
return dtype
def collateTypes(types):
"""
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy.
"""
# Pointer arithmetic case i.e. pointer + integer is allowed
if any(type(t) is PointerType for t in types):
pointerType = None
for t in types:
if type(t) is PointerType:
if pointerType is not None:
raise ValueError("Cannot collate the combination of two pointer types")
pointerType = t
elif type(t) is BasicType:
if not (t.is_int() or t.is_uint()):
raise ValueError("Invalid pointer arithmetic")
else:
raise ValueError("Invalid pointer arithmetic")
return pointerType
# peel of vector types, if at least one vector type occurred the result will also be the vector type
vectorType = [t for t in types if type(t) is VectorType]
if not allEqual(t.width for t in vectorType):
raise ValueError("Collation failed because of vector types with different width")
types = [peelOffType(t, VectorType) for t in types]
# now we should have a list of basic types - struct types are not yet supported
assert all(type(t) is BasicType for t in types)
# use numpy collation -> create type from numpy type -> and, put vector type around if necessary
resultNumpyType = np.result_type(*(t.numpyDtype for t in types))
result = BasicType(resultNumpyType)
if vectorType:
result = VectorType(result, vectorType[0].width)
return result
@memorycache(maxsize=2048)
def getTypeOfExpression(expr):
from pystencils.astnodes import ResolvedFieldAccess
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
return createTypeFromString("int")
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return createTypeFromString("double")
elif isinstance(expr, ResolvedFieldAccess):
return expr.field.dtype
elif isinstance(expr, TypedSymbol):
return expr.dtype
elif isinstance(expr, sp.Symbol):
raise ValueError("All symbols inside this expression have to be typed!")
elif hasattr(expr, 'func') and expr.func == castFunc:
return expr.args[1]
elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
collatedResultType = collateTypes(tuple(getTypeOfExpression(a[0]) for a in expr.args))
collatedConditionType = collateTypes(tuple(getTypeOfExpression(a[1]) for a in expr.args))
if type(collatedConditionType) is VectorType and type(collatedResultType) is not VectorType:
collatedResultType = VectorType(collatedResultType, width=collatedConditionType.width)
return collatedResultType
elif isinstance(expr, sp.Indexed):
typedSymbol = expr.base.label
return typedSymbol.dtype.baseType
elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = createTypeFromString("bool")
vecArgs = [getTypeOfExpression(a) for a in expr.args if isinstance(getTypeOfExpression(a), VectorType)]
if vecArgs:
result = VectorType(result, width=vecArgs[0].width)
return result
elif isinstance(expr, sp.Expr):
types = tuple(getTypeOfExpression(a) for a in expr.args)
return collateTypes(types)
raise NotImplementedError("Could not determine type for", expr, type(expr))
class Type(sp.Basic):
def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls)
def __lt__(self, other): # deprecated
# Needed for sorting the types inside an expression
if isinstance(self, BasicType):
if isinstance(other, BasicType):
return self.numpyDtype > other.numpyDtype # TODO const
elif isinstance(other, PointerType):
return False
else: # isinstance(other, StructType):
raise NotImplementedError("Struct type comparison is not yet implemented")
elif isinstance(self, PointerType):
if isinstance(other, BasicType):
return True
elif isinstance(other, PointerType):
return self.baseType > other.baseType # TODO const, restrict
else: # isinstance(other, StructType):
raise NotImplementedError("Struct type comparison is not yet implemented")
elif isinstance(self, StructType):
raise NotImplementedError("Struct type comparison is not yet implemented")
else:
raise NotImplementedError
def _sympystr(self, *args, **kwargs):
return str(self)
def _sympystr(self, *args, **kwargs):
return str(self)
class BasicType(Type):
@staticmethod
def numpyNameToC(name):
if name == 'float64':
return 'double'
elif name == 'float32':
return 'float'
elif name.startswith('int'):
width = int(name[len("int"):])
return "int%d_t" % (width,)
elif name.startswith('uint'):
width = int(name[len("uint"):])
return "uint%d_t" % (width,)
elif name == 'bool':
return 'bool'
else:
raise NotImplemented("Can map numpy to C name for %s" % (name,))
def __init__(self, dtype, const=False):
self.const = const
if isinstance(dtype, Type):
self._dtype = dtype.numpyDtype
else:
self._dtype = np.dtype(dtype)
assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
assert self._dtype.hasobject is False
assert self._dtype.subdtype is None
def __getnewargs__(self):
return self.numpyDtype, self.const
@property
def baseType(self):
return None
@property
def numpyDtype(self):
return self._dtype
@property
def itemSize(self):
return 1
def is_int(self):
return self.numpyDtype in np.sctypes['int']
def is_float(self):
return self.numpyDtype in np.sctypes['float']
def is_uint(self):
return self.numpyDtype in np.sctypes['uint']
def is_comlex(self):
return self.numpyDtype in np.sctypes['complex']
def is_other(self):
return self.numpyDtype in np.sctypes['others']
@property
def baseName(self):
return BasicType.numpyNameToC(str(self._dtype))
def __str__(self):
result = BasicType.numpyNameToC(str(self._dtype))
if self.const:
result += " const"
return result
def __repr__(self):
return str(self)
def __eq__(self, other):
if not isinstance(other, BasicType):
return False
else:
return (self.numpyDtype, self.const) == (other.numpyDtype, other.const)
def __hash__(self):
return hash(str(self))
class VectorType(Type):
instructionSet = None
def __init__(self, baseType, width=4):
self._baseType = baseType
self.width = width
@property
def baseType(self):
return self._baseType
@property
def itemSize(self):
return self.width * self.baseType.itemSize
def __eq__(self, other):
if not isinstance(other, VectorType):
return False
else:
return (self.baseType, self.width) == (other.baseType, other.width)
def __str__(self):
if self.instructionSet is None:
return "%s[%d]" % (self.baseType, self.width)
else:
if self.baseType == createTypeFromString("int64"):
return self.instructionSet['int']
elif self.baseType == createTypeFromString("double"):
return self.instructionSet['double']
elif self.baseType == createTypeFromString("float"):
return self.instructionSet['float']
elif self.baseType == createTypeFromString("bool"):
return self.instructionSet['bool']
else:
raise NotImplementedError()
def __hash__(self):
return hash((self.baseType, self.width))
class PointerType(Type):
def __init__(self, baseType, const=False, restrict=True):
self._baseType = baseType
self.const = const
self.restrict = restrict
def __getnewargs__(self):
return self.baseType, self.const, self.restrict
@property
def alias(self):
return not self.restrict
@property
def baseType(self):
return self._baseType
@property
def itemSize(self):
return self.baseType.itemSize
def __eq__(self, other):
if not isinstance(other, PointerType):
return False
else:
return (self.baseType, self.const, self.restrict) == (other.baseType, other.const, other.restrict)
def __str__(self):
return "%s *%s%s" % (self.baseType, " RESTRICT " if self.restrict else "", " const " if self.const else "")
def __repr__(self):
return str(self)
def __hash__(self):
return hash((self._baseType, self.const, self.restrict))
class StructType(object):
def __init__(self, numpyType, const=False):
self.const = const
self._dtype = np.dtype(numpyType)
def __getnewargs__(self):
return self.numpyDtype, self.const
@property
def baseType(self):
return None
@property
def numpyDtype(self):
return self._dtype
@property
def itemSize(self):
return self.numpyDtype.itemsize
def getElementOffset(self, elementName):
return self.numpyDtype.fields[elementName][1]
def getElementType(self, elementName):
npElementType = self.numpyDtype.fields[elementName][0]
return BasicType(npElementType, self.const)
def hasElement(self, elementName):
return elementName in self.numpyDtype.fields
def __eq__(self, other):
if not isinstance(other, StructType):
return False
else:
return (self.numpyDtype, self.const) == (other.numpyDtype, other.const)
def __str__(self):
# structs are handled byte-wise
result = "uint8_t"
if self.const:
result += " const"
return result
def __repr__(self):
return str(self)
def __hash__(self):
return hash((self.numpyDtype, self.const))
# TODO this should not work at all!!!
def __gt__(self, other):
if self.ptr and not other.ptr:
return True
if self.dtype > other.dtype:
return True
def get_type_from_sympy(node):
"""
Creates a Type object from a Sympy object
:param node: Sympy object
:return: Type object
"""
# Rational, NumberSymbol?
# Zero, One, NegativeOne )= Integer
# Half )= Rational
# NAN, Infinity, Negative Inifinity,
# Exp1, Imaginary Unit, Pi, EulerGamma, Catalan, Golden Ratio
# Pow, Mul, Add, Mod, Relational
if not isinstance(node, sp.Number):
raise TypeError(node, 'is not a sp.Number')
if isinstance(node, sp.Float) or isinstance(node, sp.RealNumber):
return createType('double'), float(node)
elif isinstance(node, sp.Integer):
return createType('int'), int(node)
elif isinstance(node, sp.Rational):
# TODO is it always float?
return createType('double'), float(node.p/node.q)
else:
raise TypeError(node, ' is not a supported type (yet)!')