From b444ae2573a105162c4656c867a824e49c1d4a13 Mon Sep 17 00:00:00 2001 From: Jan Hoenig <hrominium@gmail.com> Date: Thu, 2 Mar 2017 15:41:48 +0100 Subject: [PATCH] work work --- astnodes.py | 14 ++++++-------- backends/llvm.py | 17 +++++++++++++++-- llvm/kernelcreation.py | 8 +++----- transformations.py | 18 ++++++++---------- types.py | 1 - 5 files changed, 32 insertions(+), 26 deletions(-) diff --git a/astnodes.py b/astnodes.py index 28a4d6c67..f7a0abac8 100644 --- a/astnodes.py +++ b/astnodes.py @@ -1,7 +1,6 @@ import sympy as sp -from sympy.tensor import IndexedBase, Indexed from pystencils.field import Field -from pystencils.types import TypedSymbol, DataType, get_type_from_sympy +from pystencils.types import TypedSymbol, DataType, get_type_from_sympy, _c_dtype_dict class Node(object): @@ -294,7 +293,7 @@ class SympyAssignment(Node): self._lhsSymbol = lhsSymbol self.rhs = rhsTerm self._isDeclaration = True - if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase): + if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.IndexedBase): self._isDeclaration = False self._isConst = isConst @@ -393,8 +392,6 @@ class TemporaryMemoryFree(Node): # TODO implement defined & undefinedSymbols - - class Conversion(Node): def __init__(self, child, dtype, parent=None): super(Conversion, self).__init__(parent) @@ -421,9 +418,9 @@ class Conversion(Node): raise set() def __repr__(self): - return '(%s)' % (_c_dtype_dict(self.dtype)) + repr(self.args) + return '(%s(%s))' % (repr(self.dtype), repr(self.args[0].dtype)) + repr(self.args) -# TODO everything which is not Atomic expression: Pow) +# TODO Pow _expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'} @@ -482,6 +479,8 @@ class Indexed(Expr): def __init__(self, args, base, parent=None): super(Indexed, self).__init__(args, parent) self.base = base + #Get dtype from label, and unpointer it + self.dtype = DataType(base.label.dtype.dtype) def __repr__(self): return '%s[%s]' % (self.args[0], self.args[1]) @@ -492,7 +491,6 @@ class Number(Node, sp.AtomicExpr): super(Number, self).__init__(parent) self.dtype, self.value = get_type_from_sympy(number) - #TODO why does it have to be a tuple()? self._args = tuple() @property diff --git a/backends/llvm.py b/backends/llvm.py index fe11e77a4..a70627f02 100644 --- a/backends/llvm.py +++ b/backends/llvm.py @@ -36,10 +36,10 @@ class LLVMPrinter(Printer): self.tmp_var[name] = value def _print_Number(self, n, **kwargs): - return ir.Constant(self.fp_type, float(n)) + return ir.Constant(self.fp_type, n) def _print_Float(self, expr): - return ir.Constant(self.fp_type, float(expr.p)) + return ir.Constant(self.fp_type, expr.p) def _print_Integer(self, expr): return ir.Constant(self.integer, expr.p) @@ -134,6 +134,19 @@ class LLVMPrinter(Printer): def _print_SympyAssignment(self, assignment): expr = self._print(assignment.rhs) + def _print_Conversion(self, conversion): + to_dtype = conversion.dtype + from_dtype = conversion.args[0].dtype + print(to_dtype, from_dtype) + # fp -> int: fptosi + # int -> fp: sitofp + # ptr -> int: ptrtoint + # int -> ptr: inttoptr + # ?bitcast, ?addrspacecast + + def _print_Indexed(self, indexed): + pass + # Should have a list of math library functions to validate this. diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index 8df0f9bd3..b07a5fb8f 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -60,11 +60,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos) moveConstantsBeforeLoop(code) - print(code) + # print(code) desympy_ast(code) - print(code) + # print(code) insert_casts(code) - - - return code \ No newline at end of file + return code diff --git a/transformations.py b/transformations.py index 74aba9d87..dc709985f 100644 --- a/transformations.py +++ b/transformations.py @@ -552,24 +552,22 @@ def insert_casts(node): :param node: ast which should be traversed :return: node """ - def add_conversion(node, dtype): - return node - for arg in node.args: - print(arg) insert_casts(arg) if isinstance(node, ast.Indexed): - node.dtype = node.base.label.dtype + #TODO revmove this + pass elif isinstance(node, ast.Expr): - print(node) - print([(arg, type(arg), arg.dtype, type(arg.dtype)) for arg in node.args]) args = sorted((arg for arg in node.args), key=attrgetter('dtype')) target = args[0] for i in range(len(args)): - args[i] = add_conversion(args[i], target.dtype) + if args[i].dtype != target.dtype: + args[i] = ast.Conversion(args[i], target.dtype, node) node.args = args node.dtype = target.dtype - print(node) + elif isinstance(node, ast.SympyAssignment): + if node.lhs.dtype != node.rhs.dtype: + node.replace(node.rhs, ast.Conversion(node.rhs, node.lhs.dtype)) elif isinstance(node, ast.LoopOverCoordinate): pass return node @@ -601,7 +599,7 @@ def desympy_ast(node): #elif isinstance(arg, sp.containers.Tuple): # else: - print('Not transforming:', arg, type(arg)) + print('Not transforming:', type(arg), arg) for arg in node.args: desympy_ast(arg) return node diff --git a/types.py b/types.py index 17ecae91a..3550de398 100644 --- a/types.py +++ b/types.py @@ -82,7 +82,6 @@ def get_type_from_sympy(node): raise TypeError(node, 'is not a sp.Number') if isinstance(node, sp.Float) or isinstance(node, sp.RealNumber): - # TODO when float? return DataType('double'), float(node) elif isinstance(node, sp.Integer): return DataType('int'), int(node) -- GitLab