Commit bd56d31f authored by Martin Bauer's avatar Martin Bauer
Browse files

Major refactoring: separated Ast and generateC code

parent 61541046
import sympy as sp
from sympy.tensor import IndexedBase, Indexed
from pystencils.field import Field
from pystencils.typedsymbol import TypedSymbol
class Node:
def __init__(self, parent=None):
self.parent = parent
def args(self):
return []
@property
def symbolsDefined(self):
return set()
@property
def symbolsRead(self):
return set()
def atoms(self, argType):
result = set()
for arg in self.args:
if isinstance(arg, argType):
result.add(arg)
result.update(arg.atoms(argType))
return result
class KernelFunction(Node):
class Argument:
def __init__(self, name, dtype):
self.name = name
self.dtype = dtype
self.isFieldPtrArgument = False
self.isFieldShapeArgument = False
self.isFieldStrideArgument = False
self.isFieldArgument = False
self.fieldName = ""
self.coordinate = None
if name.startswith(Field.DATA_PREFIX):
self.isFieldPtrArgument = True
self.isFieldArgument = True
self.fieldName = name[len(Field.DATA_PREFIX):]
elif name.startswith(Field.SHAPE_PREFIX):
self.isFieldShapeArgument = True
self.isFieldArgument = True
self.fieldName = name[len(Field.SHAPE_PREFIX):]
elif name.startswith(Field.STRIDE_PREFIX):
self.isFieldStrideArgument = True
self.isFieldArgument = True
self.fieldName = name[len(Field.STRIDE_PREFIX):]
def __init__(self, body, functionName="kernel"):
super(KernelFunction, self).__init__()
self._body = body
self._parameters = None
self._functionName = functionName
self._body.parent = self
self.variablesToIgnore = set()
@property
def symbolsDefined(self):
return set()
@property
def symbolsRead(self):
return set()
@property
def parameters(self):
self._updateParameters()
return self._parameters
@property
def body(self):
return self._body
@property
def args(self):
return [self._body]
@property
def functionName(self):
return self._functionName
def _updateParameters(self):
undefinedSymbols = self._body.symbolsRead - self._body.symbolsDefined - self.variablesToIgnore
self._parameters = [KernelFunction.Argument(s.name, s.dtype) for s in undefinedSymbols]
self._parameters.sort(key=lambda l: (l.fieldName, l.isFieldPtrArgument, l.isFieldShapeArgument,
l.isFieldStrideArgument, l.name),
reverse=True)
class Block(Node):
def __init__(self, listOfNodes):
super(Node, self).__init__()
self._nodes = listOfNodes
for n in self._nodes:
n.parent = self
@property
def args(self):
return self._nodes
def insertFront(self, node):
node.parent = self
self._nodes.insert(0, node)
def append(self, node):
node.parent = self
self._nodes.append(node)
def takeChildNodes(self):
tmp = self._nodes
self._nodes = []
return tmp
def replace(self, child, replacements):
idx = self._nodes.index(child)
del self._nodes[idx]
if type(replacements) is list:
for e in replacements:
e.parent = self
self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
else:
replacements.parent = self
self._nodes.insert(idx, replacements)
@property
def symbolsDefined(self):
result = set()
for a in self.args:
result.update(a.symbolsDefined)
return result
@property
def symbolsRead(self):
result = set()
for a in self.args:
result.update(a.symbolsRead)
return result
class PragmaBlock(Block):
def __init__(self, pragmaLine, listOfNodes):
super(PragmaBlock, self).__init__(listOfNodes)
self.pragmaLine = pragmaLine
class LoopOverCoordinate(Node):
LOOP_COUNTER_NAME_PREFIX = "ctr"
def __init__(self, body, coordinateToLoopOver, shape, increment=1, ghostLayers=1,
isInnermostLoop=False, isOutermostLoop=False):
self._body = body
self._coordinateToLoopOver = coordinateToLoopOver
self._shape = shape
self._increment = increment
self._ghostLayers = ghostLayers
self._body.parent = self
self.prefixLines = []
self._isInnermostLoop = isInnermostLoop
self._isOutermostLoop = isOutermostLoop
def newLoopWithDifferentBody(self, newBody):
result = LoopOverCoordinate(newBody, self._coordinateToLoopOver, self._shape, self._increment,
self._ghostLayers, self._isInnermostLoop, self._isOutermostLoop)
result.prefixLines = self.prefixLines
return result
@property
def args(self):
result = [self._body]
limit = self._shape[self._coordinateToLoopOver]
if isinstance(limit, Node) or isinstance(limit, sp.Basic):
result.append(limit)
return result
@property
def body(self):
return self._body
@property
def iterationEnd(self):
return self._shape[self.coordinateToLoopOver] - self.ghostLayers
@property
def coordinateToLoopOver(self):
return self._coordinateToLoopOver
@property
def symbolsDefined(self):
result = self._body.symbolsDefined
result.add(self.loopCounterSymbol)
return result
@property
def loopCounterName(self):
return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, self._coordinateToLoopOver)
@property
def loopCounterSymbol(self):
return TypedSymbol(self.loopCounterName, "int")
@property
def symbolsRead(self):
result = self._body.symbolsRead
limit = self._shape[self._coordinateToLoopOver]
if isinstance(limit, sp.Basic):
result.update(limit.atoms(sp.Symbol))
return result
@property
def isOutermostLoop(self):
return self._isOutermostLoop
@property
def isInnermostLoop(self):
return self._isInnermostLoop
@property
def coordinateToLoopOver(self):
return self._coordinateToLoopOver
@property
def iterationRegionWithGhostLayer(self):
return self._shape[self._coordinateToLoopOver]
@property
def ghostLayers(self):
return self._ghostLayers
class SympyAssignment(Node):
def __init__(self, lhsSymbol, rhsTerm, isConst=True):
self._lhsSymbol = lhsSymbol
self.rhs = rhsTerm
self._isDeclaration = True
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase):
self._isDeclaration = False
self._isConst = isConst
@property
def lhs(self):
return self._lhsSymbol
@lhs.setter
def lhs(self, newValue):
self._lhsSymbol = newValue
self._isDeclaration = True
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, Indexed):
self._isDeclaration = False
@property
def args(self):
return [self._lhsSymbol, self.rhs]
@property
def symbolsDefined(self):
if not self._isDeclaration:
return set()
return set([self._lhsSymbol])
@property
def symbolsRead(self):
result = self.rhs.atoms(sp.Symbol)
result.update(self._lhsSymbol.atoms(sp.Symbol))
return result
@property
def isDeclaration(self):
return self._isDeclaration
@property
def isConst(self):
return self._isConst
def __repr__(self):
return repr(self.lhs) + " = " + repr(self.rhs)
class TemporaryMemoryAllocation(Node):
def __init__(self, typedSymbol, size):
self.symbol = typedSymbol
self.size = size
@property
def symbolsDefined(self):
return set([self._symbol])
@property
def symbolsRead(self):
return set()
@property
def args(self):
return [self._symbol]
class TemporaryMemoryFree(Node):
def __init__(self, typedSymbol):
self._symbol = typedSymbol
@property
def symbolsDefined(self):
return set()
@property
def symbolsRead(self):
return set()
@property
def args(self):
return []
import cgen as c
from sympy.utilities.codegen import CCodePrinter
from pystencils.ast import Node
def printCCode(astNode):
printer = CBackend(cuda=False)
return printer(astNode)
def printCudaCode(astNode):
printer = CBackend(cuda=True)
return printer(astNode)
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------
class CustomCppCode(Node):
def __init__(self, code, symbolsRead, symbolsDefined):
self._code = "\n" + code
self._symbolsRead = set(symbolsRead)
self._symbolsDefined = set(symbolsDefined)
@property
def code(self):
return self._code
@property
def args(self):
return []
@property
def symbolsDefined(self):
return self._symbolsDefined
@property
def symbolsRead(self):
return self._symbolsRead
def generateC(self):
return c.LiteralLines(self._code)
class PrintNode(CustomCppCode):
def __init__(self, symbolToPrint):
code = '\nstd::cout << "%s = " << %s << std::endl; \n' % (symbolToPrint.name, symbolToPrint.name)
super(PrintNode, self).__init__(code, symbolsRead=[symbolToPrint], symbolsDefined=set())
# ------------------------------------------- Printer ------------------------------------------------------------------
class CBackend:
def __init__(self, cuda=False):
self.cuda = cuda
self.sympyPrinter = CustomSympyPrinter()
def __call__(self, node):
return self._print(node)
def _print(self, node):
for cls in type(node).__mro__:
methodName = "_print_" + cls.__name__
if hasattr(self, methodName):
return getattr(self, methodName)(node)
raise NotImplementedError("CBackend does not support node of type " + cls.__name__)
def _print_KernelFunction(self, node):
functionArguments = [MyPOD(s.dtype, s.name) for s in node.parameters]
prefix = "__global__ void" if self.cuda else "void"
functionPOD = MyPOD(prefix, node.functionName, )
funcDeclaration = c.FunctionDeclaration(functionPOD, functionArguments)
return c.FunctionBody(funcDeclaration, self._print(node.body))
def _print_Block(self, node):
return c.Block([self._print(child) for child in node.args])
def _print_PragmaBlock(self, node):
class PragmaGenerable(c.Generable):
def __init__(self, line, block):
self._line = line
self._block = block
def generate(self):
yield self._line
for e in self._block.generate():
yield e
return PragmaGenerable(node.pragmaLine, self._print_Block(node))
def _print_LoopOverCoordinate(self, node):
class LoopWithOptionalPrefix(c.CustomLoop):
def __init__(self, intro_line, body, prefixLines=[]):
super(LoopWithOptionalPrefix, self).__init__(intro_line, body)
self.prefixLines = prefixLines
def generate(self):
for l in self.prefixLines:
yield l
for e in super(LoopWithOptionalPrefix, self).generate():
yield e
counterVar = node.loopCounterName
start = "int %s = %d" % (counterVar, node.ghostLayers)
condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.iterationEnd))
update = "++%s" % (counterVar,)
loopStr = "for (%s; %s; %s)" % (start, condition, update)
return LoopWithOptionalPrefix(loopStr, self._print(node.body), prefixLines=node.prefixLines)
def _print_SympyAssignment(self, node):
dtype = ""
if node.isDeclaration:
if node.isConst:
dtype = "const " + node.lhs.dtype + " "
else:
dtype = node.lhs.dtype + " "
return c.Assign(dtype + self.sympyPrinter.doprint(node.lhs),
self.sympyPrinter.doprint(node.rhs))
def _print_TemporaryMemoryAllocation(self, node):
return c.Assign(node.symbol.dtype + " * " + self.sympyPrinter.doprint(node.symbol),
"new %s[%s]" % (node.symbol.dtype, self.sympyPrinter.doprint(node.size)))
def _print_TemporaryMemoryFree(self, node):
return c.Statement("delete [] %s" % (self.sympyPrinter.doprint(node.symbol),))
def _print_CustomCppCode(self, node):
return c.LiteralLines(node.code)
# ------------------------------------------ Helper function & classes -------------------------------------------------
class CustomSympyPrinter(CCodePrinter):
def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication"""
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return '(' + '*'.join(["(" + self._print(expr.base) + ")"] * expr.exp) + ')'
else:
return super(CustomSympyPrinter, self)._print_Pow(expr)
def _print_Rational(self, expr):
"""Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
return str(expr.evalf().num)
def _print_Equality(self, expr):
"""Equality operator is not printable in default printer"""
return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'
def _print_Piecewise(self, expr):
"""Print piecewise in one line (remove newlines)"""
result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
return result.replace("\n", "")
class MyPOD(c.Declarator):
def __init__(self, dtype, name):
self.dtype = dtype
self.name = name
def get_decl_pair(self):
return [self.dtype], self.name
......@@ -300,6 +300,7 @@ class Field:
SHAPE_PREFIX = PREFIX + "shape_"
STRIDE_DTYPE = "const int *"
SHAPE_DTYPE = "const int *"
DATA_PREFIX = PREFIX + "d_"
class Access(sp.Symbol):
def __new__(cls, name, *args, **kwargs):
......
import numpy as np
import sympy as sp
from pystencils.generator import Field
from pystencils.field import Field
def __upDownOffsets(d, dim):
......
......@@ -2,7 +2,7 @@ import os
import subprocess
from ctypes import cdll, c_double, c_float, sizeof
from tempfile import TemporaryDirectory
from pystencils.backends.cbackend import printCCode
import numpy as np
......@@ -67,7 +67,7 @@ def compileAndLoad(kernelFunctionNode):
print('#include <iostream>', file=sourceFile)
print("#include <cmath>", file=sourceFile)
print('extern "C" { ', file=sourceFile)
print(kernelFunctionNode.generateC(), file=sourceFile)
print(printCCode(kernelFunctionNode), file=sourceFile)
print('}', file=sourceFile)
compilerCmd = [CONFIG['compiler']] + CONFIG['flags'].split()
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment