Commit abb11958 authored by Jan Hönig's avatar Jan Hönig
Browse files

Implemented the support of llvm-backend for piecewise, eq, ne, lt, le, gt, gt.

Testcases test_mu_equivalence, test_phi_equivalence are passing.
parent fc7e815a
......@@ -507,53 +507,3 @@ class TemporaryMemoryFree(Node):
def args(self):
return []
#_expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'}
#
#
#class Expr(Node):
# def __init__(self, args, parent=None):
# super(Expr, self).__init__(parent)
# self._args = list(args)
# self.dtype = None
#
# @property
# def args(self):
# return self._args
#
# @args.setter
# def args(self, value):
# self._args = value
#
# def replace(self, child, replacements):
# idx = self.args.index(child)
# del self.args[idx]
# if type(replacements) is list:
# for e in replacements:
# e.parent = self
# self.args = self.args[:idx] + replacements + self.args[idx:]
# else:
# replacements.parent = self
# self.args.insert(idx, replacements)
#
# @property
# def symbolsDefined(self):
# return set() # Todo fix for symbol analysis
#
# @property
# def undefinedSymbols(self):
# return set() # Todo fix for symbol analysis
#
# def __repr__(self):
# return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args)
#
#
#class PointerArithmetic(Expr):
# def __init__(self, args, pointer, parent=None):
# super(PointerArithmetic, self).__init__([args] + [pointer], parent)
# self.pointer = pointer
# self.offset = args
# self.dtype = pointer.dtype
#
# def __repr__(self):
# return '*(%s + %s)' % (self.pointer, self.args)
......@@ -63,15 +63,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos)
moveConstantsBeforeLoop(code)
#print('Ast:')
#print(code)
#desympy_ast(code)
#print('Desympied ast:')
#print(code)
#insert_casts(code)
#print(code)
print(code)
code = insertCasts(code)
#print(code)
print(code)
return code
......
......@@ -6,8 +6,9 @@ from sympy import S
# S is numbers?
from pystencils.llvm.control_flow import Loop
from pystencils.data_types import createType, to_llvm_type, getTypeOfExpression
from sympy import Indexed # TODO used astnodes, this should not work!
from pystencils.data_types import createType, to_llvm_type, getTypeOfExpression, collateTypes
from sympy import Indexed
from sympy.codegen.ast import Assignment
def generateLLVM(ast_node, module=None, builder=None):
......@@ -19,11 +20,12 @@ def generateLLVM(ast_node, module=None, builder=None):
if builder is None:
builder = ir.IRBuilder()
printer = LLVMPrinter(module, builder)
return printer._print(ast_node) #TODO use doprint() instead???
return printer._print(ast_node) # TODO use doprint() instead???
class LLVMPrinter(Printer):
"""Convert expressions to LLVM IR"""
def __init__(self, module, builder, fn=None, *args, **kwargs):
self.func_arg_map = kwargs.pop("func_arg_map", {})
super(LLVMPrinter, self).__init__(*args, **kwargs)
......@@ -116,6 +118,45 @@ class LLVMPrinter(Printer):
e = add(e, node)
return e
def _print_Or(self, expr):
nodes = [self._print(a) for a in expr.args]
e = nodes[0]
for node in nodes[1:]:
e = self.builder.or_(e, node)
return e
def _print_And(self, expr):
nodes = [self._print(a) for a in expr.args]
e = nodes[0]
for node in nodes[1:]:
e = self.builder.and_(e, node)
return e
def _print_StrictLessThan(self, expr):
return self._comparison('<', expr)
def _print_LessThan(self, expr):
return self._comparison('<=', expr)
def _print_StrictGreaterThan(self, expr):
return self._comparison('>', expr)
def _print_GreaterThan(self, expr):
return self._comparison('>=', expr)
def _print_Unequality(self, expr):
return self._comparison('!=', expr)
def _print_Equality(self, expr):
return self._comparison('==', expr)
def _comparison(self, cmpop, expr):
if collateTypes([getTypeOfExpression(arg) for arg in expr.args]) == createType('double'):
comparison = self.builder.fcmp_unordered
else:
comparison = self.builder.icmp_signed
return comparison(cmpop, self._print(expr.lhs), self._print(expr.rhs))
def _print_KernelFunction(self, function):
# KernelFunction does not posses a return type
return_type = self.void
......@@ -135,7 +176,7 @@ class LLVMPrinter(Printer):
# func.attributes.add("inlinehint")
# func.attributes.add("argmemonly")
block = fn.append_basic_block(name="entry")
self.builder = ir.IRBuilder(block) #TODO use goto_block instead
self.builder = ir.IRBuilder(block) # TODO use goto_block instead
self._print(function.body)
self.builder.ret_void()
self.fn = fn
......@@ -172,12 +213,17 @@ class LLVMPrinter(Printer):
(createType("int"), createType("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
(createType("double"), createType("int")): functools.partial(self.builder.fptosi, node, self.integer),
(createType("double *"), createType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
(createType("int"), createType("double *")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
(createType("double * restrict"), createType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
(createType("int"), createType("double * restrict")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
(createType("double * restrict const"), createType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
(createType("int"), createType("double * restrict const")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
}
(createType("int"), createType("double *")): functools.partial(self.builder.inttoptr, node,
self.fp_pointer),
(createType("double * restrict"), createType("int")): functools.partial(self.builder.ptrtoint, node,
self.integer),
(createType("int"), createType("double * restrict")): functools.partial(self.builder.inttoptr, node,
self.fp_pointer),
(createType("double * restrict const"), createType("int")): functools.partial(self.builder.ptrtoint, node,
self.integer),
(createType("int"), createType("double * restrict const")): functools.partial(self.builder.inttoptr, node,
self.fp_pointer),
}
# TODO float, TEST: const, restrict
# TODO bitcast, addrspacecast
# TODO unsigned/signed fills
......@@ -199,6 +245,43 @@ class LLVMPrinter(Printer):
gep = self.builder.gep(ptr, [index])
return self.builder.load(gep, name=indexed.base.label.name)
def _print_Piecewise(self, piece):
if not piece.args[-1].cond:
# We need the last conditional to be a True, otherwise the resulting
# function may not return a result.
raise ValueError("All Piecewise expressions must contain an "
"(expr, True) statement to be used as a default "
"condition. Without one, the generated "
"expression may not evaluate to anything under "
"some condition.")
if piece.has(Assignment):
raise NotImplementedError('The llvm-backend does not support assignments'
'in the Piecewise function. It is questionable'
'whether to implement it. So far there is no'
'use-case to test it.')
else:
phiData = []
after_block = self.builder.append_basic_block()
for (expr, condition) in piece.args:
if condition == True: # Don't use 'is' use '=='!
phiData.append((self._print(expr), self.builder.block))
self.builder.branch(after_block)
self.builder.position_at_end(after_block)
else:
cond = self._print(condition)
trueBlock = self.builder.append_basic_block()
falseBlock = self.builder.append_basic_block()
self.builder.cbranch(cond, trueBlock, falseBlock)
self.builder.position_at_end(trueBlock)
phiData.append((self._print(expr), trueBlock))
self.builder.branch(after_block)
self.builder.position_at_end(falseBlock)
phi = self.builder.phi(to_llvm_type(getTypeOfExpression(piece)))
for (val, block) in phiData:
phi.add_incoming(val, block)
return phi
# Should have a list of math library functions to validate this.
# TODO function calls to libs
def _print_Function(self, expr):
......@@ -212,5 +295,10 @@ class LLVMPrinter(Printer):
return self.builder.call(fn, [e0], name)
def emptyPrinter(self, expr):
raise TypeError("Unsupported type for LLVM JIT conversion: %s %s"
% (type(expr), expr))
try:
import inspect
mro = inspect.getmro(expr)
except AttributeError:
mro = "None"
raise TypeError("Unsupported type for LLVM JIT conversion: Expression:\"%s\", Type:\"%s\", MRO:%s"
% (expr, type(expr), mro))
......@@ -36,7 +36,7 @@ def insertCasts(node):
pointer = None
newArgs = []
for arg, dataType in args:
if dataType.func == PointerType:
if dataType.func is PointerType:
assert pointer is None
pointer = arg
for arg, dataType in args:
......@@ -51,33 +51,44 @@ def insertCasts(node):
args = []
for arg in node.args:
args.append(insertCasts(arg))
# TODO indexed, SympyAssignment, LoopOverCoordinate
if node.func in (sp.Add, sp.Mul, sp.Pow): # TODO fix pow, don't cast integer on double
# TODO indexed, LoopOverCoordinate
if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
# TODO optimize pow, don't cast integer on double
types = [getTypeOfExpression(arg) for arg in args]
assert len(types) > 0
target = collateTypes(types)
zipped = list(zip(args, types))
if target.func == PointerType:
assert node.func == sp.Add
if target.func is PointerType:
assert node.func is sp.Add
return pointerArithmetic(zipped)
else:
return node.func(*cast(zipped, target))
elif node.func == ast.SympyAssignment:
# TODO casting of rhs/lhs
return node.func(*args)
elif node.func == ast.ResolvedFieldAccess:
#print("Node:", node, type(node), node.__class__.mro())
elif node.func is ast.SympyAssignment:
lhs = args[0]
rhs = args[1]
target = getTypeOfExpression(lhs)
if target.func is PointerType:
return node.func(*args) # TODO fix, not complete
else:
return node.func(lhs, *cast([(rhs, getTypeOfExpression(rhs))], target))
elif node.func is ast.ResolvedFieldAccess:
return node
elif node.func == ast.Block:
elif node.func is ast.Block:
for oldArg, newArg in zip(node.args, args):
node.replace(oldArg, newArg)
return node
elif node.func == ast.LoopOverCoordinate:
elif node.func is ast.LoopOverCoordinate:
for oldArg, newArg in zip(node.args, args):
node.replace(oldArg, newArg)
return node
elif node.func is sp.Piecewise:
exprs = [expr for (expr, _) in args]
types = [getTypeOfExpression(expr) for expr in exprs]
target = collateTypes(types)
zipped = list(zip(exprs, types))
casted_exprs = cast(zipped, target)
args = [arg.func(*[expr, arg.cond]) for (arg, expr) in zip(args, casted_exprs)]
#print(node.func(*args))
return node.func(*args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment