diff --git a/astnodes.py b/astnodes.py index 2ea3de40a0d35dd479cd1c61b52dd398706e55c6..565584a03d8b42a6989e324fdf98d0de0e5fc8ec 100644 --- a/astnodes.py +++ b/astnodes.py @@ -487,6 +487,17 @@ class Indexed(Expr): return '%s[%s]' % (self.args[0], self.args[1]) +class PointerArithmetic(Expr): + def __init__(self, args, pointer, parent=None): + super(PointerArithmetic, self).__init__([args] + [pointer], parent) + self.pointer = pointer + self.offset = args + self.dtype = pointer.dtype + + def __repr__(self): + return '*(%s + %s)' % (self.pointer, self.args) + + class Number(Node, sp.AtomicExpr): def __init__(self, number, parent=None): super(Number, self).__init__(parent) diff --git a/backends/llvm.py b/backends/llvm.py index 144f7a23ef550da28fc488ef3f14ce24ba60e6f0..6bbaad3c661daaca856ce0231c08db6f0f63c3fc 100644 --- a/backends/llvm.py +++ b/backends/llvm.py @@ -188,6 +188,11 @@ class LLVMPrinter(Printer): gep = self.builder.gep(ptr, [index]) return self.builder.load(gep, name=indexed.base.label.name) + def _print_PointerArithmetic(self, pointer): + ptr = self._print(pointer.pointer) + index = self._print(pointer.offset) + return self.builder.gep(ptr, [index]) + # Should have a list of math library functions to validate this. # TODO function calls def _print_Function(self, expr): @@ -202,4 +207,4 @@ class LLVMPrinter(Printer): def emptyPrinter(self, expr): raise TypeError("Unsupported type for LLVM JIT conversion: %s %s" - % type(expr), expr) + % (type(expr), expr)) diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index 8c520cff33199184458de651150c9b86c3287a52..7ac669686cefef329bc9213388b66b48a101f47c 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -60,9 +60,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos) moveConstantsBeforeLoop(code) - # print(code) + print(code) desympy_ast(code) - # print(code) + print(code) insert_casts(code) return code diff --git a/transformations.py b/transformations.py index 3c94329520c8699e2befe35c483bfede2c2352b1..cd3556d5c154c13dcf45988b0b5e84abdf37f1cf 100644 --- a/transformations.py +++ b/transformations.py @@ -549,23 +549,38 @@ def insert_casts(node): :param node: ast which should be traversed :return: node """ + def conversion(args): + target = args[0] + if isinstance(target.dtype, PointerType): + # Pointer arithmetic + for arg in args[1:]: + # Check validness + if not arg.dtype.is_int() and not arg.dtype.is_uint(): + raise ValueError("Impossible pointer arithmetic", target, arg) + pointer = ast.PointerArithmetic(ast.Add(args[1:]), target) + return [pointer] + + else: + for i in range(len(args)): + if args[i].dtype != target.dtype: + args[i] = ast.Conversion(args[i], target.dtype, node) + return args + for arg in node.args: insert_casts(arg) if isinstance(node, ast.Indexed): #TODO revmove this pass elif isinstance(node, ast.Expr): - print(node.args) - print([type(arg) for arg in node.args]) - print([arg.dtype for arg in node.args]) + #print(node, node.args) + #print([type(arg) for arg in node.args]) + #print([arg.dtype for arg in node.args]) args = sorted((arg for arg in node.args), key=attrgetter('dtype')) target = args[0] - for i in range(len(args)): - if args[i].dtype != target.dtype: - args[i] = ast.Conversion(args[i], target.dtype, node) - node.args = args + node.args = conversion(args) node.dtype = target.dtype - print(node.dtype) + #print(node.dtype) + #print(node) elif isinstance(node, ast.SympyAssignment): if node.lhs.dtype != node.rhs.dtype: node.replace(node.rhs, ast.Conversion(node.rhs, node.lhs.dtype)) @@ -600,7 +615,8 @@ def desympy_ast(node): #elif isinstance(arg, sp.containers.Tuple): # else: - print('Not transforming:', type(arg), arg) + #print('Not transforming:', type(arg), arg) + pass for arg in node.args: desympy_ast(arg) return node @@ -616,8 +632,9 @@ def check_dtype(node): elif isinstance(node, ast.SympyAssignment): pass else: - print(node) - print(node.dtype) + #print(node) + #print(node.dtype) + pass for arg in node.args: check_dtype(arg) diff --git a/types.py b/types.py index fbf39ea3c55d07d1c0e1872415d659d2cbb5a4b6..702a1cc027c1a2155ca33dd139e1f7edd8b3e587 100644 --- a/types.py +++ b/types.py @@ -160,12 +160,12 @@ class Type(object): if isinstance(other, BasicType): return self.numpyDtype < other.numpyDtype # TODO const if isinstance(other, PointerType): - return True # TODO test + return False # TODO test if isinstance(other, StructType): raise NotImplementedError("Struct type comparison is not yet implemented") if isinstance(self, PointerType): if isinstance(other, BasicType): - return False # TODO test + return True # TODO test if isinstance(other, PointerType): return self.baseType < other.baseType # TODO const, restrict if isinstance(other, StructType): @@ -207,6 +207,21 @@ class BasicType(Type): def numpyDtype(self): return self._dtype + def is_int(self): + return self.numpyDtype in np.sctypes['int'] + + def is_float(self): + return self.numpyDtype in np.sctypes['float'] + + def is_uint(self): + return self.numpyDtype in np.sctypes['uint'] + + def is_comlex(self): + return self.numpyDtype in np.sctypes['complex'] + + def is_other(self): + return self.numpyDtype in np.sctypes['others'] + def __repr__(self): result = BasicType.numpyNameToC(str(self._dtype)) if self.const: