From 981763046c9071fd824df9e9cbf4d392c8baa63f Mon Sep 17 00:00:00 2001
From: Jan Hoenig <hrominium@gmail.com>
Date: Sun, 18 Dec 2016 17:47:10 +0100
Subject: [PATCH] not done yet

---
 astnodes.py            | 60 ++++++++++++++++++++++++++++++++++++++++--
 llvm/kernelcreation.py |  8 +++++-
 transformations.py     | 55 +++++++++++++++++++++++++++++++-------
 types.py               |  7 +++--
 4 files changed, 115 insertions(+), 15 deletions(-)

diff --git a/astnodes.py b/astnodes.py
index 8856fd421..e4a0c0dc6 100644
--- a/astnodes.py
+++ b/astnodes.py
@@ -1,7 +1,7 @@
 import sympy as sp
 from sympy.tensor import IndexedBase, Indexed
 from pystencils.field import Field
-from pystencils.types import TypedSymbol, DataType
+from pystencils.types import TypedSymbol, DataType, _c_dtype_dict
 
 
 class Node(object):
@@ -391,6 +391,37 @@ class TemporaryMemoryFree(Node):
         return []
 
 
+# TODO implement defined & undefinedSymbols
+
+
+class Conversion(Node):
+    def __init__(self, child, dtype, parent=None):
+        super(Conversion, self).__init__(parent)
+        self._args = [child]
+        self.dtype = dtype
+
+    @property
+    def args(self):
+        """Returns all arguments/children of this node"""
+        return self._args
+
+    @args.setter
+    def args(self, value):
+        self._args = value
+
+    @property
+    def symbolsDefined(self):
+        """Set of symbols which are defined by this node. """
+        return set()
+
+    @property
+    def undefinedSymbols(self):
+        """Symbols which are use but are not defined inside this node"""
+        raise set()
+
+    def __repr__(self):
+        return '(%s)' % (_c_dtype_dict(self.dtype)) + repr(self.args)
+
 # TODO everything which is not Atomic expression: Pow)
 
 
@@ -401,6 +432,7 @@ class Expr(Node):
     def __init__(self, args, parent=None):
         super(Expr, self).__init__(parent)
         self._args = list(args)
+        self.dtype = None
 
     @property
     def args(self):
@@ -430,7 +462,7 @@ class Expr(Node):
         return set()  # Todo fix for symbol analysis
 
     def __repr__(self):
-        return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args) # TODO test this
+        return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args)
 
 
 class Mul(Expr):
@@ -449,4 +481,28 @@ class Indexed(Expr):
     def __repr__(self):
         return '%s[%s]' % (self.args[0], self.args[1])
 
+class Number(Node):
+    def __init__(self, number, parent=None):
+        super(Number, self).__init__(parent)
+        self._args = None
+        self.dtype = dtype
+
+    @property
+    def args(self):
+        """Returns all arguments/children of this node"""
+        return self._args
+
+    @property
+    def symbolsDefined(self):
+        """Set of symbols which are defined by this node. """
+        return set()
+
+    @property
+    def undefinedSymbols(self):
+        """Symbols which are use but are not defined inside this node"""
+        raise set()
+
+    def __repr__(self):
+        return '(%s)' % (_c_dtype_dict(self.dtype)) + repr(self.args)
+
 
diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py
index a13001973..d67565d65 100644
--- a/llvm/kernelcreation.py
+++ b/llvm/kernelcreation.py
@@ -1,6 +1,7 @@
 import sympy as sp
 from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \
-    typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
+    typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \
+    desympy_ast, insert_casts
 from pystencils.types import TypedSymbol, DataType
 from pystencils.field import Field
 import pystencils.astnodes as ast
@@ -59,4 +60,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
     resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos)
     moveConstantsBeforeLoop(code)
 
+    desympy_ast(code)
+    insert_casts(code)
+
+
+
     return code
\ No newline at end of file
diff --git a/transformations.py b/transformations.py
index 1d2dfd524..e026cd379 100644
--- a/transformations.py
+++ b/transformations.py
@@ -1,4 +1,6 @@
 from collections import defaultdict
+from operator import attrgetter
+
 import sympy as sp
 from sympy.logic.boolalg import Boolean
 from sympy.tensor import IndexedBase
@@ -527,24 +529,56 @@ def getLoopHierarchy(astNode):
     return reversed(result)
 
 
+def get_type(node):
+    if isinstance(node, ast.Indexed):
+        return node.args[0].dtype
+    elif isinstance(node, ast.Node):
+        return node.dtype
+    # TODO sp.NumberSymbol
+    elif isinstance(node, sp.Number):
+        if isinstance(node, sp.Float):
+            return DataType('double')
+        elif isinstance(node, sp.Integer):
+            return DataType('int')
+        else:
+            raise NotImplemented('Not yet supported: %s %s' % (node, type(node)))
+    else:
+        raise NotImplemented('Not yet supported: %s %s' % (node, type(node)))
+
+
 def insert_casts(node):
-    if isinstance(node, ast.SympyAssignment):
+    """
+    Inserts casts where needed
+    :param node: ast which should be traversed
+    :return: node
+    """
+    def add_conversion(node, dtype):
+        return node
+
+    for arg in node.args:
+        insert_casts(arg)
+    if isinstance(node, ast.Indexed):
         pass
-    elif isinstance(node, sp.Expr):
+    elif isinstance(node, ast.Expr):
+        args = sorted((arg.dtype for arg in node.args), key=attrgetter('ptr', 'dtype'))
+        target = args[0]
+        for i in range(len(args)):
+            args[i] = add_conversion(args[i], target.dtype)
+        node.args = args
+    elif isinstance(node, ast.LoopOverCoordinate):
         pass
-    else:
-        for arg in node.args:
-            insert_casts(arg)
+    return node
 
 
 def desympy_ast(node):
-    # if isinstance(node, sp.Expr) and not isinstance(node, sp.AtomicExpr) and not isinstance(node, sp.tensor.IndexedBase):
-    #    print(node, type(node))
-
+    """
+    Remove Sympy Expressions, which have more then one argument.
+    This is necessary for further changes in the tree.
+    :param node: ast which should be traversed. Only node's children will be modified.
+    :return: (modified) node
+    """
     for i in range(len(node.args)):
         arg = node.args[i]
-        if isinstance(node, ast.SympyAssignment):
-            print(node, type(arg))
         if isinstance(arg, sp.Add):
             node.replace(arg, ast.Add(arg.args, node))
         elif isinstance(arg, sp.Mul):
@@ -555,3 +589,4 @@ def desympy_ast(node):
             node.replace(arg, ast.Indexed(arg.args, node))
     for arg in node.args:
         desympy_ast(arg)
+    return node
diff --git a/types.py b/types.py
index 8d964c8ec..85d1b9124 100644
--- a/types.py
+++ b/types.py
@@ -29,8 +29,8 @@ class TypedSymbol(sp.Symbol):
         return self.name, self.dtype
 
 
-_c_dtype_dict = {0: 'int', 1: 'double', 2: 'float', 3: 'bool'}
-_dtype_dict = {'int': 0, 'double': 1, 'float': 2, 'bool': 3}
+_c_dtype_dict = {0: 'bool', 1: 'int', 2: 'float', 3: 'double'}
+_dtype_dict = {'bool': 0, 'int': 1, 'float': 2, 'double': 3}
 
 
 class DataType(object):
@@ -63,3 +63,6 @@ class DataType(object):
             return True
         else:
             return False
+
+def get_type_from_sympy(node):
+    return DataType('int')
\ No newline at end of file
-- 
GitLab