From b7d48508ac5629c1414749dea544d15c83668894 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de>
Date: Fri, 13 Oct 2017 10:54:42 +0200
Subject: [PATCH] Creating and updating more demo jupyter notebooks. Fixed
 function bug.

---
 llvm/kernelcreation.py    | 4 ++--
 llvm/llvm.py              | 2 +-
 transformations/stage2.py | 8 +++-----
 3 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py
index a22c39040..d6287c87e 100644
--- a/llvm/kernelcreation.py
+++ b/llvm/kernelcreation.py
@@ -69,9 +69,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
     #print('Desympied ast:')
     #print(code)
     #insert_casts(code)
-    print(code)
+    #print(code)
     code = insertCasts(code)
-    print(code)
+    #print(code)
     return code
 
 
diff --git a/llvm/llvm.py b/llvm/llvm.py
index 14cbcc681..b48ec47bb 100644
--- a/llvm/llvm.py
+++ b/llvm/llvm.py
@@ -201,7 +201,7 @@ class LLVMPrinter(Printer):
     # Should have a list of math library functions to validate this.
     # TODO function calls to libs
     def _print_Function(self, expr):
-        name = expr.name
+        name = expr.func.__name__
         e0 = self._print(expr.args[0])
         fn = self.ext_fn.get(name)
         if not fn:
diff --git a/transformations/stage2.py b/transformations/stage2.py
index d1cfd2d15..92df679d6 100644
--- a/transformations/stage2.py
+++ b/transformations/stage2.py
@@ -6,7 +6,7 @@ from pystencils.data_types import TypedSymbol, createType, PointerType, StructTy
 import pystencils.astnodes as ast
 
 
-def insertCasts(node): # TODO test casts!!!, edit testcase
+def insertCasts(node):
     """
     Checks the types and inserts casts and pointer arithmetic where necessary
     :param node: the head node of the ast
@@ -21,7 +21,7 @@ def insertCasts(node): # TODO test casts!!!, edit testcase
         """
         casted_args = []
         for arg, dataType in zippedArgsTypes:
-            if dataType.numpyDtype != target.numpyDtype: # TODO ignoring const
+            if dataType.numpyDtype != target.numpyDtype:  # ignoring const
                 casted_args.append(castFunc(arg, target))
             else:
                 casted_args.append(arg)
@@ -52,12 +52,11 @@ def insertCasts(node): # TODO test casts!!!, edit testcase
     for arg in node.args:
         args.append(insertCasts(arg))
     # TODO indexed, SympyAssignment, LoopOverCoordinate, Pow
-    if node.func in (sp.Add, sp.Mul):
+    if node.func in (sp.Add, sp.Mul, sp.Pow):
         types = [getTypeOfExpression(arg) for arg in args]
         assert len(types) > 0
         target = collateTypes(types)
         zipped = list(zip(args, types))
-        print(zipped)
         if target.func == PointerType:
             assert node.func == sp.Add
             return pointerArithmetic(zipped)
@@ -68,7 +67,6 @@ def insertCasts(node): # TODO test casts!!!, edit testcase
         return node.func(*args)
     elif node.func == ast.ResolvedFieldAccess:
         #print("Node:", node, type(node), node.__class__.mro())
-        # TODO Everything
         return node
     elif node.func == ast.Block:
         for oldArg, newArg in zip(node.args, args):
-- 
GitLab