From abb11958cdaaa3bb65e43b78eb69b9b59bbb3e97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de> Date: Fri, 20 Oct 2017 16:52:11 +0200 Subject: [PATCH] Implemented the support of llvm-backend for piecewise, eq, ne, lt, le, gt, gt. Testcases test_mu_equivalence, test_phi_equivalence are passing. --- astnodes.py | 50 ----------------- llvm/kernelcreation.py | 10 +--- llvm/llvm.py | 112 ++++++++++++++++++++++++++++++++++---- transformations/stage2.py | 37 ++++++++----- 4 files changed, 126 insertions(+), 83 deletions(-) diff --git a/astnodes.py b/astnodes.py index 2255e9455..66d4ae8d0 100644 --- a/astnodes.py +++ b/astnodes.py @@ -507,53 +507,3 @@ class TemporaryMemoryFree(Node): def args(self): return [] - -#_expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'} -# -# -#class Expr(Node): -# def __init__(self, args, parent=None): -# super(Expr, self).__init__(parent) -# self._args = list(args) -# self.dtype = None -# -# @property -# def args(self): -# return self._args -# -# @args.setter -# def args(self, value): -# self._args = value -# -# def replace(self, child, replacements): -# idx = self.args.index(child) -# del self.args[idx] -# if type(replacements) is list: -# for e in replacements: -# e.parent = self -# self.args = self.args[:idx] + replacements + self.args[idx:] -# else: -# replacements.parent = self -# self.args.insert(idx, replacements) -# -# @property -# def symbolsDefined(self): -# return set() # Todo fix for symbol analysis -# -# @property -# def undefinedSymbols(self): -# return set() # Todo fix for symbol analysis -# -# def __repr__(self): -# return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args) -# -# -#class PointerArithmetic(Expr): -# def __init__(self, args, pointer, parent=None): -# super(PointerArithmetic, self).__init__([args] + [pointer], parent) -# self.pointer = pointer -# self.offset = args -# self.dtype = pointer.dtype -# -# def __repr__(self): -# return '*(%s + %s)' % (self.pointer, self.args) diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index d6287c87e..e0c16f9a1 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -63,15 +63,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos) moveConstantsBeforeLoop(code) - #print('Ast:') - #print(code) - #desympy_ast(code) - #print('Desympied ast:') - #print(code) - #insert_casts(code) - #print(code) + print(code) code = insertCasts(code) - #print(code) + print(code) return code diff --git a/llvm/llvm.py b/llvm/llvm.py index 0eac24b93..b33e4e57a 100644 --- a/llvm/llvm.py +++ b/llvm/llvm.py @@ -6,8 +6,9 @@ from sympy import S # S is numbers? from pystencils.llvm.control_flow import Loop -from pystencils.data_types import createType, to_llvm_type, getTypeOfExpression -from sympy import Indexed # TODO used astnodes, this should not work! +from pystencils.data_types import createType, to_llvm_type, getTypeOfExpression, collateTypes +from sympy import Indexed +from sympy.codegen.ast import Assignment def generateLLVM(ast_node, module=None, builder=None): @@ -19,11 +20,12 @@ def generateLLVM(ast_node, module=None, builder=None): if builder is None: builder = ir.IRBuilder() printer = LLVMPrinter(module, builder) - return printer._print(ast_node) #TODO use doprint() instead??? + return printer._print(ast_node) # TODO use doprint() instead??? class LLVMPrinter(Printer): """Convert expressions to LLVM IR""" + def __init__(self, module, builder, fn=None, *args, **kwargs): self.func_arg_map = kwargs.pop("func_arg_map", {}) super(LLVMPrinter, self).__init__(*args, **kwargs) @@ -116,6 +118,45 @@ class LLVMPrinter(Printer): e = add(e, node) return e + def _print_Or(self, expr): + nodes = [self._print(a) for a in expr.args] + e = nodes[0] + for node in nodes[1:]: + e = self.builder.or_(e, node) + return e + + def _print_And(self, expr): + nodes = [self._print(a) for a in expr.args] + e = nodes[0] + for node in nodes[1:]: + e = self.builder.and_(e, node) + return e + + def _print_StrictLessThan(self, expr): + return self._comparison('<', expr) + + def _print_LessThan(self, expr): + return self._comparison('<=', expr) + + def _print_StrictGreaterThan(self, expr): + return self._comparison('>', expr) + + def _print_GreaterThan(self, expr): + return self._comparison('>=', expr) + + def _print_Unequality(self, expr): + return self._comparison('!=', expr) + + def _print_Equality(self, expr): + return self._comparison('==', expr) + + def _comparison(self, cmpop, expr): + if collateTypes([getTypeOfExpression(arg) for arg in expr.args]) == createType('double'): + comparison = self.builder.fcmp_unordered + else: + comparison = self.builder.icmp_signed + return comparison(cmpop, self._print(expr.lhs), self._print(expr.rhs)) + def _print_KernelFunction(self, function): # KernelFunction does not posses a return type return_type = self.void @@ -135,7 +176,7 @@ class LLVMPrinter(Printer): # func.attributes.add("inlinehint") # func.attributes.add("argmemonly") block = fn.append_basic_block(name="entry") - self.builder = ir.IRBuilder(block) #TODO use goto_block instead + self.builder = ir.IRBuilder(block) # TODO use goto_block instead self._print(function.body) self.builder.ret_void() self.fn = fn @@ -172,12 +213,17 @@ class LLVMPrinter(Printer): (createType("int"), createType("double")): functools.partial(self.builder.sitofp, node, self.fp_type), (createType("double"), createType("int")): functools.partial(self.builder.fptosi, node, self.integer), (createType("double *"), createType("int")): functools.partial(self.builder.ptrtoint, node, self.integer), - (createType("int"), createType("double *")): functools.partial(self.builder.inttoptr, node, self.fp_pointer), - (createType("double * restrict"), createType("int")): functools.partial(self.builder.ptrtoint, node, self.integer), - (createType("int"), createType("double * restrict")): functools.partial(self.builder.inttoptr, node, self.fp_pointer), - (createType("double * restrict const"), createType("int")): functools.partial(self.builder.ptrtoint, node, self.integer), - (createType("int"), createType("double * restrict const")): functools.partial(self.builder.inttoptr, node, self.fp_pointer), - } + (createType("int"), createType("double *")): functools.partial(self.builder.inttoptr, node, + self.fp_pointer), + (createType("double * restrict"), createType("int")): functools.partial(self.builder.ptrtoint, node, + self.integer), + (createType("int"), createType("double * restrict")): functools.partial(self.builder.inttoptr, node, + self.fp_pointer), + (createType("double * restrict const"), createType("int")): functools.partial(self.builder.ptrtoint, node, + self.integer), + (createType("int"), createType("double * restrict const")): functools.partial(self.builder.inttoptr, node, + self.fp_pointer), + } # TODO float, TEST: const, restrict # TODO bitcast, addrspacecast # TODO unsigned/signed fills @@ -199,6 +245,43 @@ class LLVMPrinter(Printer): gep = self.builder.gep(ptr, [index]) return self.builder.load(gep, name=indexed.base.label.name) + def _print_Piecewise(self, piece): + if not piece.args[-1].cond: + # We need the last conditional to be a True, otherwise the resulting + # function may not return a result. + raise ValueError("All Piecewise expressions must contain an " + "(expr, True) statement to be used as a default " + "condition. Without one, the generated " + "expression may not evaluate to anything under " + "some condition.") + if piece.has(Assignment): + raise NotImplementedError('The llvm-backend does not support assignments' + 'in the Piecewise function. It is questionable' + 'whether to implement it. So far there is no' + 'use-case to test it.') + else: + phiData = [] + after_block = self.builder.append_basic_block() + for (expr, condition) in piece.args: + if condition == True: # Don't use 'is' use '=='! + phiData.append((self._print(expr), self.builder.block)) + self.builder.branch(after_block) + self.builder.position_at_end(after_block) + else: + cond = self._print(condition) + trueBlock = self.builder.append_basic_block() + falseBlock = self.builder.append_basic_block() + self.builder.cbranch(cond, trueBlock, falseBlock) + self.builder.position_at_end(trueBlock) + phiData.append((self._print(expr), trueBlock)) + self.builder.branch(after_block) + self.builder.position_at_end(falseBlock) + + phi = self.builder.phi(to_llvm_type(getTypeOfExpression(piece))) + for (val, block) in phiData: + phi.add_incoming(val, block) + return phi + # Should have a list of math library functions to validate this. # TODO function calls to libs def _print_Function(self, expr): @@ -212,5 +295,10 @@ class LLVMPrinter(Printer): return self.builder.call(fn, [e0], name) def emptyPrinter(self, expr): - raise TypeError("Unsupported type for LLVM JIT conversion: %s %s" - % (type(expr), expr)) + try: + import inspect + mro = inspect.getmro(expr) + except AttributeError: + mro = "None" + raise TypeError("Unsupported type for LLVM JIT conversion: Expression:\"%s\", Type:\"%s\", MRO:%s" + % (expr, type(expr), mro)) diff --git a/transformations/stage2.py b/transformations/stage2.py index 42264a309..ebfc5309d 100644 --- a/transformations/stage2.py +++ b/transformations/stage2.py @@ -36,7 +36,7 @@ def insertCasts(node): pointer = None newArgs = [] for arg, dataType in args: - if dataType.func == PointerType: + if dataType.func is PointerType: assert pointer is None pointer = arg for arg, dataType in args: @@ -51,33 +51,44 @@ def insertCasts(node): args = [] for arg in node.args: args.append(insertCasts(arg)) - # TODO indexed, SympyAssignment, LoopOverCoordinate - if node.func in (sp.Add, sp.Mul, sp.Pow): # TODO fix pow, don't cast integer on double + # TODO indexed, LoopOverCoordinate + if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge): + # TODO optimize pow, don't cast integer on double types = [getTypeOfExpression(arg) for arg in args] assert len(types) > 0 target = collateTypes(types) zipped = list(zip(args, types)) - if target.func == PointerType: - assert node.func == sp.Add + if target.func is PointerType: + assert node.func is sp.Add return pointerArithmetic(zipped) else: return node.func(*cast(zipped, target)) - elif node.func == ast.SympyAssignment: - # TODO casting of rhs/lhs - return node.func(*args) - elif node.func == ast.ResolvedFieldAccess: - #print("Node:", node, type(node), node.__class__.mro()) + elif node.func is ast.SympyAssignment: + lhs = args[0] + rhs = args[1] + target = getTypeOfExpression(lhs) + if target.func is PointerType: + return node.func(*args) # TODO fix, not complete + else: + return node.func(lhs, *cast([(rhs, getTypeOfExpression(rhs))], target)) + elif node.func is ast.ResolvedFieldAccess: return node - elif node.func == ast.Block: + elif node.func is ast.Block: for oldArg, newArg in zip(node.args, args): node.replace(oldArg, newArg) return node - elif node.func == ast.LoopOverCoordinate: + elif node.func is ast.LoopOverCoordinate: for oldArg, newArg in zip(node.args, args): node.replace(oldArg, newArg) return node + elif node.func is sp.Piecewise: + exprs = [expr for (expr, _) in args] + types = [getTypeOfExpression(expr) for expr in exprs] + target = collateTypes(types) + zipped = list(zip(exprs, types)) + casted_exprs = cast(zipped, target) + args = [arg.func(*[expr, arg.cond]) for (arg, expr) in zip(args, casted_exprs)] - #print(node.func(*args)) return node.func(*args) -- GitLab