Skip to content
Snippets Groups Projects
Commit 663dd123 authored by Martin Bauer's avatar Martin Bauer
Browse files

C backend: no dependence to cgen package any more

parent a2080a92
No related merge requests found
...@@ -341,7 +341,7 @@ class TemporaryMemoryAllocation(Node): ...@@ -341,7 +341,7 @@ class TemporaryMemoryAllocation(Node):
@property @property
def symbolsDefined(self): def symbolsDefined(self):
return set([self._symbol]) return set([self.symbol])
@property @property
def symbolsRead(self): def symbolsRead(self):
...@@ -349,12 +349,12 @@ class TemporaryMemoryAllocation(Node): ...@@ -349,12 +349,12 @@ class TemporaryMemoryAllocation(Node):
@property @property
def args(self): def args(self):
return [self._symbol] return [self.symbol]
class TemporaryMemoryFree(Node): class TemporaryMemoryFree(Node):
def __init__(self, typedSymbol): def __init__(self, typedSymbol):
self._symbol = typedSymbol self.symbol = typedSymbol
@property @property
def symbolsDefined(self): def symbolsDefined(self):
......
import cgen as c import textwrap
from sympy.utilities.codegen import CCodePrinter from sympy.utilities.codegen import CCodePrinter
from pystencils.ast import Node from pystencils.ast import Node
...@@ -40,9 +40,6 @@ class CustomCppCode(Node): ...@@ -40,9 +40,6 @@ class CustomCppCode(Node):
def symbolsRead(self): def symbolsRead(self):
return self._symbolsRead return self._symbolsRead
def generateC(self):
return c.LiteralLines(self._code)
class PrintNode(CustomCppCode): class PrintNode(CustomCppCode):
def __init__(self, symbolToPrint): def __init__(self, symbolToPrint):
...@@ -58,6 +55,7 @@ class CBackend: ...@@ -58,6 +55,7 @@ class CBackend:
def __init__(self, cuda=False): def __init__(self, cuda=False):
self.cuda = cuda self.cuda = cuda
self.sympyPrinter = CustomSympyPrinter() self.sympyPrinter = CustomSympyPrinter()
self._indent = " "
def __call__(self, node): def __call__(self, node):
return str(self._print(node)) return str(self._print(node))
...@@ -70,46 +68,30 @@ class CBackend: ...@@ -70,46 +68,30 @@ class CBackend:
raise NotImplementedError("CBackend does not support node of type " + cls.__name__) raise NotImplementedError("CBackend does not support node of type " + cls.__name__)
def _print_KernelFunction(self, node): def _print_KernelFunction(self, node):
functionArguments = [MyPOD(s.dtype, s.name) for s in node.parameters] functionArguments = ["%s %s" % (s.dtype, s.name) for s in node.parameters]
prefix = "__global__ void" if self.cuda else "void" prefix = "__global__ void" if self.cuda else "void"
functionPOD = MyPOD(prefix, node.functionName, ) funcDeclaration = "%s %s(%s)" % (prefix, node.functionName, ", ".join(functionArguments))
funcDeclaration = c.FunctionDeclaration(functionPOD, functionArguments) body = self._print(node.body)
return c.FunctionBody(funcDeclaration, self._print(node.body)) return funcDeclaration + "\n" + body
def _print_Block(self, node): def _print_Block(self, node):
return c.Block([self._print(child) for child in node.args]) blockContents = "\n".join([self._print(child) for child in node.args])
return "{\n%s\n}" % (textwrap.indent(blockContents, self._indent))
def _print_PragmaBlock(self, node): def _print_PragmaBlock(self, node):
class PragmaGenerable(c.Generable): return "%s\n%s" % (node.pragmaLine, self._print_Block(node))
def __init__(self, line, block):
self._line = line
self._block = block
def generate(self):
yield self._line
for e in self._block.generate():
yield e
return PragmaGenerable(node.pragmaLine, self._print_Block(node))
def _print_LoopOverCoordinate(self, node): def _print_LoopOverCoordinate(self, node):
class LoopWithOptionalPrefix(c.CustomLoop):
def __init__(self, intro_line, body, prefixLines=[]):
super(LoopWithOptionalPrefix, self).__init__(intro_line, body)
self.prefixLines = prefixLines
def generate(self):
for l in self.prefixLines:
yield l
for e in super(LoopWithOptionalPrefix, self).generate():
yield e
counterVar = node.loopCounterName counterVar = node.loopCounterName
start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start)) start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start))
condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop)) condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop))
update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(node.step),) update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(node.step),)
loopStr = "for (%s; %s; %s)" % (start, condition, update) loopStr = "for (%s; %s; %s)" % (start, condition, update)
return LoopWithOptionalPrefix(loopStr, self._print(node.body), prefixLines=node.prefixLines)
prefix = "\n".join(node.prefixLines)
if prefix:
prefix += "\n"
return "%s%s\n%s" % (prefix, loopStr, self._print(node.body))
def _print_SympyAssignment(self, node): def _print_SympyAssignment(self, node):
dtype = "" dtype = ""
...@@ -118,19 +100,17 @@ class CBackend: ...@@ -118,19 +100,17 @@ class CBackend:
dtype = "const " + node.lhs.dtype + " " dtype = "const " + node.lhs.dtype + " "
else: else:
dtype = node.lhs.dtype + " " dtype = node.lhs.dtype + " "
return "%s %s = %s;" % (dtype, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
return c.Assign(dtype + self.sympyPrinter.doprint(node.lhs),
self.sympyPrinter.doprint(node.rhs))
def _print_TemporaryMemoryAllocation(self, node): def _print_TemporaryMemoryAllocation(self, node):
return c.Assign(node.symbol.dtype + " * " + self.sympyPrinter.doprint(node.symbol), return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol),
"new %s[%s]" % (node.symbol.dtype, self.sympyPrinter.doprint(node.size))) node.symbol.dtype, self.sympyPrinter.doprint(node.size))
def _print_TemporaryMemoryFree(self, node): def _print_TemporaryMemoryFree(self, node):
return c.Statement("delete [] %s" % (self.sympyPrinter.doprint(node.symbol),)) return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol),)
def _print_CustomCppCode(self, node): def _print_CustomCppCode(self, node):
return c.LiteralLines(node.code) return node.code
# ------------------------------------------ Helper function & classes ------------------------------------------------- # ------------------------------------------ Helper function & classes -------------------------------------------------
...@@ -155,13 +135,4 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -155,13 +135,4 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Piecewise(self, expr): def _print_Piecewise(self, expr):
"""Print piecewise in one line (remove newlines)""" """Print piecewise in one line (remove newlines)"""
result = super(CustomSympyPrinter, self)._print_Piecewise(expr) result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
return result.replace("\n", "") return result.replace("\n", "")
\ No newline at end of file
class MyPOD(c.Declarator):
def __init__(self, dtype, name):
self.dtype = dtype
self.name = name
def get_decl_pair(self):
return [self.dtype], self.name
...@@ -364,7 +364,7 @@ def splitInnerLoop(astNode, symbolGroups): ...@@ -364,7 +364,7 @@ def splitInnerLoop(astNode, symbolGroups):
assignmentGroups.append(assignmentGroup) assignmentGroups.append(assignmentGroup)
newLoops = [innerLoop.newLoopWithDifferentBody(ast.Block(group)) for group in assignmentGroups] newLoops = [innerLoop.newLoopWithDifferentBody(ast.Block(group)) for group in assignmentGroups]
innerLoop.parent.replace(innerLoop, newLoops) innerLoop.parent.replace(innerLoop, ast.Block(newLoops))
for tmpArray in symbolsWithTemporaryArray: for tmpArray in symbolsWithTemporaryArray:
outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArray, innerLoop.stop)) outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArray, innerLoop.stop))
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment