Commit f58ed945 authored by Martin Bauer's avatar Martin Bauer
Browse files

Indexed Kernels & different cast mechanism with sympy functions

parent ec3faf51
......@@ -294,7 +294,8 @@ class SympyAssignment(Node):
self._lhsSymbol = lhsSymbol
self.rhs = rhsTerm
self._isDeclaration = True
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase):
isCast = str(self._lhsSymbol.func).lower() == 'cast'
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase) or isCast:
self._isDeclaration = False
self._isConst = isConst
......@@ -306,7 +307,8 @@ class SympyAssignment(Node):
def lhs(self, newValue):
self._lhsSymbol = newValue
self._isDeclaration = True
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed):
isCast = str(self._lhsSymbol.func).lower() == 'cast'
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
self._isDeclaration = False
@property
......
......@@ -151,12 +151,10 @@ class CustomSympyPrinter(CCodePrinter):
res += "f"
return res
def _print_Indexed(self, expr):
result = super(CustomSympyPrinter, self)._print_Indexed(expr)
typedSymbol = expr.base.label
if typedSymbol.castTo is not None:
newType = typedSymbol.castTo
# e.g. *((double *)(& val[200]))
return "*((%s)(& %s))" % (PointerType(newType), result)
def _print_Function(self, expr):
name = str(expr.func).lower()
if name == 'cast':
arg, type = expr.args
return "*((%s)(& %s))" % (PointerType(type), self._print(arg))
else:
return result
return super(CustomSympyPrinter, self)._print_Function(expr)
from pystencils.cpu.kernelcreation import createKernel, addOpenMP
from pystencils.cpu.kernelcreation import createKernel, createIndexedKernel, addOpenMP
from pystencils.cpu.cpujit import makePythonFunction
from pystencils.backends.cbackend import generateC
import sympy as sp
from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \
from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, \
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
from pystencils.types import TypedSymbol
from pystencils.types import TypedSymbol, BasicType, StructType
from pystencils.field import Field
import pystencils.astnodes as ast
......@@ -28,11 +30,6 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
:return: :class:`pystencils.ast.KernelFunction` node
"""
if not typeForSymbol or typeForSymbol == 'double':
typeForSymbol = typingFromSympyInspection(listOfEquations, "double")
elif typeForSymbol == 'float':
typeForSymbol = typingFromSympyInspection(listOfEquations, "float")
def typeSymbol(term):
if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
return term
......@@ -63,6 +60,64 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
return code
def createIndexedKernel(listOfEquations, indexFields, typeForSymbol=None, coordinateNames=('x', 'y', 'z')):
"""
Similar to :func:`createKernel`, but here not all cells of a field are updated but only cells with
coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling.
The coordinates are stored in a separated indexField, which is a one dimensional array with struct data type.
This struct has to contain fields named 'x', 'y' and for 3D fields ('z'). These names are configurable with the
'coordinateNames' parameter. The struct can have also other fields that can be read and written in the kernel, for
example boundary parameters.
:param listOfEquations: list of update equations or AST nodes
:param indexFields: list of index fields, i.e. 1D fields with struct data type
:param typeForSymbol: see documentation of :func:`createKernel`
:param coordinateNames: name of the coordinate fields in the struct data type
:return: abstract syntax tree
"""
fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
allFields = fieldsRead.union(fieldsWritten)
for indexField in indexFields:
indexField.isIndexField = True
assert indexField.spatialDimensions == 1, "Index fields have to be 1D"
nonIndexFields = [f for f in allFields if f not in indexFields]
spatialCoordinates = {f.spatialDimensions for f in nonIndexFields}
assert len(spatialCoordinates) == 1, "Non-index fields do not have the same number of spatial coordinates"
spatialCoordinates = list(spatialCoordinates)[0]
def getCoordinateSymbolAssignment(name):
for indexField in indexFields:
assert isinstance(indexField.dtype, StructType), "Index fields have to have a struct datatype"
dataType = indexField.dtype
if dataType.hasElement(name):
rhs = indexField[0](name)
lhs = TypedSymbol(name, BasicType(dataType.getElementType(name)))
return SympyAssignment(lhs, rhs)
raise ValueError("Index %s not found in any of the passed index fields" % (name,))
coordinateSymbolAssignments = [getCoordinateSymbolAssignment(n) for n in coordinateNames[:spatialCoordinates]]
coordinateTypedSymbols = [eq.lhs for eq in coordinateSymbolAssignments]
assignments = coordinateSymbolAssignments + assignments
# make 1D loop over index fields
loopBody = Block([])
loopNode = LoopOverCoordinate(loopBody, coordinateToLoopOver=0, start=0, stop=indexFields[0].shape[0])
for assignment in assignments:
loopBody.append(assignment)
functionBody = Block([loopNode])
ast = KernelFunction(functionBody, allFields.union(indexFields))
fixedCoordinateMapping = {f.name: coordinateTypedSymbols for f in nonIndexFields}
resolveFieldAccesses(ast, set(['indexField']), fieldToFixedCoordinates=fixedCoordinateMapping)
moveConstantsBeforeLoop(ast)
return ast
def addOpenMP(astNode, schedule="static", numThreads=None):
"""
Parallelizes the outer loop with OpenMP
......
......@@ -271,11 +271,12 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
_, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
baseArr = IndexedBase(lastPointer, shape=(1,))
result = baseArr[offset]
castFunc = sp.Function("cast")
if isinstance(getBaseType(fieldAccess.field.dtype), StructType):
typedSymbol = result.base.label
newType = fieldAccess.field.dtype.getElementType(fieldAccess.index[0])
typedSymbol.castTo = newType
return result
result = castFunc(result, newType)
return visitSympyExpr(result, enclosingBlock, sympyAssignment)
else:
newArgs = [visitSympyExpr(e, enclosingBlock, sympyAssignment) for e in expr.args]
kwargs = {'evaluate': False} if type(expr) is sp.Add or type(expr) is sp.Mul else {}
......@@ -427,6 +428,11 @@ def typeAllEquations(eqs, typeForSymbol):
:return: ``fieldsRead, fieldsWritten, typedEquations`` set of read fields, set of written fields, list of equations
where symbols have been replaced by typed symbols
"""
if not typeForSymbol or typeForSymbol == 'double':
typeForSymbol = typingFromSympyInspection(eqs, "double")
elif typeForSymbol == 'float':
typeForSymbol = typingFromSympyInspection(eqs, "float")
fieldsWritten = set()
fieldsRead = set()
......@@ -485,6 +491,8 @@ def typingFromSympyInspection(eqs, defaultType="double"):
"""
result = defaultdict(lambda: defaultType)
for eq in eqs:
if isinstance(eq, ast.Node):
continue
# problematic case here is when rhs is a symbol: then it is impossible to decide here without
# further information what type the left hand side is - default fallback is the dict value then
if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
......
......@@ -5,15 +5,13 @@ from sympy.core.cache import cacheit
class TypedSymbol(sp.Symbol):
def __new__(cls, name, *args, **kwds):
obj = TypedSymbol.__xnew_cached_(cls, name, *args, **kwds)
return obj
def __new_stage2__(cls, name, dtype, castTo=None):
def __new_stage2__(cls, name, dtype):
obj = super(TypedSymbol, cls).__xnew__(cls, name)
obj._dtype = createType(dtype)
obj.castTo = castTo
return obj
__xnew__ = staticmethod(__new_stage2__)
......@@ -25,11 +23,30 @@ class TypedSymbol(sp.Symbol):
def _hashable_content(self):
superClassContents = list(super(TypedSymbol, self)._hashable_content())
t = tuple(superClassContents + [hash(repr(self._dtype) + repr(self.castTo))])
t = tuple(superClassContents + [hash(repr(self._dtype))])
return t
def __getnewargs__(self):
return self.name, self.dtype, self.castTo
return self.name, self.dtype
#class IndexedWithCast(sp.tensor.Indexed):
# def __new__(cls, base, castTo, *args):
# obj = super(IndexedWithCast, cls).__new__(cls, base, *args)
# obj._castTo = castTo
# return obj
#
# @property
# def castTo(self):
# return self._castTo
#
# def _hashable_content(self):
# superClassContents = list(super(IndexedWithCast, self)._hashable_content())
# t = tuple(superClassContents + [hash(repr(self._castTo))])
# return t
#
# def __getnewargs__(self):
# return self.base, self.castTo
def createType(specification):
......@@ -113,8 +130,9 @@ toCtypes.map = {
}
class Type(object):
pass
class Type(sp.Basic):
def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls)
class BasicType(Type):
......@@ -135,11 +153,17 @@ class BasicType(Type):
def __init__(self, dtype, const=False):
self.const = const
self._dtype = np.dtype(dtype)
if isinstance(dtype, Type):
self._dtype = dtype.numpyDtype
else:
self._dtype = np.dtype(dtype)
assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
assert self._dtype.hasobject is False
assert self._dtype.subdtype is None
def __getnewargs__(self):
return self.numpyDtype, self.const
@property
def baseType(self):
return None
......@@ -174,6 +198,9 @@ class PointerType(Type):
self.const = const
self.restrict = restrict
def __getnewargs__(self):
return self.baseType, self.const, self.restrict
@property
def alias(self):
return not self.restrict
......@@ -204,6 +231,9 @@ class StructType(object):
self.const = const
self._dtype = np.dtype(numpyType)
def __getnewargs__(self):
return self.numpyDtype, self.const
@property
def baseType(self):
return None
......@@ -223,6 +253,9 @@ class StructType(object):
npElementType = self.numpyDtype.fields[elementName][0]
return BasicType(npElementType, self.const)
def hasElement(self, elementName):
return elementName in self.numpyDtype.fields
def __eq__(self, other):
if not isinstance(other, StructType):
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment