From abb11958cdaaa3bb65e43b78eb69b9b59bbb3e97 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de>
Date: Fri, 20 Oct 2017 16:52:11 +0200
Subject: [PATCH] Implemented the support of llvm-backend for piecewise, eq,
 ne, lt, le, gt, gt. Testcases test_mu_equivalence, test_phi_equivalence are
 passing.

---
 astnodes.py               |  50 -----------------
 llvm/kernelcreation.py    |  10 +---
 llvm/llvm.py              | 112 ++++++++++++++++++++++++++++++++++----
 transformations/stage2.py |  37 ++++++++-----
 4 files changed, 126 insertions(+), 83 deletions(-)

diff --git a/astnodes.py b/astnodes.py
index 2255e9455..66d4ae8d0 100644
--- a/astnodes.py
+++ b/astnodes.py
@@ -507,53 +507,3 @@ class TemporaryMemoryFree(Node):
     def args(self):
         return []
 
-
-#_expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'}
-#
-#
-#class Expr(Node):
-#    def __init__(self, args, parent=None):
-#        super(Expr, self).__init__(parent)
-#        self._args = list(args)
-#        self.dtype = None
-#
-#    @property
-#    def args(self):
-#        return self._args
-#
-#    @args.setter
-#    def args(self, value):
-#        self._args = value
-#
-#    def replace(self, child, replacements):
-#        idx = self.args.index(child)
-#        del self.args[idx]
-#        if type(replacements) is list:
-#            for e in replacements:
-#                e.parent = self
-#            self.args = self.args[:idx] + replacements + self.args[idx:]
-#        else:
-#            replacements.parent = self
-#            self.args.insert(idx, replacements)
-#
-#    @property
-#    def symbolsDefined(self):
-#        return set()  # Todo fix for symbol analysis
-#
-#    @property
-#    def undefinedSymbols(self):
-#        return set()  # Todo fix for symbol analysis
-#
-#    def __repr__(self):
-#        return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args)
-#
-#
-#class PointerArithmetic(Expr):
-#    def __init__(self, args, pointer, parent=None):
-#        super(PointerArithmetic, self).__init__([args] + [pointer], parent)
-#        self.pointer = pointer
-#        self.offset = args
-#        self.dtype = pointer.dtype
-#
-#    def __repr__(self):
-#        return '*(%s + %s)' % (self.pointer, self.args)
diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py
index d6287c87e..e0c16f9a1 100644
--- a/llvm/kernelcreation.py
+++ b/llvm/kernelcreation.py
@@ -63,15 +63,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
     resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos)
     moveConstantsBeforeLoop(code)
 
-    #print('Ast:')
-    #print(code)
-    #desympy_ast(code)
-    #print('Desympied ast:')
-    #print(code)
-    #insert_casts(code)
-    #print(code)
+    print(code)
     code = insertCasts(code)
-    #print(code)
+    print(code)
     return code
 
 
diff --git a/llvm/llvm.py b/llvm/llvm.py
index 0eac24b93..b33e4e57a 100644
--- a/llvm/llvm.py
+++ b/llvm/llvm.py
@@ -6,8 +6,9 @@ from sympy import S
 # S is numbers?
 
 from pystencils.llvm.control_flow import Loop
-from pystencils.data_types import createType, to_llvm_type, getTypeOfExpression
-from sympy import Indexed  # TODO used astnodes, this should not work!
+from pystencils.data_types import createType, to_llvm_type, getTypeOfExpression, collateTypes
+from sympy import Indexed
+from sympy.codegen.ast import Assignment
 
 
 def generateLLVM(ast_node, module=None, builder=None):
@@ -19,11 +20,12 @@ def generateLLVM(ast_node, module=None, builder=None):
     if builder is None:
         builder = ir.IRBuilder()
     printer = LLVMPrinter(module, builder)
-    return printer._print(ast_node) #TODO use doprint() instead???
+    return printer._print(ast_node)  # TODO use doprint() instead???
 
 
 class LLVMPrinter(Printer):
     """Convert expressions to LLVM IR"""
+
     def __init__(self, module, builder, fn=None, *args, **kwargs):
         self.func_arg_map = kwargs.pop("func_arg_map", {})
         super(LLVMPrinter, self).__init__(*args, **kwargs)
@@ -116,6 +118,45 @@ class LLVMPrinter(Printer):
             e = add(e, node)
         return e
 
+    def _print_Or(self, expr):
+        nodes = [self._print(a) for a in expr.args]
+        e = nodes[0]
+        for node in nodes[1:]:
+            e = self.builder.or_(e, node)
+        return e
+
+    def _print_And(self, expr):
+        nodes = [self._print(a) for a in expr.args]
+        e = nodes[0]
+        for node in nodes[1:]:
+            e = self.builder.and_(e, node)
+        return e
+
+    def _print_StrictLessThan(self, expr):
+        return self._comparison('<', expr)
+
+    def _print_LessThan(self, expr):
+        return self._comparison('<=', expr)
+
+    def _print_StrictGreaterThan(self, expr):
+        return self._comparison('>', expr)
+
+    def _print_GreaterThan(self, expr):
+        return self._comparison('>=', expr)
+
+    def _print_Unequality(self, expr):
+        return self._comparison('!=', expr)
+
+    def _print_Equality(self, expr):
+        return self._comparison('==', expr)
+
+    def _comparison(self, cmpop, expr):
+        if collateTypes([getTypeOfExpression(arg) for arg in expr.args]) == createType('double'):
+            comparison = self.builder.fcmp_unordered
+        else:
+            comparison = self.builder.icmp_signed
+        return comparison(cmpop, self._print(expr.lhs), self._print(expr.rhs))
+
     def _print_KernelFunction(self, function):
         # KernelFunction does not posses a return type
         return_type = self.void
@@ -135,7 +176,7 @@ class LLVMPrinter(Printer):
         # func.attributes.add("inlinehint")
         # func.attributes.add("argmemonly")
         block = fn.append_basic_block(name="entry")
-        self.builder = ir.IRBuilder(block) #TODO use goto_block instead
+        self.builder = ir.IRBuilder(block)  # TODO use goto_block instead
         self._print(function.body)
         self.builder.ret_void()
         self.fn = fn
@@ -172,12 +213,17 @@ class LLVMPrinter(Printer):
             (createType("int"), createType("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
             (createType("double"), createType("int")): functools.partial(self.builder.fptosi, node, self.integer),
             (createType("double *"), createType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
-            (createType("int"), createType("double *")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
-            (createType("double * restrict"), createType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
-            (createType("int"), createType("double * restrict")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
-            (createType("double * restrict const"), createType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
-            (createType("int"), createType("double * restrict const")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
-            }
+            (createType("int"), createType("double *")): functools.partial(self.builder.inttoptr, node,
+                                                                           self.fp_pointer),
+            (createType("double * restrict"), createType("int")): functools.partial(self.builder.ptrtoint, node,
+                                                                                    self.integer),
+            (createType("int"), createType("double * restrict")): functools.partial(self.builder.inttoptr, node,
+                                                                                    self.fp_pointer),
+            (createType("double * restrict const"), createType("int")): functools.partial(self.builder.ptrtoint, node,
+                                                                                          self.integer),
+            (createType("int"), createType("double * restrict const")): functools.partial(self.builder.inttoptr, node,
+                                                                                          self.fp_pointer),
+        }
         # TODO float, TEST: const, restrict
         # TODO bitcast, addrspacecast
         # TODO unsigned/signed fills
@@ -199,6 +245,43 @@ class LLVMPrinter(Printer):
         gep = self.builder.gep(ptr, [index])
         return self.builder.load(gep, name=indexed.base.label.name)
 
+    def _print_Piecewise(self, piece):
+        if not piece.args[-1].cond:
+            # We need the last conditional to be a True, otherwise the resulting
+            # function may not return a result.
+            raise ValueError("All Piecewise expressions must contain an "
+                             "(expr, True) statement to be used as a default "
+                             "condition. Without one, the generated "
+                             "expression may not evaluate to anything under "
+                             "some condition.")
+        if piece.has(Assignment):
+            raise NotImplementedError('The llvm-backend does not support assignments'
+                                      'in the Piecewise function. It is questionable'
+                                      'whether to implement it. So far there is no'
+                                      'use-case to test it.')
+        else:
+            phiData = []
+            after_block = self.builder.append_basic_block()
+            for (expr, condition) in piece.args:
+                if condition == True:  # Don't use 'is' use '=='!
+                    phiData.append((self._print(expr), self.builder.block))
+                    self.builder.branch(after_block)
+                    self.builder.position_at_end(after_block)
+                else:
+                    cond = self._print(condition)
+                    trueBlock = self.builder.append_basic_block()
+                    falseBlock = self.builder.append_basic_block()
+                    self.builder.cbranch(cond, trueBlock, falseBlock)
+                    self.builder.position_at_end(trueBlock)
+                    phiData.append((self._print(expr), trueBlock))
+                    self.builder.branch(after_block)
+                    self.builder.position_at_end(falseBlock)
+
+            phi = self.builder.phi(to_llvm_type(getTypeOfExpression(piece)))
+            for (val, block) in phiData:
+                phi.add_incoming(val, block)
+            return phi
+
     # Should have a list of math library functions to validate this.
     # TODO function calls to libs
     def _print_Function(self, expr):
@@ -212,5 +295,10 @@ class LLVMPrinter(Printer):
         return self.builder.call(fn, [e0], name)
 
     def emptyPrinter(self, expr):
-        raise TypeError("Unsupported type for LLVM JIT conversion: %s %s"
-                        % (type(expr), expr))
+        try:
+            import inspect
+            mro = inspect.getmro(expr)
+        except AttributeError:
+            mro = "None"
+        raise TypeError("Unsupported type for LLVM JIT conversion: Expression:\"%s\", Type:\"%s\", MRO:%s"
+                        % (expr, type(expr), mro))
diff --git a/transformations/stage2.py b/transformations/stage2.py
index 42264a309..ebfc5309d 100644
--- a/transformations/stage2.py
+++ b/transformations/stage2.py
@@ -36,7 +36,7 @@ def insertCasts(node):
         pointer = None
         newArgs = []
         for arg, dataType in args:
-            if dataType.func == PointerType:
+            if dataType.func is PointerType:
                 assert pointer is None
                 pointer = arg
         for arg, dataType in args:
@@ -51,33 +51,44 @@ def insertCasts(node):
     args = []
     for arg in node.args:
         args.append(insertCasts(arg))
-    # TODO indexed, SympyAssignment, LoopOverCoordinate
-    if node.func in (sp.Add, sp.Mul, sp.Pow): # TODO fix pow, don't cast integer on double
+    # TODO indexed, LoopOverCoordinate
+    if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
+        # TODO optimize pow, don't cast integer on double
         types = [getTypeOfExpression(arg) for arg in args]
         assert len(types) > 0
         target = collateTypes(types)
         zipped = list(zip(args, types))
-        if target.func == PointerType:
-            assert node.func == sp.Add
+        if target.func is PointerType:
+            assert node.func is sp.Add
             return pointerArithmetic(zipped)
         else:
             return node.func(*cast(zipped, target))
-    elif node.func == ast.SympyAssignment:
-        # TODO casting of rhs/lhs
-        return node.func(*args)
-    elif node.func == ast.ResolvedFieldAccess:
-        #print("Node:", node, type(node), node.__class__.mro())
+    elif node.func is ast.SympyAssignment:
+        lhs = args[0]
+        rhs = args[1]
+        target = getTypeOfExpression(lhs)
+        if target.func is PointerType:
+            return node.func(*args)  # TODO fix, not complete
+        else:
+            return node.func(lhs, *cast([(rhs, getTypeOfExpression(rhs))], target))
+    elif node.func is ast.ResolvedFieldAccess:
         return node
-    elif node.func == ast.Block:
+    elif node.func is ast.Block:
         for oldArg, newArg in zip(node.args, args):
             node.replace(oldArg, newArg)
         return node
-    elif node.func == ast.LoopOverCoordinate:
+    elif node.func is ast.LoopOverCoordinate:
         for oldArg, newArg in zip(node.args, args):
             node.replace(oldArg, newArg)
         return node
+    elif node.func is sp.Piecewise:
+        exprs = [expr for (expr, _) in args]
+        types = [getTypeOfExpression(expr) for expr in exprs]
+        target = collateTypes(types)
+        zipped = list(zip(exprs, types))
+        casted_exprs = cast(zipped, target)
+        args = [arg.func(*[expr, arg.cond]) for (arg, expr) in zip(args, casted_exprs)]
 
-    #print(node.func(*args))
     return node.func(*args)
 
 
-- 
GitLab