Commit ed011ead authored by Martin Bauer's avatar Martin Bauer
Browse files

Support for float

parent c27719f0
...@@ -7,12 +7,16 @@ def generateC(astNode): ...@@ -7,12 +7,16 @@ def generateC(astNode):
""" """
Prints the abstract syntax tree as C function Prints the abstract syntax tree as C function
""" """
printer = CBackend(cuda=False) fieldTypes = set([f.dtype for f in astNode.fieldsAccessed])
useFloatConstants = "double" not in fieldTypes
printer = CBackend(cuda=False, constantsAsFloats=useFloatConstants)
return printer(astNode) return printer(astNode)
def generateCUDA(astNode): def generateCUDA(astNode):
printer = CBackend(cuda=True) fieldTypes = set([f.dtype for f in astNode.fieldsAccessed])
useFloatConstants = "double" not in fieldTypes
printer = CBackend(cuda=True, constantsAsFloats=useFloatConstants)
return printer(astNode) return printer(astNode)
# --------------------------------------- Backend Specific Nodes ------------------------------------------------------- # --------------------------------------- Backend Specific Nodes -------------------------------------------------------
...@@ -52,10 +56,10 @@ class PrintNode(CustomCppCode): ...@@ -52,10 +56,10 @@ class PrintNode(CustomCppCode):
class CBackend: class CBackend:
def __init__(self, cuda=False, sympyPrinter=None): def __init__(self, cuda=False, constantsAsFloats=False, sympyPrinter=None):
self.cuda = cuda self.cuda = cuda
if sympyPrinter is None: if sympyPrinter is None:
self.sympyPrinter = CustomSympyPrinter() self.sympyPrinter = CustomSympyPrinter(constantsAsFloats)
else: else:
self.sympyPrinter = sympyPrinter self.sympyPrinter = sympyPrinter
...@@ -121,6 +125,11 @@ class CBackend: ...@@ -121,6 +125,11 @@ class CBackend:
class CustomSympyPrinter(CCodePrinter): class CustomSympyPrinter(CCodePrinter):
def __init__(self, constantsAsFloats=False):
self._constantsAsFloats = constantsAsFloats
super(CustomSympyPrinter, self).__init__()
def _print_Pow(self, expr): def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication""" """Don't use std::pow function, for small integer exponents, write as multiplication"""
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
...@@ -130,7 +139,10 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -130,7 +139,10 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Rational(self, expr): def _print_Rational(self, expr):
"""Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0""" """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
return str(expr.evalf().num) res = str(expr.evalf().num)
if self._constantsAsFloats:
res += "f"
return res
def _print_Equality(self, expr): def _print_Equality(self, expr):
"""Equality operator is not printable in default printer""" """Equality operator is not printable in default printer"""
...@@ -139,4 +151,10 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -139,4 +151,10 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Piecewise(self, expr): def _print_Piecewise(self, expr):
"""Print piecewise in one line (remove newlines)""" """Print piecewise in one line (remove newlines)"""
result = super(CustomSympyPrinter, self)._print_Piecewise(expr) result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
return result.replace("\n", "") return result.replace("\n", "")
\ No newline at end of file
def _print_Float(self, expr):
res = str(expr)
if self._constantsAsFloats:
res += "f"
return res
...@@ -28,8 +28,10 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl ...@@ -28,8 +28,10 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
:return: :class:`pystencils.ast.KernelFunction` node :return: :class:`pystencils.ast.KernelFunction` node
""" """
if not typeForSymbol: if not typeForSymbol or typeForSymbol == 'double':
typeForSymbol = typingFromSympyInspection(listOfEquations, "double") typeForSymbol = typingFromSympyInspection(listOfEquations, "double")
elif typeForSymbol == 'float':
typeForSymbol = typingFromSympyInspection(listOfEquations, "float")
def typeSymbol(term): def typeSymbol(term):
if isinstance(term, Field.Access) or isinstance(term, TypedSymbol): if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
......
...@@ -71,7 +71,7 @@ class Field: ...@@ -71,7 +71,7 @@ class Field:
return Field(fieldName, npArray.dtype, spatialLayout, shape, strides) return Field(fieldName, npArray.dtype, spatialLayout, shape, strides)
@staticmethod @staticmethod
def createGeneric(fieldName, spatialDimensions, dtype=np.float64, indexDimensions=0, layout=None): def createGeneric(fieldName, spatialDimensions, dtype=np.float64, indexDimensions=0, layout='numpy'):
""" """
Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes
:param fieldName: symbolic name for the field :param fieldName: symbolic name for the field
...@@ -79,11 +79,13 @@ class Field: ...@@ -79,11 +79,13 @@ class Field:
:param spatialDimensions: see documentation of Field :param spatialDimensions: see documentation of Field
:param indexDimensions: see documentation of Field :param indexDimensions: see documentation of Field
:param layout: tuple specifying the loop ordering of the spatial dimensions e.g. (2, 1, 0 ) means that :param layout: tuple specifying the loop ordering of the spatial dimensions e.g. (2, 1, 0 ) means that
the outer loop loops over dimension 2, the second outer over dimension 1, and the inner loop the outer loop loops over dimension 2, the second outer over dimension 1, and the inner loop
over dimension 0 over dimension 0. Also allowed: the strings 'numpy' (0,1,..d) or 'reverseNumpy' (d, ..., 1, 0)
""" """
if not layout: if layout == 'numpy':
layout = tuple(range(spatialDimensions)) layout = tuple(range(spatialDimensions))
elif layout == 'reverseNumpy':
layout = tuple(reversed(range(spatialDimensions)))
if len(layout) != spatialDimensions: if len(layout) != spatialDimensions:
raise ValueError("Layout") raise ValueError("Layout")
shapeSymbol = IndexedBase(TypedSymbol(Field.SHAPE_PREFIX + fieldName, Field.SHAPE_DTYPE), shape=(1,)) shapeSymbol = IndexedBase(TypedSymbol(Field.SHAPE_PREFIX + fieldName, Field.SHAPE_DTYPE), shape=(1,))
......
from collections import defaultdict
import sympy as sp import sympy as sp
from pystencils.transformations import resolveFieldAccesses, typeAllEquations, parseBasePointerInfo from pystencils.transformations import resolveFieldAccesses, typeAllEquations, \
parseBasePointerInfo, typingFromSympyInspection
from pystencils.ast import Block, KernelFunction from pystencils.ast import Block, KernelFunction
from pystencils import Field from pystencils import Field
...@@ -31,7 +30,12 @@ def getLinewiseCoordinates(field, ghostLayers): ...@@ -31,7 +30,12 @@ def getLinewiseCoordinates(field, ghostLayers):
return [i + ghostLayers for i in result], getCallParameters return [i + ghostLayers for i in result], getCallParameters
def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defaultdict(lambda: "double")): def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None):
if not typeForSymbol or typeForSymbol == 'double':
typeForSymbol = typingFromSympyInspection(listOfEquations, "double")
elif typeForSymbol == 'float':
typeForSymbol = typingFromSympyInspection(listOfEquations, "float")
fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol) fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten]) readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment