diff --git a/astnodes.py b/astnodes.py index ef8f4866ae836e84fb195d988193a03f2f647a4e..9a36744db7145b45e941decda8b1c90f436cbe03 100644 --- a/astnodes.py +++ b/astnodes.py @@ -294,7 +294,8 @@ class SympyAssignment(Node): self._lhsSymbol = lhsSymbol self.rhs = rhsTerm self._isDeclaration = True - if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase): + isCast = str(self._lhsSymbol.func).lower() == 'cast' + if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase) or isCast: self._isDeclaration = False self._isConst = isConst @@ -306,7 +307,8 @@ class SympyAssignment(Node): def lhs(self, newValue): self._lhsSymbol = newValue self._isDeclaration = True - if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed): + isCast = str(self._lhsSymbol.func).lower() == 'cast' + if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast: self._isDeclaration = False @property diff --git a/backends/cbackend.py b/backends/cbackend.py index 2bc3704999961f9590e83c354edf446a5608c8c0..c85e9be72769c2d79e91fb1319699a7bf34235cc 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -151,12 +151,10 @@ class CustomSympyPrinter(CCodePrinter): res += "f" return res - def _print_Indexed(self, expr): - result = super(CustomSympyPrinter, self)._print_Indexed(expr) - typedSymbol = expr.base.label - if typedSymbol.castTo is not None: - newType = typedSymbol.castTo - # e.g. *((double *)(& val[200])) - return "*((%s)(& %s))" % (PointerType(newType), result) + def _print_Function(self, expr): + name = str(expr.func).lower() + if name == 'cast': + arg, type = expr.args + return "*((%s)(& %s))" % (PointerType(type), self._print(arg)) else: - return result + return super(CustomSympyPrinter, self)._print_Function(expr) diff --git a/cpu/__init__.py b/cpu/__init__.py index 039f12025743f687d26c7ff548a81ab97cbb4f03..cec36fd9298f5645d1669edac101c24207cfd655 100644 --- a/cpu/__init__.py +++ b/cpu/__init__.py @@ -1,3 +1,3 @@ -from pystencils.cpu.kernelcreation import createKernel, addOpenMP +from pystencils.cpu.kernelcreation import createKernel, createIndexedKernel, addOpenMP from pystencils.cpu.cpujit import makePythonFunction from pystencils.backends.cbackend import generateC diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index f4e306a35a348bc8445b3b68f2c9adea6a311a63..26d710d1b0c605244a98cb7f4be7c5b310603b60 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -1,7 +1,9 @@ import sympy as sp -from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \ + +from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction +from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, \ typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop -from pystencils.types import TypedSymbol +from pystencils.types import TypedSymbol, BasicType, StructType from pystencils.field import Field import pystencils.astnodes as ast @@ -28,11 +30,6 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl :return: :class:`pystencils.ast.KernelFunction` node """ - if not typeForSymbol or typeForSymbol == 'double': - typeForSymbol = typingFromSympyInspection(listOfEquations, "double") - elif typeForSymbol == 'float': - typeForSymbol = typingFromSympyInspection(listOfEquations, "float") - def typeSymbol(term): if isinstance(term, Field.Access) or isinstance(term, TypedSymbol): return term @@ -63,6 +60,64 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl return code +def createIndexedKernel(listOfEquations, indexFields, typeForSymbol=None, coordinateNames=('x', 'y', 'z')): + """ + Similar to :func:`createKernel`, but here not all cells of a field are updated but only cells with + coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling. + + The coordinates are stored in a separated indexField, which is a one dimensional array with struct data type. + This struct has to contain fields named 'x', 'y' and for 3D fields ('z'). These names are configurable with the + 'coordinateNames' parameter. The struct can have also other fields that can be read and written in the kernel, for + example boundary parameters. + + :param listOfEquations: list of update equations or AST nodes + :param indexFields: list of index fields, i.e. 1D fields with struct data type + :param typeForSymbol: see documentation of :func:`createKernel` + :param coordinateNames: name of the coordinate fields in the struct data type + :return: abstract syntax tree + """ + fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol) + allFields = fieldsRead.union(fieldsWritten) + + for indexField in indexFields: + indexField.isIndexField = True + assert indexField.spatialDimensions == 1, "Index fields have to be 1D" + + nonIndexFields = [f for f in allFields if f not in indexFields] + spatialCoordinates = {f.spatialDimensions for f in nonIndexFields} + assert len(spatialCoordinates) == 1, "Non-index fields do not have the same number of spatial coordinates" + spatialCoordinates = list(spatialCoordinates)[0] + + def getCoordinateSymbolAssignment(name): + for indexField in indexFields: + assert isinstance(indexField.dtype, StructType), "Index fields have to have a struct datatype" + dataType = indexField.dtype + if dataType.hasElement(name): + rhs = indexField[0](name) + lhs = TypedSymbol(name, BasicType(dataType.getElementType(name))) + return SympyAssignment(lhs, rhs) + raise ValueError("Index %s not found in any of the passed index fields" % (name,)) + + coordinateSymbolAssignments = [getCoordinateSymbolAssignment(n) for n in coordinateNames[:spatialCoordinates]] + coordinateTypedSymbols = [eq.lhs for eq in coordinateSymbolAssignments] + assignments = coordinateSymbolAssignments + assignments + + # make 1D loop over index fields + loopBody = Block([]) + loopNode = LoopOverCoordinate(loopBody, coordinateToLoopOver=0, start=0, stop=indexFields[0].shape[0]) + + for assignment in assignments: + loopBody.append(assignment) + + functionBody = Block([loopNode]) + ast = KernelFunction(functionBody, allFields.union(indexFields)) + + fixedCoordinateMapping = {f.name: coordinateTypedSymbols for f in nonIndexFields} + resolveFieldAccesses(ast, set(['indexField']), fieldToFixedCoordinates=fixedCoordinateMapping) + moveConstantsBeforeLoop(ast) + return ast + + def addOpenMP(astNode, schedule="static", numThreads=None): """ Parallelizes the outer loop with OpenMP diff --git a/transformations.py b/transformations.py index f85733f6abc292fa1f7131eed3f842e8d87c3c60..93346f7aa7043ef16c72ba442929939d938f95e7 100644 --- a/transformations.py +++ b/transformations.py @@ -271,11 +271,12 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn _, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer) baseArr = IndexedBase(lastPointer, shape=(1,)) result = baseArr[offset] + castFunc = sp.Function("cast") if isinstance(getBaseType(fieldAccess.field.dtype), StructType): - typedSymbol = result.base.label newType = fieldAccess.field.dtype.getElementType(fieldAccess.index[0]) - typedSymbol.castTo = newType - return result + result = castFunc(result, newType) + + return visitSympyExpr(result, enclosingBlock, sympyAssignment) else: newArgs = [visitSympyExpr(e, enclosingBlock, sympyAssignment) for e in expr.args] kwargs = {'evaluate': False} if type(expr) is sp.Add or type(expr) is sp.Mul else {} @@ -427,6 +428,11 @@ def typeAllEquations(eqs, typeForSymbol): :return: ``fieldsRead, fieldsWritten, typedEquations`` set of read fields, set of written fields, list of equations where symbols have been replaced by typed symbols """ + if not typeForSymbol or typeForSymbol == 'double': + typeForSymbol = typingFromSympyInspection(eqs, "double") + elif typeForSymbol == 'float': + typeForSymbol = typingFromSympyInspection(eqs, "float") + fieldsWritten = set() fieldsRead = set() @@ -485,6 +491,8 @@ def typingFromSympyInspection(eqs, defaultType="double"): """ result = defaultdict(lambda: defaultType) for eq in eqs: + if isinstance(eq, ast.Node): + continue # problematic case here is when rhs is a symbol: then it is impossible to decide here without # further information what type the left hand side is - default fallback is the dict value then if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol): diff --git a/types.py b/types.py index 32deb2811e4dab57fc3fb2055452b7405a7e96a4..251688ce873fff3a8b0eaf5edc300ad8b7bf4e77 100644 --- a/types.py +++ b/types.py @@ -5,15 +5,13 @@ from sympy.core.cache import cacheit class TypedSymbol(sp.Symbol): - def __new__(cls, name, *args, **kwds): obj = TypedSymbol.__xnew_cached_(cls, name, *args, **kwds) return obj - def __new_stage2__(cls, name, dtype, castTo=None): + def __new_stage2__(cls, name, dtype): obj = super(TypedSymbol, cls).__xnew__(cls, name) obj._dtype = createType(dtype) - obj.castTo = castTo return obj __xnew__ = staticmethod(__new_stage2__) @@ -25,11 +23,30 @@ class TypedSymbol(sp.Symbol): def _hashable_content(self): superClassContents = list(super(TypedSymbol, self)._hashable_content()) - t = tuple(superClassContents + [hash(repr(self._dtype) + repr(self.castTo))]) + t = tuple(superClassContents + [hash(repr(self._dtype))]) return t def __getnewargs__(self): - return self.name, self.dtype, self.castTo + return self.name, self.dtype + + +#class IndexedWithCast(sp.tensor.Indexed): +# def __new__(cls, base, castTo, *args): +# obj = super(IndexedWithCast, cls).__new__(cls, base, *args) +# obj._castTo = castTo +# return obj +# +# @property +# def castTo(self): +# return self._castTo +# +# def _hashable_content(self): +# superClassContents = list(super(IndexedWithCast, self)._hashable_content()) +# t = tuple(superClassContents + [hash(repr(self._castTo))]) +# return t +# +# def __getnewargs__(self): +# return self.base, self.castTo def createType(specification): @@ -113,8 +130,9 @@ toCtypes.map = { } -class Type(object): - pass +class Type(sp.Basic): + def __new__(cls, *args, **kwargs): + return sp.Basic.__new__(cls) class BasicType(Type): @@ -135,11 +153,17 @@ class BasicType(Type): def __init__(self, dtype, const=False): self.const = const - self._dtype = np.dtype(dtype) + if isinstance(dtype, Type): + self._dtype = dtype.numpyDtype + else: + self._dtype = np.dtype(dtype) assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type" assert self._dtype.hasobject is False assert self._dtype.subdtype is None + def __getnewargs__(self): + return self.numpyDtype, self.const + @property def baseType(self): return None @@ -174,6 +198,9 @@ class PointerType(Type): self.const = const self.restrict = restrict + def __getnewargs__(self): + return self.baseType, self.const, self.restrict + @property def alias(self): return not self.restrict @@ -204,6 +231,9 @@ class StructType(object): self.const = const self._dtype = np.dtype(numpyType) + def __getnewargs__(self): + return self.numpyDtype, self.const + @property def baseType(self): return None @@ -223,6 +253,9 @@ class StructType(object): npElementType = self.numpyDtype.fields[elementName][0] return BasicType(npElementType, self.const) + def hasElement(self, elementName): + return elementName in self.numpyDtype.fields + def __eq__(self, other): if not isinstance(other, StructType): return False