From b7d48508ac5629c1414749dea544d15c83668894 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de> Date: Fri, 13 Oct 2017 10:54:42 +0200 Subject: [PATCH] Creating and updating more demo jupyter notebooks. Fixed function bug. --- llvm/kernelcreation.py | 4 ++-- llvm/llvm.py | 2 +- transformations/stage2.py | 8 +++----- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index a22c39040..d6287c87e 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -69,9 +69,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl #print('Desympied ast:') #print(code) #insert_casts(code) - print(code) + #print(code) code = insertCasts(code) - print(code) + #print(code) return code diff --git a/llvm/llvm.py b/llvm/llvm.py index 14cbcc681..b48ec47bb 100644 --- a/llvm/llvm.py +++ b/llvm/llvm.py @@ -201,7 +201,7 @@ class LLVMPrinter(Printer): # Should have a list of math library functions to validate this. # TODO function calls to libs def _print_Function(self, expr): - name = expr.name + name = expr.func.__name__ e0 = self._print(expr.args[0]) fn = self.ext_fn.get(name) if not fn: diff --git a/transformations/stage2.py b/transformations/stage2.py index d1cfd2d15..92df679d6 100644 --- a/transformations/stage2.py +++ b/transformations/stage2.py @@ -6,7 +6,7 @@ from pystencils.data_types import TypedSymbol, createType, PointerType, StructTy import pystencils.astnodes as ast -def insertCasts(node): # TODO test casts!!!, edit testcase +def insertCasts(node): """ Checks the types and inserts casts and pointer arithmetic where necessary :param node: the head node of the ast @@ -21,7 +21,7 @@ def insertCasts(node): # TODO test casts!!!, edit testcase """ casted_args = [] for arg, dataType in zippedArgsTypes: - if dataType.numpyDtype != target.numpyDtype: # TODO ignoring const + if dataType.numpyDtype != target.numpyDtype: # ignoring const casted_args.append(castFunc(arg, target)) else: casted_args.append(arg) @@ -52,12 +52,11 @@ def insertCasts(node): # TODO test casts!!!, edit testcase for arg in node.args: args.append(insertCasts(arg)) # TODO indexed, SympyAssignment, LoopOverCoordinate, Pow - if node.func in (sp.Add, sp.Mul): + if node.func in (sp.Add, sp.Mul, sp.Pow): types = [getTypeOfExpression(arg) for arg in args] assert len(types) > 0 target = collateTypes(types) zipped = list(zip(args, types)) - print(zipped) if target.func == PointerType: assert node.func == sp.Add return pointerArithmetic(zipped) @@ -68,7 +67,6 @@ def insertCasts(node): # TODO test casts!!!, edit testcase return node.func(*args) elif node.func == ast.ResolvedFieldAccess: #print("Node:", node, type(node), node.__class__.mro()) - # TODO Everything return node elif node.func == ast.Block: for oldArg, newArg in zip(node.args, args): -- GitLab