From 663dd123a07be3571e203abf8c3a7bf05f8fe065 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Thu, 10 Nov 2016 17:05:11 +0100 Subject: [PATCH] C backend: no dependence to cgen package any more --- ast.py | 6 ++-- backends/cbackend.py | 69 +++++++++++++------------------------------- transformations.py | 2 +- 3 files changed, 24 insertions(+), 53 deletions(-) diff --git a/ast.py b/ast.py index 9f027ac1d..4669f3b3e 100644 --- a/ast.py +++ b/ast.py @@ -341,7 +341,7 @@ class TemporaryMemoryAllocation(Node): @property def symbolsDefined(self): - return set([self._symbol]) + return set([self.symbol]) @property def symbolsRead(self): @@ -349,12 +349,12 @@ class TemporaryMemoryAllocation(Node): @property def args(self): - return [self._symbol] + return [self.symbol] class TemporaryMemoryFree(Node): def __init__(self, typedSymbol): - self._symbol = typedSymbol + self.symbol = typedSymbol @property def symbolsDefined(self): diff --git a/backends/cbackend.py b/backends/cbackend.py index bf5f92a06..314c301bf 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -1,4 +1,4 @@ -import cgen as c +import textwrap from sympy.utilities.codegen import CCodePrinter from pystencils.ast import Node @@ -40,9 +40,6 @@ class CustomCppCode(Node): def symbolsRead(self): return self._symbolsRead - def generateC(self): - return c.LiteralLines(self._code) - class PrintNode(CustomCppCode): def __init__(self, symbolToPrint): @@ -58,6 +55,7 @@ class CBackend: def __init__(self, cuda=False): self.cuda = cuda self.sympyPrinter = CustomSympyPrinter() + self._indent = " " def __call__(self, node): return str(self._print(node)) @@ -70,46 +68,30 @@ class CBackend: raise NotImplementedError("CBackend does not support node of type " + cls.__name__) def _print_KernelFunction(self, node): - functionArguments = [MyPOD(s.dtype, s.name) for s in node.parameters] + functionArguments = ["%s %s" % (s.dtype, s.name) for s in node.parameters] prefix = "__global__ void" if self.cuda else "void" - functionPOD = MyPOD(prefix, node.functionName, ) - funcDeclaration = c.FunctionDeclaration(functionPOD, functionArguments) - return c.FunctionBody(funcDeclaration, self._print(node.body)) + funcDeclaration = "%s %s(%s)" % (prefix, node.functionName, ", ".join(functionArguments)) + body = self._print(node.body) + return funcDeclaration + "\n" + body def _print_Block(self, node): - return c.Block([self._print(child) for child in node.args]) + blockContents = "\n".join([self._print(child) for child in node.args]) + return "{\n%s\n}" % (textwrap.indent(blockContents, self._indent)) def _print_PragmaBlock(self, node): - class PragmaGenerable(c.Generable): - def __init__(self, line, block): - self._line = line - self._block = block - - def generate(self): - yield self._line - for e in self._block.generate(): - yield e - return PragmaGenerable(node.pragmaLine, self._print_Block(node)) + return "%s\n%s" % (node.pragmaLine, self._print_Block(node)) def _print_LoopOverCoordinate(self, node): - class LoopWithOptionalPrefix(c.CustomLoop): - def __init__(self, intro_line, body, prefixLines=[]): - super(LoopWithOptionalPrefix, self).__init__(intro_line, body) - self.prefixLines = prefixLines - - def generate(self): - for l in self.prefixLines: - yield l - - for e in super(LoopWithOptionalPrefix, self).generate(): - yield e - counterVar = node.loopCounterName start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start)) condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop)) update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(node.step),) loopStr = "for (%s; %s; %s)" % (start, condition, update) - return LoopWithOptionalPrefix(loopStr, self._print(node.body), prefixLines=node.prefixLines) + + prefix = "\n".join(node.prefixLines) + if prefix: + prefix += "\n" + return "%s%s\n%s" % (prefix, loopStr, self._print(node.body)) def _print_SympyAssignment(self, node): dtype = "" @@ -118,19 +100,17 @@ class CBackend: dtype = "const " + node.lhs.dtype + " " else: dtype = node.lhs.dtype + " " - - return c.Assign(dtype + self.sympyPrinter.doprint(node.lhs), - self.sympyPrinter.doprint(node.rhs)) + return "%s %s = %s;" % (dtype, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs)) def _print_TemporaryMemoryAllocation(self, node): - return c.Assign(node.symbol.dtype + " * " + self.sympyPrinter.doprint(node.symbol), - "new %s[%s]" % (node.symbol.dtype, self.sympyPrinter.doprint(node.size))) + return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol), + node.symbol.dtype, self.sympyPrinter.doprint(node.size)) def _print_TemporaryMemoryFree(self, node): - return c.Statement("delete [] %s" % (self.sympyPrinter.doprint(node.symbol),)) + return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol),) def _print_CustomCppCode(self, node): - return c.LiteralLines(node.code) + return node.code # ------------------------------------------ Helper function & classes ------------------------------------------------- @@ -155,13 +135,4 @@ class CustomSympyPrinter(CCodePrinter): def _print_Piecewise(self, expr): """Print piecewise in one line (remove newlines)""" result = super(CustomSympyPrinter, self)._print_Piecewise(expr) - return result.replace("\n", "") - - -class MyPOD(c.Declarator): - def __init__(self, dtype, name): - self.dtype = dtype - self.name = name - - def get_decl_pair(self): - return [self.dtype], self.name + return result.replace("\n", "") \ No newline at end of file diff --git a/transformations.py b/transformations.py index 2b47cbde3..62d400b5a 100644 --- a/transformations.py +++ b/transformations.py @@ -364,7 +364,7 @@ def splitInnerLoop(astNode, symbolGroups): assignmentGroups.append(assignmentGroup) newLoops = [innerLoop.newLoopWithDifferentBody(ast.Block(group)) for group in assignmentGroups] - innerLoop.parent.replace(innerLoop, newLoops) + innerLoop.parent.replace(innerLoop, ast.Block(newLoops)) for tmpArray in symbolsWithTemporaryArray: outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArray, innerLoop.stop)) -- GitLab