From ec3faf5138954c74b358d862647d5bf7e9007fae Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Tue, 14 Mar 2017 19:08:45 +0100
Subject: [PATCH] pystencils: fields can now contain structs

- this extension is necessary for more generic boundary treatment
- cells can now be structs, i.e. contain different data types
- instead of having numeric index dimensions, one can use the index per cell to adress struct elements
---
 astnodes.py          |  5 ++--
 backends/cbackend.py | 12 ++++++++-
 cpu/cpujit.py        | 33 +++++++++++++----------
 field.py             | 25 +++++++++++++++++-
 transformations.py   | 22 ++++++++++++----
 types.py             | 63 +++++++++++++++++++++++++++++++++++++++-----
 6 files changed, 130 insertions(+), 30 deletions(-)

diff --git a/astnodes.py b/astnodes.py
index fe89cb318..ef8f4866a 100644
--- a/astnodes.py
+++ b/astnodes.py
@@ -420,7 +420,7 @@ class Conversion(Node):
         raise set()
 
     def __repr__(self):
-        return '(%s)' % (_c_dtype_dict(self.dtype)) + repr(self.args)
+        return '(%s)' % (self.dtype,) + repr(self.args)
 
 # TODO everything which is not Atomic expression: Pow)
 
@@ -481,6 +481,7 @@ class Indexed(Expr):
     def __repr__(self):
         return '%s[%s]' % (self.args[0], self.args[1])
 
+
 class Number(Node):
     def __init__(self, number, parent=None):
         super(Number, self).__init__(parent)
@@ -503,6 +504,6 @@ class Number(Node):
         raise set()
 
     def __repr__(self):
-        return '(%s)' % (_c_dtype_dict(self.dtype)) + repr(self.args)
+        return '(%s)' % (self.dtype,) + repr(self.args)
 
 
diff --git a/backends/cbackend.py b/backends/cbackend.py
index 63172de8a..2bc370499 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -1,6 +1,6 @@
 from sympy.utilities.codegen import CCodePrinter
 from pystencils.astnodes import Node
-from pystencils.types import createType
+from pystencils.types import createType, PointerType
 
 
 def generateC(astNode):
@@ -150,3 +150,13 @@ class CustomSympyPrinter(CCodePrinter):
         if self._constantsAsFloats:
             res += "f"
         return res
+
+    def _print_Indexed(self, expr):
+        result = super(CustomSympyPrinter, self)._print_Indexed(expr)
+        typedSymbol = expr.base.label
+        if typedSymbol.castTo is not None:
+            newType = typedSymbol.castTo
+            # e.g.  *((double *)(& val[200]))
+            return "*((%s)(& %s))" % (PointerType(newType), result)
+        else:
+            return result
diff --git a/cpu/cpujit.py b/cpu/cpujit.py
index 365b0462c..5be04bb0c 100644
--- a/cpu/cpujit.py
+++ b/cpu/cpujit.py
@@ -72,7 +72,7 @@ from ctypes import cdll, sizeof
 from pystencils.backends.cbackend import generateC
 from collections import OrderedDict, Mapping
 from pystencils.transformations import symbolNameToVariableName
-from pystencils.types import toCtypes, getBaseType, createType
+from pystencils.types import toCtypes, getBaseType, createType, StructType
 
 
 def makePythonFunction(kernelFunctionNode, argumentDict={}):
@@ -366,34 +366,39 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict):
     for arg in parameterSpecification:
         if arg.isFieldArgument:
             try:
-                field = argumentDict[arg.fieldName]
+                fieldArr = argumentDict[arg.fieldName]
             except KeyError:
                 raise KeyError("Missing field parameter for kernel call " + arg.fieldName)
 
             symbolicField = arg.field
             if arg.isFieldPtrArgument:
-                ctArguments.append(field.ctypes.data_as(toCtypes(arg.dtype)))
+                ctArguments.append(fieldArr.ctypes.data_as(toCtypes(arg.dtype)))
                 if symbolicField.hasFixedShape:
-                    if tuple(int(i) for i in symbolicField.shape) != field.shape:
+                    symbolicFieldShape = tuple(int(i) for i in symbolicField.shape)
+                    if isinstance(symbolicField.dtype, StructType):
+                        symbolicFieldShape = symbolicFieldShape[:-1]
+                    if symbolicFieldShape != fieldArr.shape:
                         raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" %
-                                         (arg.fieldName, str(field.shape), str(symbolicField.shape)))
+                                         (arg.fieldName, str(fieldArr.shape), str(symbolicField.shape)))
                 if symbolicField.hasFixedShape:
-                    if tuple(int(i) * field.itemsize for i in symbolicField.strides) != field.strides:
+                    symbolicFieldStrides = tuple(int(i) * fieldArr.itemsize for i in symbolicField.strides)
+                    if isinstance(symbolicField.dtype, StructType):
+                        symbolicFieldStrides = symbolicFieldStrides[:-1]
+                    if symbolicFieldStrides != fieldArr.strides:
                         raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
-                                         (arg.fieldName, str(field.strides), str(symbolicField.strides)))
+                                         (arg.fieldName, str(fieldArr.strides), str(symbolicFieldStrides)))
 
                 if not symbolicField.isIndexField:
-                    arrayShapes.add(field.shape[:symbolicField.spatialDimensions])
+                    arrayShapes.add(fieldArr.shape[:symbolicField.spatialDimensions])
             elif arg.isFieldShapeArgument:
                 dataType = toCtypes(getBaseType(arg.dtype))
-                ctArguments.append(field.ctypes.shape_as(dataType))
+                ctArguments.append(fieldArr.ctypes.shape_as(dataType))
             elif arg.isFieldStrideArgument:
                 dataType = toCtypes(getBaseType(arg.dtype))
-                baseFieldType = toCtypes(createType(field.dtype))
-                strides = field.ctypes.strides_as(dataType)
-                for i in range(len(field.shape)):
-                    assert strides[i] % sizeof(baseFieldType) == 0
-                    strides[i] //= sizeof(baseFieldType)
+                strides = fieldArr.ctypes.strides_as(dataType)
+                for i in range(len(fieldArr.shape)):
+                    assert strides[i] % fieldArr.itemsize == 0
+                    strides[i] //= fieldArr.itemsize
                 ctArguments.append(strides)
             else:
                 assert False
diff --git a/field.py b/field.py
index 546e049d2..66cc1181d 100644
--- a/field.py
+++ b/field.py
@@ -3,7 +3,7 @@ import numpy as np
 import sympy as sp
 from sympy.core.cache import cacheit
 from sympy.tensor import IndexedBase
-from pystencils.types import TypedSymbol, createType
+from pystencils.types import TypedSymbol, createType, StructType
 
 
 class Field(object):
@@ -72,6 +72,14 @@ class Field(object):
         totalDimensions = spatialDimensions + indexDimensions
         shape = tuple([shapeSymbol[i] for i in range(totalDimensions)])
         strides = tuple([strideSymbol[i] for i in range(totalDimensions)])
+
+        npDataType = np.dtype(dtype)
+        if npDataType.fields is not None:
+            if indexDimensions != 0:
+                raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
+            shape += (1,)
+            strides += (1,)
+
         return Field(fieldName, dtype, layout, shape, strides)
 
     @staticmethod
@@ -94,6 +102,13 @@ class Field(object):
         strides = tuple([s // np.dtype(npArray.dtype).itemsize for s in npArray.strides])
         shape = tuple(int(s) for s in npArray.shape)
 
+        npDataType = np.dtype(npArray.dtype)
+        if npDataType.fields is not None:
+            if indexDimensions != 0:
+                raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
+            shape += (1,)
+            strides += (1,)
+
         return Field(fieldName, npArray.dtype, spatialLayout, shape, strides)
 
     @staticmethod
@@ -117,6 +132,14 @@ class Field(object):
 
         shape = tuple(int(s) for s in shape)
         strides = computeStrides(shape, layout)
+
+        npDataType = np.dtype(dtype)
+        if npDataType.fields is not None:
+            if indexDimensions != 0:
+                raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
+            shape += (1,)
+            strides += (1,)
+
         return Field(fieldName, dtype, layout[:spatialDimensions], shape, strides)
 
     def __init__(self, fieldName, dtype, layout, shape, strides):
diff --git a/transformations.py b/transformations.py
index 35a6dbab2..f85733f6a 100644
--- a/transformations.py
+++ b/transformations.py
@@ -6,7 +6,7 @@ from sympy.logic.boolalg import Boolean
 from sympy.tensor import IndexedBase
 
 from pystencils.field import Field, offsetComponentToDirectionString
-from pystencils.types import TypedSymbol, createType, PointerType
+from pystencils.types import TypedSymbol, createType, PointerType, StructType, getBaseType
 from pystencils.slicing import normalizeSlice
 import pystencils.astnodes as ast
 
@@ -248,8 +248,15 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
                         else:
                             ctrName = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
                             coordDict[e] = TypedSymbol("%s_%d" % (ctrName, e), 'int')
+                        coordDict[e] *= field.dtype.itemSize
                     else:
-                        coordDict[e] = fieldAccess.index[e-field.spatialDimensions]
+                        if isinstance(field.dtype, StructType):
+                            assert field.indexDimensions == 1
+                            accessedFieldName = fieldAccess.index[0]
+                            assert isinstance(accessedFieldName, str)
+                            coordDict[e] = field.dtype.getElementOffset(accessedFieldName)
+                        else:
+                            coordDict[e] = fieldAccess.index[e - field.spatialDimensions]
                 return coordDict
 
             for group in reversed(basePointerInfo[1:]):
@@ -260,10 +267,15 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
                     enclosingBlock.insertBefore(newAssignment, sympyAssignment)
                 lastPointer = newPtr
 
-            _, offset = createIntermediateBasePointer(fieldAccess, createCoordinateDict(basePointerInfo[0]),
-                                                      lastPointer)
+            coordDict = createCoordinateDict(basePointerInfo[0])
+            _, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
             baseArr = IndexedBase(lastPointer, shape=(1,))
-            return baseArr[offset]
+            result = baseArr[offset]
+            if isinstance(getBaseType(fieldAccess.field.dtype), StructType):
+                typedSymbol = result.base.label
+                newType = fieldAccess.field.dtype.getElementType(fieldAccess.index[0])
+                typedSymbol.castTo = newType
+            return result
         else:
             newArgs = [visitSympyExpr(e, enclosingBlock, sympyAssignment) for e in expr.args]
             kwargs = {'evaluate': False} if type(expr) is sp.Add or type(expr) is sp.Mul else {}
diff --git a/types.py b/types.py
index 4a6516142..32deb2811 100644
--- a/types.py
+++ b/types.py
@@ -10,9 +10,10 @@ class TypedSymbol(sp.Symbol):
         obj = TypedSymbol.__xnew_cached_(cls, name, *args, **kwds)
         return obj
 
-    def __new_stage2__(cls, name, dtype):
+    def __new_stage2__(cls, name, dtype, castTo=None):
         obj = super(TypedSymbol, cls).__xnew__(cls, name)
         obj._dtype = createType(dtype)
+        obj.castTo = castTo
         return obj
 
     __xnew__ = staticmethod(__new_stage2__)
@@ -24,11 +25,11 @@ class TypedSymbol(sp.Symbol):
 
     def _hashable_content(self):
         superClassContents = list(super(TypedSymbol, self)._hashable_content())
-        t = tuple(superClassContents + [hash(repr(self._dtype))])
+        t = tuple(superClassContents + [hash(repr(self._dtype) + repr(self.castTo))])
         return t
 
     def __getnewargs__(self):
-        return self.name, self.dtype
+        return self.name, self.dtype, self.castTo
 
 
 def createType(specification):
@@ -38,7 +39,10 @@ def createType(specification):
         return createTypeFromString(specification)
     else:
         npDataType = np.dtype(specification)
-        return BasicType(npDataType, const=False)
+        if npDataType.fields is None:
+            return BasicType(npDataType, const=False)
+        else:
+            return StructType(npDataType, const=False)
 
 
 def createTypeFromString(specification):
@@ -88,6 +92,8 @@ def getBaseType(type):
 def toCtypes(dataType):
     if isinstance(dataType, PointerType):
         return ctypes.POINTER(toCtypes(dataType.baseType))
+    elif isinstance(dataType, StructType):
+        return ctypes.POINTER(ctypes.c_uint8)
     else:
         return toCtypes.map[dataType.numpyDtype]
 
@@ -120,7 +126,7 @@ class BasicType(Type):
             width = int(name[len("int"):])
             return "int%d_t" % (width,)
         elif name.startswith('uint'):
-            width = int(name[len("int"):])
+            width = int(name[len("uint"):])
             return "uint%d_t" % (width,)
         elif name == 'bool':
             return 'bool'
@@ -142,6 +148,10 @@ class BasicType(Type):
     def numpyDtype(self):
         return self._dtype
 
+    @property
+    def itemSize(self):
+        return 1
+
     def __str__(self):
         result = BasicType.numpyNameToC(str(self._dtype))
         if self.const:
@@ -172,6 +182,10 @@ class PointerType(Type):
     def baseType(self):
         return self._baseType
 
+    @property
+    def itemSize(self):
+        return self.baseType.itemSize
+
     def __eq__(self, other):
         if not isinstance(other, PointerType):
             return False
@@ -179,13 +193,48 @@ class PointerType(Type):
             return (self.baseType, self.const, self.restrict) == (other.baseType, other.const, other.restrict)
 
     def __str__(self):
-        return "%s * %s%s" % (self.baseType, "RESTRICT " if self.restrict else "", "const " if self.const else "")
+        return "%s *%s%s" % (self.baseType, " RESTRICT" if self.restrict else "", " const" if self.const else "")
 
     def __hash__(self):
         return hash(str(self))
 
 
 class StructType(object):
-    def __init__(self, numpyType):
+    def __init__(self, numpyType, const=False):
+        self.const = const
         self._dtype = np.dtype(numpyType)
 
+    @property
+    def baseType(self):
+        return None
+
+    @property
+    def numpyDtype(self):
+        return self._dtype
+
+    @property
+    def itemSize(self):
+        return self.numpyDtype.itemsize
+
+    def getElementOffset(self, elementName):
+        return self.numpyDtype.fields[elementName][1]
+
+    def getElementType(self, elementName):
+        npElementType = self.numpyDtype.fields[elementName][0]
+        return BasicType(npElementType, self.const)
+
+    def __eq__(self, other):
+        if not isinstance(other, StructType):
+            return False
+        else:
+            return (self.numpyDtype, self.const) == (other.numpyDtype, other.const)
+
+    def __str__(self):
+        # structs are handled byte-wise
+        result = "uint8_t"
+        if self.const:
+            result += " const"
+        return result
+
+    def __hash__(self):
+        return hash((self.numpyDtype, self.const))
-- 
GitLab