diff --git a/astnodes.py b/astnodes.py index e4a0c0dc631c81bc71458346339e07f63de45bb1..ce68fd44412d8d94f59b99e954262bf43e10ee6c 100644 --- a/astnodes.py +++ b/astnodes.py @@ -1,7 +1,7 @@ import sympy as sp from sympy.tensor import IndexedBase, Indexed from pystencils.field import Field -from pystencils.types import TypedSymbol, DataType, _c_dtype_dict +from pystencils.types import TypedSymbol, DataType, get_type_from_sympy class Node(object): @@ -481,11 +481,14 @@ class Indexed(Expr): def __repr__(self): return '%s[%s]' % (self.args[0], self.args[1]) -class Number(Node): + +class Number(Node, sp.AtomicExpr): def __init__(self, number, parent=None): super(Number, self).__init__(parent) - self._args = None - self.dtype = dtype + + self.dtype, self.value = get_type_from_sympy(number) + #TODO why does it have to be a tuple()? + self._args = tuple() @property def args(self): @@ -503,6 +506,6 @@ class Number(Node): raise set() def __repr__(self): - return '(%s)' % (_c_dtype_dict(self.dtype)) + repr(self.args) + return repr(self.dtype) + repr(self.value) diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index d67565d65635026513454c551aa61f464c686f8c..8df0f9bd32d5f2cb77e3667f68b6b79222ef5fa3 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -60,7 +60,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos) moveConstantsBeforeLoop(code) + print(code) desympy_ast(code) + print(code) insert_casts(code) diff --git a/transformations.py b/transformations.py index e026cd3797a9c0e3388680387761256282fbf1b6..efa7bd5642b93e0995f9f33a9ad48c9930f8950f 100644 --- a/transformations.py +++ b/transformations.py @@ -556,15 +556,20 @@ def insert_casts(node): return node for arg in node.args: + print(arg) insert_casts(arg) if isinstance(node, ast.Indexed): pass elif isinstance(node, ast.Expr): - args = sorted((arg.dtype for arg in node.args), key=attrgetter('ptr', 'dtype')) + print(node) + print([(arg, type(arg), arg.dtype, type(arg.dtype)) for arg in node.args]) + args = sorted((arg for arg in node.args), key=attrgetter('dtype')) target = args[0] for i in range(len(args)): args[i] = add_conversion(args[i], target.dtype) node.args = args + node.dtype = target.dtype + print(node) elif isinstance(node, ast.LoopOverCoordinate): pass return node @@ -577,16 +582,21 @@ def desympy_ast(node): :param node: ast which should be traversed. Only node's children will be modified. :return: (modified) node """ + if node.args is None: + return node for i in range(len(node.args)): arg = node.args[i] if isinstance(arg, sp.Add): node.replace(arg, ast.Add(arg.args, node)) + elif isinstance(arg, sp.Number): + node.replace(arg, ast.Number(arg, node)) elif isinstance(arg, sp.Mul): node.replace(arg, ast.Mul(arg.args, node)) elif isinstance(arg, sp.Pow): node.replace(arg, ast.Pow(arg.args, node)) elif isinstance(arg, sp.tensor.Indexed): node.replace(arg, ast.Indexed(arg.args, node)) + #elif isinstance(arg, ) for arg in node.args: desympy_ast(arg) return node diff --git a/types.py b/types.py index 85d1b9124b8bafce0dd5a8a836676ecf13424a02..17ecae91a973f1275c7c7df4c26252327290ca69 100644 --- a/types.py +++ b/types.py @@ -64,5 +64,29 @@ class DataType(object): else: return False + def __gt__(self, other): + if self.ptr and not other.ptr: + return True + if self.dtype > other.dtype: + return True + + def get_type_from_sympy(node): - return DataType('int') \ No newline at end of file + # Rational, NumberSymbol? + # Zero, One, NegativeOne )= Integer + # Half )= Rational + # NAN, Infinity, Negative Inifinity, + # Exp1, Imaginary Unit, Pi, EulerGamma, Catalan, Golden Ratio + # Pow, Mul, Add, Mod, Relational + if not isinstance(node, sp.Number): + raise TypeError(node, 'is not a sp.Number') + + if isinstance(node, sp.Float) or isinstance(node, sp.RealNumber): + # TODO when float? + return DataType('double'), float(node) + elif isinstance(node, sp.Integer): + return DataType('int'), int(node) + elif isinstance(node, sp.Rational): + raise NotImplementedError('Rationals are not supported yet') + else: + raise TypeError(node, ' is not a supported type!')