diff --git a/astnodes.py b/astnodes.py index 8856fd421087a9cc35f30be3d89d34c3dd60ee41..e4a0c0dc631c81bc71458346339e07f63de45bb1 100644 --- a/astnodes.py +++ b/astnodes.py @@ -1,7 +1,7 @@ import sympy as sp from sympy.tensor import IndexedBase, Indexed from pystencils.field import Field -from pystencils.types import TypedSymbol, DataType +from pystencils.types import TypedSymbol, DataType, _c_dtype_dict class Node(object): @@ -391,6 +391,37 @@ class TemporaryMemoryFree(Node): return [] +# TODO implement defined & undefinedSymbols + + +class Conversion(Node): + def __init__(self, child, dtype, parent=None): + super(Conversion, self).__init__(parent) + self._args = [child] + self.dtype = dtype + + @property + def args(self): + """Returns all arguments/children of this node""" + return self._args + + @args.setter + def args(self, value): + self._args = value + + @property + def symbolsDefined(self): + """Set of symbols which are defined by this node. """ + return set() + + @property + def undefinedSymbols(self): + """Symbols which are use but are not defined inside this node""" + raise set() + + def __repr__(self): + return '(%s)' % (_c_dtype_dict(self.dtype)) + repr(self.args) + # TODO everything which is not Atomic expression: Pow) @@ -401,6 +432,7 @@ class Expr(Node): def __init__(self, args, parent=None): super(Expr, self).__init__(parent) self._args = list(args) + self.dtype = None @property def args(self): @@ -430,7 +462,7 @@ class Expr(Node): return set() # Todo fix for symbol analysis def __repr__(self): - return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args) # TODO test this + return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args) class Mul(Expr): @@ -449,4 +481,28 @@ class Indexed(Expr): def __repr__(self): return '%s[%s]' % (self.args[0], self.args[1]) +class Number(Node): + def __init__(self, number, parent=None): + super(Number, self).__init__(parent) + self._args = None + self.dtype = dtype + + @property + def args(self): + """Returns all arguments/children of this node""" + return self._args + + @property + def symbolsDefined(self): + """Set of symbols which are defined by this node. """ + return set() + + @property + def undefinedSymbols(self): + """Symbols which are use but are not defined inside this node""" + raise set() + + def __repr__(self): + return '(%s)' % (_c_dtype_dict(self.dtype)) + repr(self.args) + diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index a13001973936d7ea95cb740a878c0a106ff660c9..d67565d65635026513454c551aa61f464c686f8c 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -1,6 +1,7 @@ import sympy as sp from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \ - typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop + typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \ + desympy_ast, insert_casts from pystencils.types import TypedSymbol, DataType from pystencils.field import Field import pystencils.astnodes as ast @@ -59,4 +60,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos) moveConstantsBeforeLoop(code) + desympy_ast(code) + insert_casts(code) + + + return code \ No newline at end of file diff --git a/transformations.py b/transformations.py index 1d2dfd5244a4297ed56b9f40e0f2211c0e207402..e026cd3797a9c0e3388680387761256282fbf1b6 100644 --- a/transformations.py +++ b/transformations.py @@ -1,4 +1,6 @@ from collections import defaultdict +from operator import attrgetter + import sympy as sp from sympy.logic.boolalg import Boolean from sympy.tensor import IndexedBase @@ -527,24 +529,56 @@ def getLoopHierarchy(astNode): return reversed(result) +def get_type(node): + if isinstance(node, ast.Indexed): + return node.args[0].dtype + elif isinstance(node, ast.Node): + return node.dtype + # TODO sp.NumberSymbol + elif isinstance(node, sp.Number): + if isinstance(node, sp.Float): + return DataType('double') + elif isinstance(node, sp.Integer): + return DataType('int') + else: + raise NotImplemented('Not yet supported: %s %s' % (node, type(node))) + else: + raise NotImplemented('Not yet supported: %s %s' % (node, type(node))) + + def insert_casts(node): - if isinstance(node, ast.SympyAssignment): + """ + Inserts casts where needed + :param node: ast which should be traversed + :return: node + """ + def add_conversion(node, dtype): + return node + + for arg in node.args: + insert_casts(arg) + if isinstance(node, ast.Indexed): pass - elif isinstance(node, sp.Expr): + elif isinstance(node, ast.Expr): + args = sorted((arg.dtype for arg in node.args), key=attrgetter('ptr', 'dtype')) + target = args[0] + for i in range(len(args)): + args[i] = add_conversion(args[i], target.dtype) + node.args = args + elif isinstance(node, ast.LoopOverCoordinate): pass - else: - for arg in node.args: - insert_casts(arg) + return node def desympy_ast(node): - # if isinstance(node, sp.Expr) and not isinstance(node, sp.AtomicExpr) and not isinstance(node, sp.tensor.IndexedBase): - # print(node, type(node)) - + """ + Remove Sympy Expressions, which have more then one argument. + This is necessary for further changes in the tree. + :param node: ast which should be traversed. Only node's children will be modified. + :return: (modified) node + """ for i in range(len(node.args)): arg = node.args[i] - if isinstance(node, ast.SympyAssignment): - print(node, type(arg)) if isinstance(arg, sp.Add): node.replace(arg, ast.Add(arg.args, node)) elif isinstance(arg, sp.Mul): @@ -555,3 +589,4 @@ def desympy_ast(node): node.replace(arg, ast.Indexed(arg.args, node)) for arg in node.args: desympy_ast(arg) + return node diff --git a/types.py b/types.py index 8d964c8ec75a884cbb4f1f37f91d88f67eb1c184..85d1b9124b8bafce0dd5a8a836676ecf13424a02 100644 --- a/types.py +++ b/types.py @@ -29,8 +29,8 @@ class TypedSymbol(sp.Symbol): return self.name, self.dtype -_c_dtype_dict = {0: 'int', 1: 'double', 2: 'float', 3: 'bool'} -_dtype_dict = {'int': 0, 'double': 1, 'float': 2, 'bool': 3} +_c_dtype_dict = {0: 'bool', 1: 'int', 2: 'float', 3: 'double'} +_dtype_dict = {'bool': 0, 'int': 1, 'float': 2, 'double': 3} class DataType(object): @@ -63,3 +63,6 @@ class DataType(object): return True else: return False + +def get_type_from_sympy(node): + return DataType('int') \ No newline at end of file