Commit ec3faf51 authored by Martin Bauer's avatar Martin Bauer
Browse files

pystencils: fields can now contain structs

- this extension is necessary for more generic boundary treatment
- cells can now be structs, i.e. contain different data types
- instead of having numeric index dimensions, one can use the index per cell to adress struct elements
parent c8b455fe
......@@ -420,7 +420,7 @@ class Conversion(Node):
raise set()
def __repr__(self):
return '(%s)' % (_c_dtype_dict(self.dtype)) + repr(self.args)
return '(%s)' % (self.dtype,) + repr(self.args)
# TODO everything which is not Atomic expression: Pow)
......@@ -481,6 +481,7 @@ class Indexed(Expr):
def __repr__(self):
return '%s[%s]' % (self.args[0], self.args[1])
class Number(Node):
def __init__(self, number, parent=None):
super(Number, self).__init__(parent)
......@@ -503,6 +504,6 @@ class Number(Node):
raise set()
def __repr__(self):
return '(%s)' % (_c_dtype_dict(self.dtype)) + repr(self.args)
return '(%s)' % (self.dtype,) + repr(self.args)
from sympy.utilities.codegen import CCodePrinter
from pystencils.astnodes import Node
from pystencils.types import createType
from pystencils.types import createType, PointerType
def generateC(astNode):
......@@ -150,3 +150,13 @@ class CustomSympyPrinter(CCodePrinter):
if self._constantsAsFloats:
res += "f"
return res
def _print_Indexed(self, expr):
result = super(CustomSympyPrinter, self)._print_Indexed(expr)
typedSymbol = expr.base.label
if typedSymbol.castTo is not None:
newType = typedSymbol.castTo
# e.g. *((double *)(& val[200]))
return "*((%s)(& %s))" % (PointerType(newType), result)
return result
......@@ -72,7 +72,7 @@ from ctypes import cdll, sizeof
from pystencils.backends.cbackend import generateC
from collections import OrderedDict, Mapping
from pystencils.transformations import symbolNameToVariableName
from pystencils.types import toCtypes, getBaseType, createType
from pystencils.types import toCtypes, getBaseType, createType, StructType
def makePythonFunction(kernelFunctionNode, argumentDict={}):
......@@ -366,34 +366,39 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict):
for arg in parameterSpecification:
if arg.isFieldArgument:
field = argumentDict[arg.fieldName]
fieldArr = argumentDict[arg.fieldName]
except KeyError:
raise KeyError("Missing field parameter for kernel call " + arg.fieldName)
symbolicField = arg.field
if arg.isFieldPtrArgument:
if symbolicField.hasFixedShape:
if tuple(int(i) for i in symbolicField.shape) != field.shape:
symbolicFieldShape = tuple(int(i) for i in symbolicField.shape)
if isinstance(symbolicField.dtype, StructType):
symbolicFieldShape = symbolicFieldShape[:-1]
if symbolicFieldShape != fieldArr.shape:
raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" %
(arg.fieldName, str(field.shape), str(symbolicField.shape)))
(arg.fieldName, str(fieldArr.shape), str(symbolicField.shape)))
if symbolicField.hasFixedShape:
if tuple(int(i) * field.itemsize for i in symbolicField.strides) != field.strides:
symbolicFieldStrides = tuple(int(i) * fieldArr.itemsize for i in symbolicField.strides)
if isinstance(symbolicField.dtype, StructType):
symbolicFieldStrides = symbolicFieldStrides[:-1]
if symbolicFieldStrides != fieldArr.strides:
raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
(arg.fieldName, str(field.strides), str(symbolicField.strides)))
(arg.fieldName, str(fieldArr.strides), str(symbolicFieldStrides)))
if not symbolicField.isIndexField:
elif arg.isFieldShapeArgument:
dataType = toCtypes(getBaseType(arg.dtype))
elif arg.isFieldStrideArgument:
dataType = toCtypes(getBaseType(arg.dtype))
baseFieldType = toCtypes(createType(field.dtype))
strides = field.ctypes.strides_as(dataType)
for i in range(len(field.shape)):
assert strides[i] % sizeof(baseFieldType) == 0
strides[i] //= sizeof(baseFieldType)
strides = fieldArr.ctypes.strides_as(dataType)
for i in range(len(fieldArr.shape)):
assert strides[i] % fieldArr.itemsize == 0
strides[i] //= fieldArr.itemsize
assert False
......@@ -3,7 +3,7 @@ import numpy as np
import sympy as sp
from sympy.core.cache import cacheit
from sympy.tensor import IndexedBase
from pystencils.types import TypedSymbol, createType
from pystencils.types import TypedSymbol, createType, StructType
class Field(object):
......@@ -72,6 +72,14 @@ class Field(object):
totalDimensions = spatialDimensions + indexDimensions
shape = tuple([shapeSymbol[i] for i in range(totalDimensions)])
strides = tuple([strideSymbol[i] for i in range(totalDimensions)])
npDataType = np.dtype(dtype)
if npDataType.fields is not None:
if indexDimensions != 0:
raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
shape += (1,)
strides += (1,)
return Field(fieldName, dtype, layout, shape, strides)
......@@ -94,6 +102,13 @@ class Field(object):
strides = tuple([s // np.dtype(npArray.dtype).itemsize for s in npArray.strides])
shape = tuple(int(s) for s in npArray.shape)
npDataType = np.dtype(npArray.dtype)
if npDataType.fields is not None:
if indexDimensions != 0:
raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
shape += (1,)
strides += (1,)
return Field(fieldName, npArray.dtype, spatialLayout, shape, strides)
......@@ -117,6 +132,14 @@ class Field(object):
shape = tuple(int(s) for s in shape)
strides = computeStrides(shape, layout)
npDataType = np.dtype(dtype)
if npDataType.fields is not None:
if indexDimensions != 0:
raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
shape += (1,)
strides += (1,)
return Field(fieldName, dtype, layout[:spatialDimensions], shape, strides)
def __init__(self, fieldName, dtype, layout, shape, strides):
......@@ -6,7 +6,7 @@ from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase
from pystencils.field import Field, offsetComponentToDirectionString
from pystencils.types import TypedSymbol, createType, PointerType
from pystencils.types import TypedSymbol, createType, PointerType, StructType, getBaseType
from pystencils.slicing import normalizeSlice
import pystencils.astnodes as ast
......@@ -248,8 +248,15 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
ctrName = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
coordDict[e] = TypedSymbol("%s_%d" % (ctrName, e), 'int')
coordDict[e] *= field.dtype.itemSize
coordDict[e] = fieldAccess.index[e-field.spatialDimensions]
if isinstance(field.dtype, StructType):
assert field.indexDimensions == 1
accessedFieldName = fieldAccess.index[0]
assert isinstance(accessedFieldName, str)
coordDict[e] = field.dtype.getElementOffset(accessedFieldName)
coordDict[e] = fieldAccess.index[e - field.spatialDimensions]
return coordDict
for group in reversed(basePointerInfo[1:]):
......@@ -260,10 +267,15 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
enclosingBlock.insertBefore(newAssignment, sympyAssignment)
lastPointer = newPtr
_, offset = createIntermediateBasePointer(fieldAccess, createCoordinateDict(basePointerInfo[0]),
coordDict = createCoordinateDict(basePointerInfo[0])
_, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
baseArr = IndexedBase(lastPointer, shape=(1,))
return baseArr[offset]
result = baseArr[offset]
if isinstance(getBaseType(fieldAccess.field.dtype), StructType):
typedSymbol = result.base.label
newType = fieldAccess.field.dtype.getElementType(fieldAccess.index[0])
typedSymbol.castTo = newType
return result
newArgs = [visitSympyExpr(e, enclosingBlock, sympyAssignment) for e in expr.args]
kwargs = {'evaluate': False} if type(expr) is sp.Add or type(expr) is sp.Mul else {}
......@@ -10,9 +10,10 @@ class TypedSymbol(sp.Symbol):
obj = TypedSymbol.__xnew_cached_(cls, name, *args, **kwds)
return obj
def __new_stage2__(cls, name, dtype):
def __new_stage2__(cls, name, dtype, castTo=None):
obj = super(TypedSymbol, cls).__xnew__(cls, name)
obj._dtype = createType(dtype)
obj.castTo = castTo
return obj
__xnew__ = staticmethod(__new_stage2__)
......@@ -24,11 +25,11 @@ class TypedSymbol(sp.Symbol):
def _hashable_content(self):
superClassContents = list(super(TypedSymbol, self)._hashable_content())
t = tuple(superClassContents + [hash(repr(self._dtype))])
t = tuple(superClassContents + [hash(repr(self._dtype) + repr(self.castTo))])
return t
def __getnewargs__(self):
return, self.dtype
return, self.dtype, self.castTo
def createType(specification):
......@@ -38,7 +39,10 @@ def createType(specification):
return createTypeFromString(specification)
npDataType = np.dtype(specification)
return BasicType(npDataType, const=False)
if npDataType.fields is None:
return BasicType(npDataType, const=False)
return StructType(npDataType, const=False)
def createTypeFromString(specification):
......@@ -88,6 +92,8 @@ def getBaseType(type):
def toCtypes(dataType):
if isinstance(dataType, PointerType):
return ctypes.POINTER(toCtypes(dataType.baseType))
elif isinstance(dataType, StructType):
return ctypes.POINTER(ctypes.c_uint8)
......@@ -120,7 +126,7 @@ class BasicType(Type):
width = int(name[len("int"):])
return "int%d_t" % (width,)
elif name.startswith('uint'):
width = int(name[len("int"):])
width = int(name[len("uint"):])
return "uint%d_t" % (width,)
elif name == 'bool':
return 'bool'
......@@ -142,6 +148,10 @@ class BasicType(Type):
def numpyDtype(self):
return self._dtype
def itemSize(self):
return 1
def __str__(self):
result = BasicType.numpyNameToC(str(self._dtype))
if self.const:
......@@ -172,6 +182,10 @@ class PointerType(Type):
def baseType(self):
return self._baseType
def itemSize(self):
return self.baseType.itemSize
def __eq__(self, other):
if not isinstance(other, PointerType):
return False
......@@ -179,13 +193,48 @@ class PointerType(Type):
return (self.baseType, self.const, self.restrict) == (other.baseType, other.const, other.restrict)
def __str__(self):
return "%s * %s%s" % (self.baseType, "RESTRICT " if self.restrict else "", "const " if self.const else "")
return "%s *%s%s" % (self.baseType, " RESTRICT" if self.restrict else "", " const" if self.const else "")
def __hash__(self):
return hash(str(self))
class StructType(object):
def __init__(self, numpyType):
def __init__(self, numpyType, const=False):
self.const = const
self._dtype = np.dtype(numpyType)
def baseType(self):
return None
def numpyDtype(self):
return self._dtype
def itemSize(self):
return self.numpyDtype.itemsize
def getElementOffset(self, elementName):
return self.numpyDtype.fields[elementName][1]
def getElementType(self, elementName):
npElementType = self.numpyDtype.fields[elementName][0]
return BasicType(npElementType, self.const)
def __eq__(self, other):
if not isinstance(other, StructType):
return False
return (self.numpyDtype, self.const) == (other.numpyDtype, other.const)
def __str__(self):
# structs are handled byte-wise
result = "uint8_t"
if self.const:
result += " const"
return result
def __hash__(self):
return hash((self.numpyDtype, self.const))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment