Skip to content
Snippets Groups Projects
Commit 98176304 authored by Jan Hoenig's avatar Jan Hoenig
Browse files

not done yet

parent f9e81a26
No related merge requests found
import sympy as sp import sympy as sp
from sympy.tensor import IndexedBase, Indexed from sympy.tensor import IndexedBase, Indexed
from pystencils.field import Field from pystencils.field import Field
from pystencils.types import TypedSymbol, DataType from pystencils.types import TypedSymbol, DataType, _c_dtype_dict
class Node(object): class Node(object):
...@@ -391,6 +391,37 @@ class TemporaryMemoryFree(Node): ...@@ -391,6 +391,37 @@ class TemporaryMemoryFree(Node):
return [] return []
# TODO implement defined & undefinedSymbols
class Conversion(Node):
def __init__(self, child, dtype, parent=None):
super(Conversion, self).__init__(parent)
self._args = [child]
self.dtype = dtype
@property
def args(self):
"""Returns all arguments/children of this node"""
return self._args
@args.setter
def args(self, value):
self._args = value
@property
def symbolsDefined(self):
"""Set of symbols which are defined by this node. """
return set()
@property
def undefinedSymbols(self):
"""Symbols which are use but are not defined inside this node"""
raise set()
def __repr__(self):
return '(%s)' % (_c_dtype_dict(self.dtype)) + repr(self.args)
# TODO everything which is not Atomic expression: Pow) # TODO everything which is not Atomic expression: Pow)
...@@ -401,6 +432,7 @@ class Expr(Node): ...@@ -401,6 +432,7 @@ class Expr(Node):
def __init__(self, args, parent=None): def __init__(self, args, parent=None):
super(Expr, self).__init__(parent) super(Expr, self).__init__(parent)
self._args = list(args) self._args = list(args)
self.dtype = None
@property @property
def args(self): def args(self):
...@@ -430,7 +462,7 @@ class Expr(Node): ...@@ -430,7 +462,7 @@ class Expr(Node):
return set() # Todo fix for symbol analysis return set() # Todo fix for symbol analysis
def __repr__(self): def __repr__(self):
return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args) # TODO test this return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args)
class Mul(Expr): class Mul(Expr):
...@@ -449,4 +481,28 @@ class Indexed(Expr): ...@@ -449,4 +481,28 @@ class Indexed(Expr):
def __repr__(self): def __repr__(self):
return '%s[%s]' % (self.args[0], self.args[1]) return '%s[%s]' % (self.args[0], self.args[1])
class Number(Node):
def __init__(self, number, parent=None):
super(Number, self).__init__(parent)
self._args = None
self.dtype = dtype
@property
def args(self):
"""Returns all arguments/children of this node"""
return self._args
@property
def symbolsDefined(self):
"""Set of symbols which are defined by this node. """
return set()
@property
def undefinedSymbols(self):
"""Symbols which are use but are not defined inside this node"""
raise set()
def __repr__(self):
return '(%s)' % (_c_dtype_dict(self.dtype)) + repr(self.args)
import sympy as sp import sympy as sp
from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \ from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \
desympy_ast, insert_casts
from pystencils.types import TypedSymbol, DataType from pystencils.types import TypedSymbol, DataType
from pystencils.field import Field from pystencils.field import Field
import pystencils.astnodes as ast import pystencils.astnodes as ast
...@@ -59,4 +60,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl ...@@ -59,4 +60,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos) resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos)
moveConstantsBeforeLoop(code) moveConstantsBeforeLoop(code)
desympy_ast(code)
insert_casts(code)
return code return code
\ No newline at end of file
from collections import defaultdict from collections import defaultdict
from operator import attrgetter
import sympy as sp import sympy as sp
from sympy.logic.boolalg import Boolean from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase from sympy.tensor import IndexedBase
...@@ -527,24 +529,56 @@ def getLoopHierarchy(astNode): ...@@ -527,24 +529,56 @@ def getLoopHierarchy(astNode):
return reversed(result) return reversed(result)
def get_type(node):
if isinstance(node, ast.Indexed):
return node.args[0].dtype
elif isinstance(node, ast.Node):
return node.dtype
# TODO sp.NumberSymbol
elif isinstance(node, sp.Number):
if isinstance(node, sp.Float):
return DataType('double')
elif isinstance(node, sp.Integer):
return DataType('int')
else:
raise NotImplemented('Not yet supported: %s %s' % (node, type(node)))
else:
raise NotImplemented('Not yet supported: %s %s' % (node, type(node)))
def insert_casts(node): def insert_casts(node):
if isinstance(node, ast.SympyAssignment): """
Inserts casts where needed
:param node: ast which should be traversed
:return: node
"""
def add_conversion(node, dtype):
return node
for arg in node.args:
insert_casts(arg)
if isinstance(node, ast.Indexed):
pass pass
elif isinstance(node, sp.Expr): elif isinstance(node, ast.Expr):
args = sorted((arg.dtype for arg in node.args), key=attrgetter('ptr', 'dtype'))
target = args[0]
for i in range(len(args)):
args[i] = add_conversion(args[i], target.dtype)
node.args = args
elif isinstance(node, ast.LoopOverCoordinate):
pass pass
else: return node
for arg in node.args:
insert_casts(arg)
def desympy_ast(node): def desympy_ast(node):
# if isinstance(node, sp.Expr) and not isinstance(node, sp.AtomicExpr) and not isinstance(node, sp.tensor.IndexedBase): """
# print(node, type(node)) Remove Sympy Expressions, which have more then one argument.
This is necessary for further changes in the tree.
:param node: ast which should be traversed. Only node's children will be modified.
:return: (modified) node
"""
for i in range(len(node.args)): for i in range(len(node.args)):
arg = node.args[i] arg = node.args[i]
if isinstance(node, ast.SympyAssignment):
print(node, type(arg))
if isinstance(arg, sp.Add): if isinstance(arg, sp.Add):
node.replace(arg, ast.Add(arg.args, node)) node.replace(arg, ast.Add(arg.args, node))
elif isinstance(arg, sp.Mul): elif isinstance(arg, sp.Mul):
...@@ -555,3 +589,4 @@ def desympy_ast(node): ...@@ -555,3 +589,4 @@ def desympy_ast(node):
node.replace(arg, ast.Indexed(arg.args, node)) node.replace(arg, ast.Indexed(arg.args, node))
for arg in node.args: for arg in node.args:
desympy_ast(arg) desympy_ast(arg)
return node
...@@ -29,8 +29,8 @@ class TypedSymbol(sp.Symbol): ...@@ -29,8 +29,8 @@ class TypedSymbol(sp.Symbol):
return self.name, self.dtype return self.name, self.dtype
_c_dtype_dict = {0: 'int', 1: 'double', 2: 'float', 3: 'bool'} _c_dtype_dict = {0: 'bool', 1: 'int', 2: 'float', 3: 'double'}
_dtype_dict = {'int': 0, 'double': 1, 'float': 2, 'bool': 3} _dtype_dict = {'bool': 0, 'int': 1, 'float': 2, 'double': 3}
class DataType(object): class DataType(object):
...@@ -63,3 +63,6 @@ class DataType(object): ...@@ -63,3 +63,6 @@ class DataType(object):
return True return True
else: else:
return False return False
def get_type_from_sympy(node):
return DataType('int')
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment