From b444ae2573a105162c4656c867a824e49c1d4a13 Mon Sep 17 00:00:00 2001
From: Jan Hoenig <hrominium@gmail.com>
Date: Thu, 2 Mar 2017 15:41:48 +0100
Subject: [PATCH] work work

---
 astnodes.py            | 14 ++++++--------
 backends/llvm.py       | 17 +++++++++++++++--
 llvm/kernelcreation.py |  8 +++-----
 transformations.py     | 18 ++++++++----------
 types.py               |  1 -
 5 files changed, 32 insertions(+), 26 deletions(-)

diff --git a/astnodes.py b/astnodes.py
index 28a4d6c67..f7a0abac8 100644
--- a/astnodes.py
+++ b/astnodes.py
@@ -1,7 +1,6 @@
 import sympy as sp
-from sympy.tensor import IndexedBase, Indexed
 from pystencils.field import Field
-from pystencils.types import TypedSymbol, DataType, get_type_from_sympy
+from pystencils.types import TypedSymbol, DataType, get_type_from_sympy, _c_dtype_dict
 
 
 class Node(object):
@@ -294,7 +293,7 @@ class SympyAssignment(Node):
         self._lhsSymbol = lhsSymbol
         self.rhs = rhsTerm
         self._isDeclaration = True
-        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase):
+        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.IndexedBase):
             self._isDeclaration = False
         self._isConst = isConst
 
@@ -393,8 +392,6 @@ class TemporaryMemoryFree(Node):
 
 
 # TODO implement defined & undefinedSymbols
-
-
 class Conversion(Node):
     def __init__(self, child, dtype, parent=None):
         super(Conversion, self).__init__(parent)
@@ -421,9 +418,9 @@ class Conversion(Node):
         raise set()
 
     def __repr__(self):
-        return '(%s)' % (_c_dtype_dict(self.dtype)) + repr(self.args)
+        return '(%s(%s))' % (repr(self.dtype), repr(self.args[0].dtype)) + repr(self.args)
 
-# TODO everything which is not Atomic expression: Pow)
+# TODO Pow
 
 
 _expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'}
@@ -482,6 +479,8 @@ class Indexed(Expr):
     def __init__(self, args, base, parent=None):
         super(Indexed, self).__init__(args, parent)
         self.base = base
+        #Get dtype from label, and unpointer it
+        self.dtype = DataType(base.label.dtype.dtype)
 
     def __repr__(self):
         return '%s[%s]' % (self.args[0], self.args[1])
@@ -492,7 +491,6 @@ class Number(Node, sp.AtomicExpr):
         super(Number, self).__init__(parent)
 
         self.dtype, self.value = get_type_from_sympy(number)
-        #TODO why does it have to be a tuple()?
         self._args = tuple()
 
     @property
diff --git a/backends/llvm.py b/backends/llvm.py
index fe11e77a4..a70627f02 100644
--- a/backends/llvm.py
+++ b/backends/llvm.py
@@ -36,10 +36,10 @@ class LLVMPrinter(Printer):
         self.tmp_var[name] = value
 
     def _print_Number(self, n, **kwargs):
-        return ir.Constant(self.fp_type, float(n))
+        return ir.Constant(self.fp_type, n)
 
     def _print_Float(self, expr):
-        return ir.Constant(self.fp_type, float(expr.p))
+        return ir.Constant(self.fp_type, expr.p)
 
     def _print_Integer(self, expr):
         return ir.Constant(self.integer, expr.p)
@@ -134,6 +134,19 @@ class LLVMPrinter(Printer):
     def _print_SympyAssignment(self, assignment):
         expr = self._print(assignment.rhs)
 
+    def _print_Conversion(self, conversion):
+        to_dtype = conversion.dtype
+        from_dtype = conversion.args[0].dtype
+        print(to_dtype, from_dtype)
+        # fp -> int: fptosi
+        # int -> fp: sitofp
+        # ptr -> int: ptrtoint
+        # int -> ptr: inttoptr
+        # ?bitcast, ?addrspacecast
+
+    def _print_Indexed(self, indexed):
+        pass
+
 
 
         #  Should have a list of math library functions to validate this.
diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py
index 8df0f9bd3..b07a5fb8f 100644
--- a/llvm/kernelcreation.py
+++ b/llvm/kernelcreation.py
@@ -60,11 +60,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
     resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos)
     moveConstantsBeforeLoop(code)
 
-    print(code)
+    # print(code)
     desympy_ast(code)
-    print(code)
+    # print(code)
     insert_casts(code)
 
-
-
-    return code
\ No newline at end of file
+    return code
diff --git a/transformations.py b/transformations.py
index 74aba9d87..dc709985f 100644
--- a/transformations.py
+++ b/transformations.py
@@ -552,24 +552,22 @@ def insert_casts(node):
     :param node: ast which should be traversed
     :return: node
     """
-    def add_conversion(node, dtype):
-        return node
-
     for arg in node.args:
-        print(arg)
         insert_casts(arg)
     if isinstance(node, ast.Indexed):
-        node.dtype = node.base.label.dtype
+        #TODO revmove this
+        pass
     elif isinstance(node, ast.Expr):
-        print(node)
-        print([(arg, type(arg), arg.dtype, type(arg.dtype)) for arg in node.args])
         args = sorted((arg for arg in node.args), key=attrgetter('dtype'))
         target = args[0]
         for i in range(len(args)):
-            args[i] = add_conversion(args[i], target.dtype)
+            if args[i].dtype != target.dtype:
+                args[i] = ast.Conversion(args[i], target.dtype, node)
         node.args = args
         node.dtype = target.dtype
-        print(node)
+    elif isinstance(node, ast.SympyAssignment):
+        if node.lhs.dtype != node.rhs.dtype:
+            node.replace(node.rhs, ast.Conversion(node.rhs, node.lhs.dtype))
     elif isinstance(node, ast.LoopOverCoordinate):
         pass
     return node
@@ -601,7 +599,7 @@ def desympy_ast(node):
         #elif isinstance(arg, sp.containers.Tuple):
         #
         else:
-            print('Not transforming:', arg, type(arg))
+            print('Not transforming:', type(arg), arg)
     for arg in node.args:
         desympy_ast(arg)
     return node
diff --git a/types.py b/types.py
index 17ecae91a..3550de398 100644
--- a/types.py
+++ b/types.py
@@ -82,7 +82,6 @@ def get_type_from_sympy(node):
         raise TypeError(node, 'is not a sp.Number')
 
     if isinstance(node, sp.Float) or isinstance(node, sp.RealNumber):
-        # TODO when float?
         return DataType('double'), float(node)
     elif isinstance(node, sp.Integer):
         return DataType('int'), int(node)
-- 
GitLab