Skip to content
Snippets Groups Projects
Commit 98176304 authored by Jan Hoenig's avatar Jan Hoenig
Browse files

not done yet

parent f9e81a26
No related merge requests found
import sympy as sp
from sympy.tensor import IndexedBase, Indexed
from pystencils.field import Field
from pystencils.types import TypedSymbol, DataType
from pystencils.types import TypedSymbol, DataType, _c_dtype_dict
class Node(object):
......@@ -391,6 +391,37 @@ class TemporaryMemoryFree(Node):
return []
# TODO implement defined & undefinedSymbols
class Conversion(Node):
def __init__(self, child, dtype, parent=None):
super(Conversion, self).__init__(parent)
self._args = [child]
self.dtype = dtype
@property
def args(self):
"""Returns all arguments/children of this node"""
return self._args
@args.setter
def args(self, value):
self._args = value
@property
def symbolsDefined(self):
"""Set of symbols which are defined by this node. """
return set()
@property
def undefinedSymbols(self):
"""Symbols which are use but are not defined inside this node"""
raise set()
def __repr__(self):
return '(%s)' % (_c_dtype_dict(self.dtype)) + repr(self.args)
# TODO everything which is not Atomic expression: Pow)
......@@ -401,6 +432,7 @@ class Expr(Node):
def __init__(self, args, parent=None):
super(Expr, self).__init__(parent)
self._args = list(args)
self.dtype = None
@property
def args(self):
......@@ -430,7 +462,7 @@ class Expr(Node):
return set() # Todo fix for symbol analysis
def __repr__(self):
return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args) # TODO test this
return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args)
class Mul(Expr):
......@@ -449,4 +481,28 @@ class Indexed(Expr):
def __repr__(self):
return '%s[%s]' % (self.args[0], self.args[1])
class Number(Node):
def __init__(self, number, parent=None):
super(Number, self).__init__(parent)
self._args = None
self.dtype = dtype
@property
def args(self):
"""Returns all arguments/children of this node"""
return self._args
@property
def symbolsDefined(self):
"""Set of symbols which are defined by this node. """
return set()
@property
def undefinedSymbols(self):
"""Symbols which are use but are not defined inside this node"""
raise set()
def __repr__(self):
return '(%s)' % (_c_dtype_dict(self.dtype)) + repr(self.args)
import sympy as sp
from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \
desympy_ast, insert_casts
from pystencils.types import TypedSymbol, DataType
from pystencils.field import Field
import pystencils.astnodes as ast
......@@ -59,4 +60,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos)
moveConstantsBeforeLoop(code)
desympy_ast(code)
insert_casts(code)
return code
\ No newline at end of file
from collections import defaultdict
from operator import attrgetter
import sympy as sp
from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase
......@@ -527,24 +529,56 @@ def getLoopHierarchy(astNode):
return reversed(result)
def get_type(node):
if isinstance(node, ast.Indexed):
return node.args[0].dtype
elif isinstance(node, ast.Node):
return node.dtype
# TODO sp.NumberSymbol
elif isinstance(node, sp.Number):
if isinstance(node, sp.Float):
return DataType('double')
elif isinstance(node, sp.Integer):
return DataType('int')
else:
raise NotImplemented('Not yet supported: %s %s' % (node, type(node)))
else:
raise NotImplemented('Not yet supported: %s %s' % (node, type(node)))
def insert_casts(node):
if isinstance(node, ast.SympyAssignment):
"""
Inserts casts where needed
:param node: ast which should be traversed
:return: node
"""
def add_conversion(node, dtype):
return node
for arg in node.args:
insert_casts(arg)
if isinstance(node, ast.Indexed):
pass
elif isinstance(node, sp.Expr):
elif isinstance(node, ast.Expr):
args = sorted((arg.dtype for arg in node.args), key=attrgetter('ptr', 'dtype'))
target = args[0]
for i in range(len(args)):
args[i] = add_conversion(args[i], target.dtype)
node.args = args
elif isinstance(node, ast.LoopOverCoordinate):
pass
else:
for arg in node.args:
insert_casts(arg)
return node
def desympy_ast(node):
# if isinstance(node, sp.Expr) and not isinstance(node, sp.AtomicExpr) and not isinstance(node, sp.tensor.IndexedBase):
# print(node, type(node))
"""
Remove Sympy Expressions, which have more then one argument.
This is necessary for further changes in the tree.
:param node: ast which should be traversed. Only node's children will be modified.
:return: (modified) node
"""
for i in range(len(node.args)):
arg = node.args[i]
if isinstance(node, ast.SympyAssignment):
print(node, type(arg))
if isinstance(arg, sp.Add):
node.replace(arg, ast.Add(arg.args, node))
elif isinstance(arg, sp.Mul):
......@@ -555,3 +589,4 @@ def desympy_ast(node):
node.replace(arg, ast.Indexed(arg.args, node))
for arg in node.args:
desympy_ast(arg)
return node
......@@ -29,8 +29,8 @@ class TypedSymbol(sp.Symbol):
return self.name, self.dtype
_c_dtype_dict = {0: 'int', 1: 'double', 2: 'float', 3: 'bool'}
_dtype_dict = {'int': 0, 'double': 1, 'float': 2, 'bool': 3}
_c_dtype_dict = {0: 'bool', 1: 'int', 2: 'float', 3: 'double'}
_dtype_dict = {'bool': 0, 'int': 1, 'float': 2, 'double': 3}
class DataType(object):
......@@ -63,3 +63,6 @@ class DataType(object):
return True
else:
return False
def get_type_from_sympy(node):
return DataType('int')
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment