diff --git a/ast.py b/ast.py index 9f027ac1dd8ad0e9f1a0670e1320dfbd040b8b2f..4669f3b3efa34f83d46c9dd814aa52e5dd3370a5 100644 --- a/ast.py +++ b/ast.py @@ -341,7 +341,7 @@ class TemporaryMemoryAllocation(Node): @property def symbolsDefined(self): - return set([self._symbol]) + return set([self.symbol]) @property def symbolsRead(self): @@ -349,12 +349,12 @@ class TemporaryMemoryAllocation(Node): @property def args(self): - return [self._symbol] + return [self.symbol] class TemporaryMemoryFree(Node): def __init__(self, typedSymbol): - self._symbol = typedSymbol + self.symbol = typedSymbol @property def symbolsDefined(self): diff --git a/backends/cbackend.py b/backends/cbackend.py index bf5f92a06e24dc55025d49a3165e6a7380fa47b7..314c301bfd84d45e75f7657bdb6518afdfad6dc2 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -1,4 +1,4 @@ -import cgen as c +import textwrap from sympy.utilities.codegen import CCodePrinter from pystencils.ast import Node @@ -40,9 +40,6 @@ class CustomCppCode(Node): def symbolsRead(self): return self._symbolsRead - def generateC(self): - return c.LiteralLines(self._code) - class PrintNode(CustomCppCode): def __init__(self, symbolToPrint): @@ -58,6 +55,7 @@ class CBackend: def __init__(self, cuda=False): self.cuda = cuda self.sympyPrinter = CustomSympyPrinter() + self._indent = " " def __call__(self, node): return str(self._print(node)) @@ -70,46 +68,30 @@ class CBackend: raise NotImplementedError("CBackend does not support node of type " + cls.__name__) def _print_KernelFunction(self, node): - functionArguments = [MyPOD(s.dtype, s.name) for s in node.parameters] + functionArguments = ["%s %s" % (s.dtype, s.name) for s in node.parameters] prefix = "__global__ void" if self.cuda else "void" - functionPOD = MyPOD(prefix, node.functionName, ) - funcDeclaration = c.FunctionDeclaration(functionPOD, functionArguments) - return c.FunctionBody(funcDeclaration, self._print(node.body)) + funcDeclaration = "%s %s(%s)" % (prefix, node.functionName, ", ".join(functionArguments)) + body = self._print(node.body) + return funcDeclaration + "\n" + body def _print_Block(self, node): - return c.Block([self._print(child) for child in node.args]) + blockContents = "\n".join([self._print(child) for child in node.args]) + return "{\n%s\n}" % (textwrap.indent(blockContents, self._indent)) def _print_PragmaBlock(self, node): - class PragmaGenerable(c.Generable): - def __init__(self, line, block): - self._line = line - self._block = block - - def generate(self): - yield self._line - for e in self._block.generate(): - yield e - return PragmaGenerable(node.pragmaLine, self._print_Block(node)) + return "%s\n%s" % (node.pragmaLine, self._print_Block(node)) def _print_LoopOverCoordinate(self, node): - class LoopWithOptionalPrefix(c.CustomLoop): - def __init__(self, intro_line, body, prefixLines=[]): - super(LoopWithOptionalPrefix, self).__init__(intro_line, body) - self.prefixLines = prefixLines - - def generate(self): - for l in self.prefixLines: - yield l - - for e in super(LoopWithOptionalPrefix, self).generate(): - yield e - counterVar = node.loopCounterName start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start)) condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop)) update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(node.step),) loopStr = "for (%s; %s; %s)" % (start, condition, update) - return LoopWithOptionalPrefix(loopStr, self._print(node.body), prefixLines=node.prefixLines) + + prefix = "\n".join(node.prefixLines) + if prefix: + prefix += "\n" + return "%s%s\n%s" % (prefix, loopStr, self._print(node.body)) def _print_SympyAssignment(self, node): dtype = "" @@ -118,19 +100,17 @@ class CBackend: dtype = "const " + node.lhs.dtype + " " else: dtype = node.lhs.dtype + " " - - return c.Assign(dtype + self.sympyPrinter.doprint(node.lhs), - self.sympyPrinter.doprint(node.rhs)) + return "%s %s = %s;" % (dtype, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs)) def _print_TemporaryMemoryAllocation(self, node): - return c.Assign(node.symbol.dtype + " * " + self.sympyPrinter.doprint(node.symbol), - "new %s[%s]" % (node.symbol.dtype, self.sympyPrinter.doprint(node.size))) + return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol), + node.symbol.dtype, self.sympyPrinter.doprint(node.size)) def _print_TemporaryMemoryFree(self, node): - return c.Statement("delete [] %s" % (self.sympyPrinter.doprint(node.symbol),)) + return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol),) def _print_CustomCppCode(self, node): - return c.LiteralLines(node.code) + return node.code # ------------------------------------------ Helper function & classes ------------------------------------------------- @@ -155,13 +135,4 @@ class CustomSympyPrinter(CCodePrinter): def _print_Piecewise(self, expr): """Print piecewise in one line (remove newlines)""" result = super(CustomSympyPrinter, self)._print_Piecewise(expr) - return result.replace("\n", "") - - -class MyPOD(c.Declarator): - def __init__(self, dtype, name): - self.dtype = dtype - self.name = name - - def get_decl_pair(self): - return [self.dtype], self.name + return result.replace("\n", "") \ No newline at end of file diff --git a/transformations.py b/transformations.py index 2b47cbde3b7b742759455c365c236168b9ccecc6..62d400b5a2d316a92cd6fd015bb9ab1137c41ce5 100644 --- a/transformations.py +++ b/transformations.py @@ -364,7 +364,7 @@ def splitInnerLoop(astNode, symbolGroups): assignmentGroups.append(assignmentGroup) newLoops = [innerLoop.newLoopWithDifferentBody(ast.Block(group)) for group in assignmentGroups] - innerLoop.parent.replace(innerLoop, newLoops) + innerLoop.parent.replace(innerLoop, ast.Block(newLoops)) for tmpArray in symbolsWithTemporaryArray: outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArray, innerLoop.stop))