From 0168aaedcdba5295319954186c655e0682ef1585 Mon Sep 17 00:00:00 2001
From: Jan Hoenig <hrominium@gmail.com>
Date: Thu, 8 Dec 2016 10:18:53 +0100
Subject: [PATCH] Created DataType class for storing information about data
 inside a class and not as a string. Changed name of the file TypedSymbol to
 types. Fixed usage of dtype accordingly, however i might not have found every
 usage of dtype.

---
 __init__.py            |  2 +-
 ast.py                 |  2 +-
 backends/cbackend.py   |  8 ++++----
 cpu/kernelcreation.py  |  2 +-
 field.py               |  2 +-
 llvm/kernelcreation.py |  2 +-
 transformations.py     |  7 ++++---
 types.py               | 16 ++++++++++++++++
 8 files changed, 29 insertions(+), 12 deletions(-)

diff --git a/__init__.py b/__init__.py
index 2f6f6bf2f..1a8183006 100644
--- a/__init__.py
+++ b/__init__.py
@@ -1,3 +1,3 @@
 from pystencils.field import Field, extractCommonSubexpressions
-from pystencils.typedsymbol import TypedSymbol
+from pystencils.types import TypedSymbol
 from pystencils.slicing import makeSlice
diff --git a/ast.py b/ast.py
index 5d36c1e0d..04ab30cf9 100644
--- a/ast.py
+++ b/ast.py
@@ -2,7 +2,7 @@ import sympy as sp
 import textwrap as textwrap
 from sympy.tensor import IndexedBase, Indexed
 from pystencils.field import Field
-from pystencils.typedsymbol import TypedSymbol
+from pystencils.types import TypedSymbol
 
 
 class Node(object):
diff --git a/backends/cbackend.py b/backends/cbackend.py
index 8bdf166a8..f59741d5c 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -76,7 +76,7 @@ class CBackend:
         raise NotImplementedError("CBackend does not support node of type " + cls.__name__)
 
     def _print_KernelFunction(self, node):
-        functionArguments = ["%s %s" % (s.dtype, s.name) for s in node.parameters]
+        functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
         prefix = "__global__ void" if self.cuda else "void"
         funcDeclaration = "%s %s(%s)" % (prefix, node.functionName, ", ".join(functionArguments))
         body = self._print(node.body)
@@ -105,10 +105,10 @@ class CBackend:
         dtype = ""
         if node.isDeclaration:
             if node.isConst:
-                dtype = "const " + node.lhs.dtype + " "
+                dtype = "const " + str(node.lhs.dtype) + " "
             else:
-                dtype = node.lhs.dtype + " "
-        return "%s %s = %s;" % (dtype, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
+                dtype = str(node.lhs.dtype) + " "
+        return "%s %s = %s;" % (str(dtype), self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
 
     def _print_TemporaryMemoryAllocation(self, node):
         return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol),
diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py
index ac2db27ef..e8e722c13 100644
--- a/cpu/kernelcreation.py
+++ b/cpu/kernelcreation.py
@@ -1,7 +1,7 @@
 import sympy as sp
 from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \
     typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
-from pystencils.typedsymbol import TypedSymbol
+from pystencils.types import TypedSymbol
 from pystencils.field import Field
 import pystencils.ast as ast
 
diff --git a/field.py b/field.py
index 50e6d36ae..e0835bc6f 100644
--- a/field.py
+++ b/field.py
@@ -3,7 +3,7 @@ import numpy as np
 import sympy as sp
 from sympy.core.cache import cacheit
 from sympy.tensor import IndexedBase
-from pystencils.typedsymbol import TypedSymbol
+from pystencils.types import TypedSymbol
 
 
 class Field:
diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py
index 2fd0245a0..54d4ed0ce 100644
--- a/llvm/kernelcreation.py
+++ b/llvm/kernelcreation.py
@@ -1,7 +1,7 @@
 import sympy as sp
 from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \
     typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
-from pystencils.typedsymbol import TypedSymbol
+from pystencils.types import TypedSymbol
 from pystencils.field import Field
 import pystencils.ast as ast
 
diff --git a/transformations.py b/transformations.py
index ddeef910e..f83fa9c64 100644
--- a/transformations.py
+++ b/transformations.py
@@ -4,7 +4,7 @@ from sympy.logic.boolalg import Boolean
 from sympy.tensor import IndexedBase
 
 from pystencils.field import Field, offsetComponentToDirectionString
-from pystencils.typedsymbol import TypedSymbol
+from pystencils.types import TypedSymbol, DataType
 from pystencils.slicing import normalizeSlice
 import pystencils.ast as ast
 
@@ -220,9 +220,10 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
             else:
                 basePointerInfo = [list(range(field.indexDimensions + field.spatialDimensions))]
 
-            dtype = "%s * __restrict__" % field.dtype
+            dtype = DataType(field.dtype)
+            dtype.alias = False
             if field.name in readOnlyFieldNames:
-                dtype = "const " + dtype
+                dtype.const = True
 
             fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, field.name), dtype)
 
diff --git a/types.py b/types.py
index 72ad8fe69..e4f579e68 100644
--- a/types.py
+++ b/types.py
@@ -28,3 +28,19 @@ class TypedSymbol(sp.Symbol):
     def __getnewargs__(self):
         return self.name, self.dtype
 
+
+_c_dtype_dict = {0: 'int', 1: 'double', 2: 'float'}
+_dtype_dict = {'int': 0, 'double': 1, 'float': 2}
+
+
+class DataType(object):
+    def __init__(self, dtype):
+        self.alias = True
+        self.const = False
+        if isinstance(dtype, str):
+            self.dtype = _dtype_dict[dtype]
+        else:
+            self.dtype = dtype
+
+    def __repr__(self):
+        return "{!s} {!s} {!s}".format("const" if self.const else "", "__restrict__" if not self.alias else "", _c_dtype_dict[self.dtype])
-- 
GitLab