Commit 56af5641 authored by Jan Hoenig's avatar Jan Hoenig
Browse files

LLVM-backend generates Object-files, whcih can be imported in C.

parent 39a096fa
......@@ -487,6 +487,17 @@ class Indexed(Expr):
return '%s[%s]' % (self.args[0], self.args[1])
class PointerArithmetic(Expr):
def __init__(self, args, pointer, parent=None):
super(PointerArithmetic, self).__init__([args] + [pointer], parent)
self.pointer = pointer
self.offset = args
self.dtype = pointer.dtype
def __repr__(self):
return '*(%s + %s)' % (self.pointer, self.args)
class Number(Node, sp.AtomicExpr):
def __init__(self, number, parent=None):
super(Number, self).__init__(parent)
......
......@@ -188,6 +188,11 @@ class LLVMPrinter(Printer):
gep = self.builder.gep(ptr, [index])
return self.builder.load(gep, name=indexed.base.label.name)
def _print_PointerArithmetic(self, pointer):
ptr = self._print(pointer.pointer)
index = self._print(pointer.offset)
return self.builder.gep(ptr, [index])
# Should have a list of math library functions to validate this.
# TODO function calls
def _print_Function(self, expr):
......@@ -202,4 +207,4 @@ class LLVMPrinter(Printer):
def emptyPrinter(self, expr):
raise TypeError("Unsupported type for LLVM JIT conversion: %s %s"
% type(expr), expr)
% (type(expr), expr))
......@@ -60,9 +60,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos)
moveConstantsBeforeLoop(code)
# print(code)
print(code)
desympy_ast(code)
# print(code)
print(code)
insert_casts(code)
return code
......@@ -549,23 +549,38 @@ def insert_casts(node):
:param node: ast which should be traversed
:return: node
"""
def conversion(args):
target = args[0]
if isinstance(target.dtype, PointerType):
# Pointer arithmetic
for arg in args[1:]:
# Check validness
if not arg.dtype.is_int() and not arg.dtype.is_uint():
raise ValueError("Impossible pointer arithmetic", target, arg)
pointer = ast.PointerArithmetic(ast.Add(args[1:]), target)
return [pointer]
else:
for i in range(len(args)):
if args[i].dtype != target.dtype:
args[i] = ast.Conversion(args[i], target.dtype, node)
return args
for arg in node.args:
insert_casts(arg)
if isinstance(node, ast.Indexed):
#TODO revmove this
pass
elif isinstance(node, ast.Expr):
print(node.args)
print([type(arg) for arg in node.args])
print([arg.dtype for arg in node.args])
#print(node, node.args)
#print([type(arg) for arg in node.args])
#print([arg.dtype for arg in node.args])
args = sorted((arg for arg in node.args), key=attrgetter('dtype'))
target = args[0]
for i in range(len(args)):
if args[i].dtype != target.dtype:
args[i] = ast.Conversion(args[i], target.dtype, node)
node.args = args
node.args = conversion(args)
node.dtype = target.dtype
print(node.dtype)
#print(node.dtype)
#print(node)
elif isinstance(node, ast.SympyAssignment):
if node.lhs.dtype != node.rhs.dtype:
node.replace(node.rhs, ast.Conversion(node.rhs, node.lhs.dtype))
......@@ -600,7 +615,8 @@ def desympy_ast(node):
#elif isinstance(arg, sp.containers.Tuple):
#
else:
print('Not transforming:', type(arg), arg)
#print('Not transforming:', type(arg), arg)
pass
for arg in node.args:
desympy_ast(arg)
return node
......@@ -616,8 +632,9 @@ def check_dtype(node):
elif isinstance(node, ast.SympyAssignment):
pass
else:
print(node)
print(node.dtype)
#print(node)
#print(node.dtype)
pass
for arg in node.args:
check_dtype(arg)
......@@ -160,12 +160,12 @@ class Type(object):
if isinstance(other, BasicType):
return self.numpyDtype < other.numpyDtype # TODO const
if isinstance(other, PointerType):
return True # TODO test
return False # TODO test
if isinstance(other, StructType):
raise NotImplementedError("Struct type comparison is not yet implemented")
if isinstance(self, PointerType):
if isinstance(other, BasicType):
return False # TODO test
return True # TODO test
if isinstance(other, PointerType):
return self.baseType < other.baseType # TODO const, restrict
if isinstance(other, StructType):
......@@ -207,6 +207,21 @@ class BasicType(Type):
def numpyDtype(self):
return self._dtype
def is_int(self):
return self.numpyDtype in np.sctypes['int']
def is_float(self):
return self.numpyDtype in np.sctypes['float']
def is_uint(self):
return self.numpyDtype in np.sctypes['uint']
def is_comlex(self):
return self.numpyDtype in np.sctypes['complex']
def is_other(self):
return self.numpyDtype in np.sctypes['others']
def __repr__(self):
result = BasicType.numpyNameToC(str(self._dtype))
if self.const:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment