Commit ed011ead authored by Martin Bauer's avatar Martin Bauer
Browse files

Support for float

parent c27719f0
......@@ -7,12 +7,16 @@ def generateC(astNode):
"""
Prints the abstract syntax tree as C function
"""
printer = CBackend(cuda=False)
fieldTypes = set([f.dtype for f in astNode.fieldsAccessed])
useFloatConstants = "double" not in fieldTypes
printer = CBackend(cuda=False, constantsAsFloats=useFloatConstants)
return printer(astNode)
def generateCUDA(astNode):
printer = CBackend(cuda=True)
fieldTypes = set([f.dtype for f in astNode.fieldsAccessed])
useFloatConstants = "double" not in fieldTypes
printer = CBackend(cuda=True, constantsAsFloats=useFloatConstants)
return printer(astNode)
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------
......@@ -52,10 +56,10 @@ class PrintNode(CustomCppCode):
class CBackend:
def __init__(self, cuda=False, sympyPrinter=None):
def __init__(self, cuda=False, constantsAsFloats=False, sympyPrinter=None):
self.cuda = cuda
if sympyPrinter is None:
self.sympyPrinter = CustomSympyPrinter()
self.sympyPrinter = CustomSympyPrinter(constantsAsFloats)
else:
self.sympyPrinter = sympyPrinter
......@@ -121,6 +125,11 @@ class CBackend:
class CustomSympyPrinter(CCodePrinter):
def __init__(self, constantsAsFloats=False):
self._constantsAsFloats = constantsAsFloats
super(CustomSympyPrinter, self).__init__()
def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication"""
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
......@@ -130,7 +139,10 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Rational(self, expr):
"""Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
return str(expr.evalf().num)
res = str(expr.evalf().num)
if self._constantsAsFloats:
res += "f"
return res
def _print_Equality(self, expr):
"""Equality operator is not printable in default printer"""
......@@ -139,4 +151,10 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Piecewise(self, expr):
"""Print piecewise in one line (remove newlines)"""
result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
return result.replace("\n", "")
\ No newline at end of file
return result.replace("\n", "")
def _print_Float(self, expr):
res = str(expr)
if self._constantsAsFloats:
res += "f"
return res
......@@ -28,8 +28,10 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
:return: :class:`pystencils.ast.KernelFunction` node
"""
if not typeForSymbol:
if not typeForSymbol or typeForSymbol == 'double':
typeForSymbol = typingFromSympyInspection(listOfEquations, "double")
elif typeForSymbol == 'float':
typeForSymbol = typingFromSympyInspection(listOfEquations, "float")
def typeSymbol(term):
if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
......
......@@ -71,7 +71,7 @@ class Field:
return Field(fieldName, npArray.dtype, spatialLayout, shape, strides)
@staticmethod
def createGeneric(fieldName, spatialDimensions, dtype=np.float64, indexDimensions=0, layout=None):
def createGeneric(fieldName, spatialDimensions, dtype=np.float64, indexDimensions=0, layout='numpy'):
"""
Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes
:param fieldName: symbolic name for the field
......@@ -79,11 +79,13 @@ class Field:
:param spatialDimensions: see documentation of Field
:param indexDimensions: see documentation of Field
:param layout: tuple specifying the loop ordering of the spatial dimensions e.g. (2, 1, 0 ) means that
the outer loop loops over dimension 2, the second outer over dimension 1, and the inner loop
over dimension 0
the outer loop loops over dimension 2, the second outer over dimension 1, and the inner loop
over dimension 0. Also allowed: the strings 'numpy' (0,1,..d) or 'reverseNumpy' (d, ..., 1, 0)
"""
if not layout:
if layout == 'numpy':
layout = tuple(range(spatialDimensions))
elif layout == 'reverseNumpy':
layout = tuple(reversed(range(spatialDimensions)))
if len(layout) != spatialDimensions:
raise ValueError("Layout")
shapeSymbol = IndexedBase(TypedSymbol(Field.SHAPE_PREFIX + fieldName, Field.SHAPE_DTYPE), shape=(1,))
......
from collections import defaultdict
import sympy as sp
from pystencils.transformations import resolveFieldAccesses, typeAllEquations, parseBasePointerInfo
from pystencils.transformations import resolveFieldAccesses, typeAllEquations, \
parseBasePointerInfo, typingFromSympyInspection
from pystencils.ast import Block, KernelFunction
from pystencils import Field
......@@ -31,7 +30,12 @@ def getLinewiseCoordinates(field, ghostLayers):
return [i + ghostLayers for i in result], getCallParameters
def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defaultdict(lambda: "double")):
def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None):
if not typeForSymbol or typeForSymbol == 'double':
typeForSymbol = typingFromSympyInspection(listOfEquations, "double")
elif typeForSymbol == 'float':
typeForSymbol = typingFromSympyInspection(listOfEquations, "float")
fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment