Skip to content
Snippets Groups Projects
Commit 663dd123 authored by Martin Bauer's avatar Martin Bauer
Browse files

C backend: no dependence to cgen package any more

parent a2080a92
No related merge requests found
......@@ -341,7 +341,7 @@ class TemporaryMemoryAllocation(Node):
@property
def symbolsDefined(self):
return set([self._symbol])
return set([self.symbol])
@property
def symbolsRead(self):
......@@ -349,12 +349,12 @@ class TemporaryMemoryAllocation(Node):
@property
def args(self):
return [self._symbol]
return [self.symbol]
class TemporaryMemoryFree(Node):
def __init__(self, typedSymbol):
self._symbol = typedSymbol
self.symbol = typedSymbol
@property
def symbolsDefined(self):
......
import cgen as c
import textwrap
from sympy.utilities.codegen import CCodePrinter
from pystencils.ast import Node
......@@ -40,9 +40,6 @@ class CustomCppCode(Node):
def symbolsRead(self):
return self._symbolsRead
def generateC(self):
return c.LiteralLines(self._code)
class PrintNode(CustomCppCode):
def __init__(self, symbolToPrint):
......@@ -58,6 +55,7 @@ class CBackend:
def __init__(self, cuda=False):
self.cuda = cuda
self.sympyPrinter = CustomSympyPrinter()
self._indent = " "
def __call__(self, node):
return str(self._print(node))
......@@ -70,46 +68,30 @@ class CBackend:
raise NotImplementedError("CBackend does not support node of type " + cls.__name__)
def _print_KernelFunction(self, node):
functionArguments = [MyPOD(s.dtype, s.name) for s in node.parameters]
functionArguments = ["%s %s" % (s.dtype, s.name) for s in node.parameters]
prefix = "__global__ void" if self.cuda else "void"
functionPOD = MyPOD(prefix, node.functionName, )
funcDeclaration = c.FunctionDeclaration(functionPOD, functionArguments)
return c.FunctionBody(funcDeclaration, self._print(node.body))
funcDeclaration = "%s %s(%s)" % (prefix, node.functionName, ", ".join(functionArguments))
body = self._print(node.body)
return funcDeclaration + "\n" + body
def _print_Block(self, node):
return c.Block([self._print(child) for child in node.args])
blockContents = "\n".join([self._print(child) for child in node.args])
return "{\n%s\n}" % (textwrap.indent(blockContents, self._indent))
def _print_PragmaBlock(self, node):
class PragmaGenerable(c.Generable):
def __init__(self, line, block):
self._line = line
self._block = block
def generate(self):
yield self._line
for e in self._block.generate():
yield e
return PragmaGenerable(node.pragmaLine, self._print_Block(node))
return "%s\n%s" % (node.pragmaLine, self._print_Block(node))
def _print_LoopOverCoordinate(self, node):
class LoopWithOptionalPrefix(c.CustomLoop):
def __init__(self, intro_line, body, prefixLines=[]):
super(LoopWithOptionalPrefix, self).__init__(intro_line, body)
self.prefixLines = prefixLines
def generate(self):
for l in self.prefixLines:
yield l
for e in super(LoopWithOptionalPrefix, self).generate():
yield e
counterVar = node.loopCounterName
start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start))
condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop))
update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(node.step),)
loopStr = "for (%s; %s; %s)" % (start, condition, update)
return LoopWithOptionalPrefix(loopStr, self._print(node.body), prefixLines=node.prefixLines)
prefix = "\n".join(node.prefixLines)
if prefix:
prefix += "\n"
return "%s%s\n%s" % (prefix, loopStr, self._print(node.body))
def _print_SympyAssignment(self, node):
dtype = ""
......@@ -118,19 +100,17 @@ class CBackend:
dtype = "const " + node.lhs.dtype + " "
else:
dtype = node.lhs.dtype + " "
return c.Assign(dtype + self.sympyPrinter.doprint(node.lhs),
self.sympyPrinter.doprint(node.rhs))
return "%s %s = %s;" % (dtype, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
def _print_TemporaryMemoryAllocation(self, node):
return c.Assign(node.symbol.dtype + " * " + self.sympyPrinter.doprint(node.symbol),
"new %s[%s]" % (node.symbol.dtype, self.sympyPrinter.doprint(node.size)))
return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol),
node.symbol.dtype, self.sympyPrinter.doprint(node.size))
def _print_TemporaryMemoryFree(self, node):
return c.Statement("delete [] %s" % (self.sympyPrinter.doprint(node.symbol),))
return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol),)
def _print_CustomCppCode(self, node):
return c.LiteralLines(node.code)
return node.code
# ------------------------------------------ Helper function & classes -------------------------------------------------
......@@ -155,13 +135,4 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Piecewise(self, expr):
"""Print piecewise in one line (remove newlines)"""
result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
return result.replace("\n", "")
class MyPOD(c.Declarator):
def __init__(self, dtype, name):
self.dtype = dtype
self.name = name
def get_decl_pair(self):
return [self.dtype], self.name
return result.replace("\n", "")
\ No newline at end of file
......@@ -364,7 +364,7 @@ def splitInnerLoop(astNode, symbolGroups):
assignmentGroups.append(assignmentGroup)
newLoops = [innerLoop.newLoopWithDifferentBody(ast.Block(group)) for group in assignmentGroups]
innerLoop.parent.replace(innerLoop, newLoops)
innerLoop.parent.replace(innerLoop, ast.Block(newLoops))
for tmpArray in symbolsWithTemporaryArray:
outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArray, innerLoop.stop))
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment