From c8b455fef8dabfe1413f168d42fe6b9b4952ad32 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Mon, 13 Mar 2017 20:49:34 +0100
Subject: [PATCH] pystencils: Cleaned up type system

- use data type class consistently instead of strings (in TypedSymbol, Field and jit module)
- new datatype class is based on numpy types with additional specifier information (const and restrict)
- translation between data type class and other modules (numpy, ctypes)
---
 astnodes.py            |   6 +-
 backends/cbackend.py   |   3 +-
 cpu/cpujit.py          |  59 +++----------
 cpu/kernelcreation.py  |   4 +-
 field.py               |  15 +---
 llvm/kernelcreation.py |   4 +-
 transformations.py     |  23 ++----
 types.py               | 184 ++++++++++++++++++++++++++++++++++-------
 8 files changed, 185 insertions(+), 113 deletions(-)

diff --git a/astnodes.py b/astnodes.py
index e4a0c0dc6..fe89cb318 100644
--- a/astnodes.py
+++ b/astnodes.py
@@ -1,7 +1,7 @@
 import sympy as sp
-from sympy.tensor import IndexedBase, Indexed
+from sympy.tensor import IndexedBase
 from pystencils.field import Field
-from pystencils.types import TypedSymbol, DataType, _c_dtype_dict
+from pystencils.types import TypedSymbol
 
 
 class Node(object):
@@ -262,7 +262,7 @@ class LoopOverCoordinate(Node):
 
     @staticmethod
     def getLoopCounterSymbol(coordinateToLoopOver):
-        return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), DataType('int'))
+        return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), 'int')
 
     @property
     def loopCounterSymbol(self):
diff --git a/backends/cbackend.py b/backends/cbackend.py
index e4a635414..63172de8a 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -1,5 +1,6 @@
 from sympy.utilities.codegen import CCodePrinter
 from pystencils.astnodes import Node
+from pystencils.types import createType
 
 
 def generateC(astNode):
@@ -7,7 +8,7 @@ def generateC(astNode):
     Prints the abstract syntax tree as C function
     """
     fieldTypes = set([f.dtype for f in astNode.fieldsAccessed])
-    useFloatConstants = "double" not in fieldTypes
+    useFloatConstants = createType("double") not in fieldTypes
     printer = CBackend(constantsAsFloats=useFloatConstants)
     return printer(astNode)
 
diff --git a/cpu/cpujit.py b/cpu/cpujit.py
index 8ec4e3756..365b0462c 100644
--- a/cpu/cpujit.py
+++ b/cpu/cpujit.py
@@ -62,17 +62,17 @@ compiled into the shared library. Then, the same script can be run from the comp
 from __future__ import print_function
 import os
 import subprocess
-from ctypes import cdll, c_double, c_float, sizeof
-import shutil
-from pystencils.backends.cbackend import generateC
-import numpy as np
 import hashlib
 import json
 import platform
 import glob
 import atexit
+import shutil
+from ctypes import cdll, sizeof
+from pystencils.backends.cbackend import generateC
 from collections import OrderedDict, Mapping
 from pystencils.transformations import symbolNameToVariableName
+from pystencils.types import toCtypes, getBaseType, createType
 
 
 def makePythonFunction(kernelFunctionNode, argumentDict={}):
@@ -173,7 +173,7 @@ def readConfig():
         defaultCompilerConfig = OrderedDict([
             ('os', 'linux'),
             ('command', 'g++'),
-            ('flags', '-Ofast -DNDEBUG -fPIC -march=native -fopenmp'),
+            ('flags', '-Ofast -DNDEBUG -fPIC -march=native -fopenmp -std=c++11'),
             ('restrictQualifier', '__restrict__')
         ])
         defaultCacheConfig = OrderedDict([
@@ -240,41 +240,6 @@ def getCacheConfig():
     return _config['cache']
 
 
-def ctypeFromString(typename, includePointers=True):
-    import ctypes as ct
-
-    typename = str(typename).replace("*", " * ")
-    typeComponents = typename.split()
-
-    basicTypeMap = {
-        'double': ct.c_double,
-        'float': ct.c_float,
-        'int': ct.c_int,
-        'long': ct.c_long,
-    }
-
-    resultType = None
-    for typeComponent in typeComponents:
-        typeComponent = typeComponent.strip()
-        if typeComponent == "const" or typeComponent == "restrict" or typeComponent == "volatile":
-            continue
-        if typeComponent in basicTypeMap:
-            resultType = basicTypeMap[typeComponent]
-        elif typeComponent == "*" and includePointers:
-            assert resultType is not None
-            resultType = ct.POINTER(resultType)
-
-    return resultType
-
-
-def ctypeFromNumpyType(numpyType):
-    typeMap = {
-        np.dtype('float64'): c_double,
-        np.dtype('float32'): c_float,
-    }
-    return typeMap[numpyType]
-
-
 def hashToFunctionName(h):
     res = "func_%s" % (h,)
     return res.replace('-', 'm')
@@ -344,7 +309,7 @@ def compileLinux(ast, codeHashStr, srcFile, libFile):
     objectFile = os.path.join(cacheConfig['objectCache'], codeHashStr + '.o')
     # Compilation
     if not os.path.exists(objectFile):
-        generateCode(ast, ['iostream', 'cmath'], compilerConfig['restrictQualifier'], '', srcFile)
+        generateCode(ast, ['iostream', 'cmath', 'cstdint'], compilerConfig['restrictQualifier'], '', srcFile)
         compileCmd = [compilerConfig['command'], '-c'] + compilerConfig['flags'].split()
         compileCmd += ['-o', objectFile, srcFile]
         runCompileStep(compileCmd)
@@ -360,7 +325,7 @@ def compileWindows(ast, codeHashStr, srcFile, libFile):
     objectFile = os.path.join(cacheConfig['objectCache'], codeHashStr + '.obj')
     # Compilation
     if not os.path.exists(objectFile):
-        generateCode(ast, ['iostream', 'cmath'], compilerConfig['restrictQualifier'],
+        generateCode(ast, ['iostream', 'cmath', 'cstdint'], compilerConfig['restrictQualifier'],
                      '__declspec(dllexport)', srcFile)
 
         # /c compiles only, /EHsc turns of exception handling in c code
@@ -407,7 +372,7 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict):
 
             symbolicField = arg.field
             if arg.isFieldPtrArgument:
-                ctArguments.append(field.ctypes.data_as(ctypeFromString(arg.dtype)))
+                ctArguments.append(field.ctypes.data_as(toCtypes(arg.dtype)))
                 if symbolicField.hasFixedShape:
                     if tuple(int(i) for i in symbolicField.shape) != field.shape:
                         raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" %
@@ -420,11 +385,11 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict):
                 if not symbolicField.isIndexField:
                     arrayShapes.add(field.shape[:symbolicField.spatialDimensions])
             elif arg.isFieldShapeArgument:
-                dataType = ctypeFromString(arg.dtype, includePointers=False)
+                dataType = toCtypes(getBaseType(arg.dtype))
                 ctArguments.append(field.ctypes.shape_as(dataType))
             elif arg.isFieldStrideArgument:
-                dataType = ctypeFromString(arg.dtype, includePointers=False)
-                baseFieldType = ctypeFromNumpyType(field.dtype)
+                dataType = toCtypes(getBaseType(arg.dtype))
+                baseFieldType = toCtypes(createType(field.dtype))
                 strides = field.ctypes.strides_as(dataType)
                 for i in range(len(field.shape)):
                     assert strides[i] % sizeof(baseFieldType) == 0
@@ -437,7 +402,7 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict):
                 param = argumentDict[arg.name]
             except KeyError:
                 raise KeyError("Missing parameter for kernel call " + arg.name)
-            expectedType = ctypeFromString(arg.dtype)
+            expectedType = toCtypes(arg.dtype)
             ctArguments.append(expectedType(param))
 
     if len(arrayShapes) > 1:
diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py
index f5fb46569..f4e306a35 100644
--- a/cpu/kernelcreation.py
+++ b/cpu/kernelcreation.py
@@ -1,7 +1,7 @@
 import sympy as sp
 from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \
     typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
-from pystencils.types import TypedSymbol, DataType
+from pystencils.types import TypedSymbol
 from pystencils.field import Field
 import pystencils.astnodes as ast
 
@@ -37,7 +37,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
         if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
             return term
         elif isinstance(term, sp.Symbol):
-            return TypedSymbol(term.name, DataType(typeForSymbol[term.name]))
+            return TypedSymbol(term.name, typeForSymbol[term.name])
         else:
             raise ValueError("Term has to be field access or symbol")
 
diff --git a/field.py b/field.py
index acb31391f..546e049d2 100644
--- a/field.py
+++ b/field.py
@@ -3,7 +3,7 @@ import numpy as np
 import sympy as sp
 from sympy.core.cache import cacheit
 from sympy.tensor import IndexedBase
-from pystencils.types import TypedSymbol
+from pystencils.types import TypedSymbol, createType
 
 
 class Field(object):
@@ -122,7 +122,7 @@ class Field(object):
     def __init__(self, fieldName, dtype, layout, shape, strides):
         """Do not use directly. Use static create* methods"""
         self._fieldName = fieldName
-        self._dtype = numpyDataTypeToC(dtype)
+        self._dtype = createType(dtype)
         self._layout = normalizeLayout(layout)
         self.shape = shape
         self.strides = strides
@@ -372,17 +372,6 @@ def computeStrides(shape, layout):
     return tuple(strides)
 
 
-def numpyDataTypeToC(dtype):
-    """Mapping numpy data types to C data types"""
-    if dtype == np.float64:
-        return "double"
-    elif dtype == np.float32:
-        return "float"
-    elif dtype == np.int32:
-        return "int"
-    raise NotImplementedError("Cannot convert type " + str(dtype))
-
-
 def offsetComponentToDirectionString(coordinateId, value):
     """
     Translates numerical offset to string notation.
diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py
index d67565d65..75bac76cc 100644
--- a/llvm/kernelcreation.py
+++ b/llvm/kernelcreation.py
@@ -2,7 +2,7 @@ import sympy as sp
 from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \
     typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \
     desympy_ast, insert_casts
-from pystencils.types import TypedSymbol, DataType
+from pystencils.types import TypedSymbol
 from pystencils.field import Field
 import pystencils.astnodes as ast
 
@@ -36,7 +36,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
         if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
             return term
         elif isinstance(term, sp.Symbol):
-            return TypedSymbol(term.name, DataType(typeForSymbol[term.name]))
+            return TypedSymbol(term.name, typeForSymbol[term.name])
         else:
             raise ValueError("Term has to be field access or symbol")
 
diff --git a/transformations.py b/transformations.py
index e07b1de9c..35a6dbab2 100644
--- a/transformations.py
+++ b/transformations.py
@@ -6,7 +6,7 @@ from sympy.logic.boolalg import Boolean
 from sympy.tensor import IndexedBase
 
 from pystencils.field import Field, offsetComponentToDirectionString
-from pystencils.types import TypedSymbol, DataType
+from pystencils.types import TypedSymbol, createType, PointerType
 from pystencils.slicing import normalizeSlice
 import pystencils.astnodes as ast
 
@@ -109,7 +109,7 @@ def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
     Example:
         >>> field = Field.createGeneric('myfield', spatialDimensions=2, indexDimensions=1)
         >>> x, y = sp.symbols("x y")
-        >>> prevPointer = TypedSymbol("ptr", DataType("double"))
+        >>> prevPointer = TypedSymbol("ptr", "double")
         >>> createIntermediateBasePointer(field[1,-2](5), {0: x}, prevPointer)
         (ptr_E, x*fstride_myfield[0] + fstride_myfield[0])
         >>> createIntermediateBasePointer(field[1,-2](5), {0: x, 1 : y }, prevPointer)
@@ -140,7 +140,7 @@ def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
     if len(listToHash) > 0:
         name += "%0.6X" % (abs(hash(tuple(listToHash))))
 
-    newPtr = TypedSymbol(previousPtr.name + name, DataType(previousPtr.dtype))
+    newPtr = TypedSymbol(previousPtr.name + name, previousPtr.dtype)
     return newPtr, offset
 
 
@@ -234,12 +234,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
             else:
                 basePointerInfo = [list(range(field.indexDimensions + field.spatialDimensions))]
 
-            dtype = DataType(field.dtype)
-            dtype.alias = False
-            dtype.ptr = True
-            if field.name in readOnlyFieldNames:
-                dtype.const = True
-
+            dtype = PointerType(field.dtype, const=field.name in readOnlyFieldNames, restrict=True)
             fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbolNameToVariableName(field.name)), dtype)
 
             lastPointer = fieldPtr
@@ -252,7 +247,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
                             coordDict[e] = fieldToFixedCoordinates[field.name][e]
                         else:
                             ctrName = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
-                            coordDict[e] = TypedSymbol("%s_%d" % (ctrName, e), DataType('int'))
+                            coordDict[e] = TypedSymbol("%s_%d" % (ctrName, e), 'int')
                     else:
                         coordDict[e] = fieldAccess.index[e-field.spatialDimensions]
                 return coordDict
@@ -433,7 +428,7 @@ def typeAllEquations(eqs, typeForSymbol):
         elif isinstance(term, TypedSymbol):
             return term
         elif isinstance(term, sp.Symbol):
-            return TypedSymbol(symbolNameToVariableName(term.name), DataType(typeForSymbol[term.name]))
+            return TypedSymbol(symbolNameToVariableName(term.name), typeForSymbol[term.name])
         else:
             newArgs = [processRhs(arg) for arg in term.args]
             return term.func(*newArgs) if newArgs else term
@@ -446,7 +441,7 @@ def typeAllEquations(eqs, typeForSymbol):
         elif isinstance(term, TypedSymbol):
             return term
         elif isinstance(term, sp.Symbol):
-            return TypedSymbol(term.name, DataType(typeForSymbol[term.name]))
+            return TypedSymbol(term.name, typeForSymbol[term.name])
         else:
             assert False, "Expected a symbol as left-hand-side"
 
@@ -539,9 +534,9 @@ def get_type(node):
     # TODO sp.NumberSymbol
     elif isinstance(node, sp.Number):
         if isinstance(node, sp.Float):
-            return DataType('double')
+            return createType('double')
         elif isinstance(node, sp.Integer):
-            return DataType('int')
+            return createType('int')
         else:
             raise NotImplemented('Not yet supported: %s %s' % (node, type(node)))
     else:
diff --git a/types.py b/types.py
index 328863970..4a6516142 100644
--- a/types.py
+++ b/types.py
@@ -1,4 +1,6 @@
+import ctypes
 import sympy as sp
+import numpy as np
 from sympy.core.cache import cacheit
 
 
@@ -10,7 +12,7 @@ class TypedSymbol(sp.Symbol):
 
     def __new_stage2__(cls, name, dtype):
         obj = super(TypedSymbol, cls).__xnew__(cls, name)
-        obj._dtype = DataType(dtype) if isinstance(dtype, str) else dtype
+        obj._dtype = createType(dtype)
         return obj
 
     __xnew__ = staticmethod(__new_stage2__)
@@ -29,41 +31,161 @@ class TypedSymbol(sp.Symbol):
         return self.name, self.dtype
 
 
-_c_dtype_dict = {0: 'bool', 1: 'int', 2: 'float', 3: 'double'}
-_dtype_dict = {'bool': 0, 'int': 1, 'float': 2, 'double': 3}
-
-
-class DataType(object):
-    def __init__(self, dtype):
-        self.alias = True
-        self.const = False
-        self.ptr = False
-        self.dtype = 0
-        if isinstance(dtype, str):
-            for s in dtype.split():
-                if s == 'const':
-                    self.const = True
-                elif s == '*':
-                    self.ptr = True
-                elif s == 'RESTRICT':
-                    self.alias = False
-                else:
-                    self.dtype = _dtype_dict[s]
-        elif isinstance(dtype, DataType):
-            self.__dict__.update(dtype.__dict__)
+def createType(specification):
+    if isinstance(specification, Type):
+        return specification
+    elif isinstance(specification, str):
+        return createTypeFromString(specification)
+    else:
+        npDataType = np.dtype(specification)
+        return BasicType(npDataType, const=False)
+
+
+def createTypeFromString(specification):
+    specification = specification.lower().split()
+    parts = []
+    current = []
+    for s in specification:
+        if s == '*':
+            parts.append(current)
+            current = [s]
         else:
-            self.dtype = dtype
+            current.append(s)
+    if len(current) > 0:
+        parts.append(current)
+
+    # Parse native part
+    basePart = parts.pop(0)
+    const = False
+    if 'const' in basePart:
+        const = True
+        basePart.remove('const')
+    assert len(basePart) == 1
+    baseType = BasicType(basePart[0], const)
+
+    currentType = baseType
+    # Parse pointer parts
+    for part in parts:
+        restrict = False
+        const = False
+        if 'restrict' in part:
+            restrict = True
+            part.remove('restrict')
+        if 'const' in part:
+            const = True
+            part.remove("const")
+        assert len(part) == 1 and part[0] == '*'
+        currentType = PointerType(currentType, const, restrict)
+    return currentType
+
+
+def getBaseType(type):
+    while type.baseType is not None:
+        type = type.baseType
+    return type
+
+
+def toCtypes(dataType):
+    if isinstance(dataType, PointerType):
+        return ctypes.POINTER(toCtypes(dataType.baseType))
+    else:
+        return toCtypes.map[dataType.numpyDtype]
+
+toCtypes.map = {
+    np.dtype(np.int8): ctypes.c_int8,
+    np.dtype(np.int16): ctypes.c_int16,
+    np.dtype(np.int32): ctypes.c_int32,
+    np.dtype(np.int64): ctypes.c_int64,
+
+    np.dtype(np.uint8): ctypes.c_uint8,
+    np.dtype(np.uint16): ctypes.c_uint16,
+    np.dtype(np.uint32): ctypes.c_uint32,
+    np.dtype(np.uint64): ctypes.c_uint64,
+
+    np.dtype(np.float32): ctypes.c_float,
+    np.dtype(np.float64): ctypes.c_double,
+}
+
+
+class Type(object):
+    pass
+
+
+class BasicType(Type):
+    @staticmethod
+    def numpyNameToC(name):
+        if name == 'float64': return 'double'
+        elif name == 'float32': return 'float'
+        elif name.startswith('int'):
+            width = int(name[len("int"):])
+            return "int%d_t" % (width,)
+        elif name.startswith('uint'):
+            width = int(name[len("int"):])
+            return "uint%d_t" % (width,)
+        elif name == 'bool':
+            return 'bool'
+        else:
+            raise NotImplemented("Can map numpy to C name for %s" % (name,))
+
+    def __init__(self, dtype, const=False):
+        self.const = const
+        self._dtype = np.dtype(dtype)
+        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
+        assert self._dtype.hasobject is False
+        assert self._dtype.subdtype is None
+
+    @property
+    def baseType(self):
+        return None
 
-    def __repr__(self):
-        return "{!s} {!s}{!s} {!s}".format("const" if self.const else "", _c_dtype_dict[self.dtype],
-                                           "*" if self.ptr else "", "RESTRICT" if not self.alias else "")
+    @property
+    def numpyDtype(self):
+        return self._dtype
+
+    def __str__(self):
+        result = BasicType.numpyNameToC(str(self._dtype))
+        if self.const:
+            result += " const"
+        return result
 
     def __eq__(self, other):
-        if self.alias == other.alias and self.const == other.const and self.ptr == other.ptr and self.dtype == other.dtype:
-            return True
+        if not isinstance(other, BasicType):
+            return False
         else:
+            return (self.numpyDtype, self.const) == (other.numpyDtype, other.const)
+
+    def __hash__(self):
+        return hash(str(self))
+
+
+class PointerType(Type):
+    def __init__(self, baseType, const=False, restrict=True):
+        self._baseType = baseType
+        self.const = const
+        self.restrict = restrict
+
+    @property
+    def alias(self):
+        return not self.restrict
+
+    @property
+    def baseType(self):
+        return self._baseType
+
+    def __eq__(self, other):
+        if not isinstance(other, PointerType):
             return False
+        else:
+            return (self.baseType, self.const, self.restrict) == (other.baseType, other.const, other.restrict)
+
+    def __str__(self):
+        return "%s * %s%s" % (self.baseType, "RESTRICT " if self.restrict else "", "const " if self.const else "")
+
+    def __hash__(self):
+        return hash(str(self))
+
 
+class StructType(object):
+    def __init__(self, numpyType):
+        self._dtype = np.dtype(numpyType)
 
-def get_type_from_sympy(node):
-    return DataType('int')
\ No newline at end of file
-- 
GitLab