diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index a22c390407dcf2ab834ca39683b67df882496de7..d6287c87e98d1c7c36cfde6c5597b14fea5d190c 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -69,9 +69,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl #print('Desympied ast:') #print(code) #insert_casts(code) - print(code) + #print(code) code = insertCasts(code) - print(code) + #print(code) return code diff --git a/llvm/llvm.py b/llvm/llvm.py index 14cbcc681239913c283d2c216ac3ff210b0a059a..b48ec47bb564add71af36a5396ff59ca478c4022 100644 --- a/llvm/llvm.py +++ b/llvm/llvm.py @@ -201,7 +201,7 @@ class LLVMPrinter(Printer): # Should have a list of math library functions to validate this. # TODO function calls to libs def _print_Function(self, expr): - name = expr.name + name = expr.func.__name__ e0 = self._print(expr.args[0]) fn = self.ext_fn.get(name) if not fn: diff --git a/transformations/stage2.py b/transformations/stage2.py index d1cfd2d156f64d1f9e003e2b91bfec6d6de371c4..92df679d6ff6f3631ee19517873eb515cdd0a96c 100644 --- a/transformations/stage2.py +++ b/transformations/stage2.py @@ -6,7 +6,7 @@ from pystencils.data_types import TypedSymbol, createType, PointerType, StructTy import pystencils.astnodes as ast -def insertCasts(node): # TODO test casts!!!, edit testcase +def insertCasts(node): """ Checks the types and inserts casts and pointer arithmetic where necessary :param node: the head node of the ast @@ -21,7 +21,7 @@ def insertCasts(node): # TODO test casts!!!, edit testcase """ casted_args = [] for arg, dataType in zippedArgsTypes: - if dataType.numpyDtype != target.numpyDtype: # TODO ignoring const + if dataType.numpyDtype != target.numpyDtype: # ignoring const casted_args.append(castFunc(arg, target)) else: casted_args.append(arg) @@ -52,12 +52,11 @@ def insertCasts(node): # TODO test casts!!!, edit testcase for arg in node.args: args.append(insertCasts(arg)) # TODO indexed, SympyAssignment, LoopOverCoordinate, Pow - if node.func in (sp.Add, sp.Mul): + if node.func in (sp.Add, sp.Mul, sp.Pow): types = [getTypeOfExpression(arg) for arg in args] assert len(types) > 0 target = collateTypes(types) zipped = list(zip(args, types)) - print(zipped) if target.func == PointerType: assert node.func == sp.Add return pointerArithmetic(zipped) @@ -68,7 +67,6 @@ def insertCasts(node): # TODO test casts!!!, edit testcase return node.func(*args) elif node.func == ast.ResolvedFieldAccess: #print("Node:", node, type(node), node.__class__.mro()) - # TODO Everything return node elif node.func == ast.Block: for oldArg, newArg in zip(node.args, args):