Commit 860cf788 authored by Jan Hönig's avatar Jan Hönig
Browse files

Jan's rest of Master Thesis and followup Work:

Added LLVM: CodePrinter and a compiler
Updated data_types
Added tests
Added jupyter notebooks
Fixed bugs
Restructured transformation functions
parent 0db06926
......@@ -61,6 +61,10 @@ class Node(object):
for a in self.args:
a.subs(*args, **kwargs)
@property
def func(self):
return self.__class__
def atoms(self, argType):
"""
Returns a set of all children which are an instance of the given argType
......@@ -224,6 +228,7 @@ class Block(Node):
def __init__(self, listOfNodes):
super(Node, self).__init__()
self._nodes = listOfNodes
self.parent = None
for n in self._nodes:
n.parent = self
......@@ -324,6 +329,17 @@ class LoopOverCoordinate(Node):
result.append(e)
return result
def replace(self, child, replacement):
if child == self.body:
self.body = replacement
elif child == self.start:
self.start = replacement
elif child == self.step:
self.step = replacement
elif child == self.stop:
self.stop = replacement
@property
def symbolsDefined(self):
return set([self.loopCounterSymbol])
......@@ -372,11 +388,15 @@ class LoopOverCoordinate(Node):
return len(self.atoms(LoopOverCoordinate)) == 0
def __str__(self):
return 'loop:{!s} in {!s}:{!s}:{!s}\n{!s}'.format(self.loopCounterName, self.start, self.stop, self.step,
("\t" + "\t".join(str(self.body).splitlines(True))))
return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loopCounterName, self.start,
self.loopCounterName, self.stop,
self.loopCounterName, self.step,
("\t" + "\t".join(str(self.body).splitlines(True))))
def __repr__(self):
return 'loop:{!s} in {!s}:{!s}:{!s}'.format(self.loopCounterName, self.start, self.stop, self.step)
return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loopCounterName, self.start,
self.loopCounterName, self.stop,
self.loopCounterName, self.step)
class SympyAssignment(Node):
......@@ -488,141 +508,52 @@ class TemporaryMemoryFree(Node):
return []
# TODO implement defined & undefinedSymbols
class Conversion(Node):
def __init__(self, child, dtype, parent=None):
super(Conversion, self).__init__(parent)
self._args = [child]
self.dtype = dtype
@property
def args(self):
"""Returns all arguments/children of this node"""
return self._args
@args.setter
def args(self, value):
self._args = value
@property
def symbolsDefined(self):
"""Set of symbols which are defined by this node. """
return set()
@property
def undefinedSymbols(self):
"""Symbols which are use but are not defined inside this node"""
raise set()
def __repr__(self):
return '(%s(%s))' % (repr(self.dtype), repr(self.args[0].dtype)) + repr(self.args)
# TODO Pow
_expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'}
class Expr(Node):
def __init__(self, args, parent=None):
super(Expr, self).__init__(parent)
self._args = list(args)
self.dtype = None
@property
def args(self):
return self._args
@args.setter
def args(self, value):
self._args = value
def replace(self, child, replacements):
idx = self.args.index(child)
del self.args[idx]
if type(replacements) is list:
for e in replacements:
e.parent = self
self.args = self.args[:idx] + replacements + self.args[idx:]
else:
replacements.parent = self
self.args.insert(idx, replacements)
@property
def symbolsDefined(self):
return set() # Todo fix for symbol analysis
@property
def undefinedSymbols(self):
return set() # Todo fix for symbol analysis
def __repr__(self):
return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args)
class Mul(Expr):
pass
class Add(Expr):
pass
class Pow(Expr):
pass
class Indexed(Expr):
def __init__(self, args, base, parent=None):
super(Indexed, self).__init__(args, parent)
self.base = base
# Get dtype from label, and unpointer it
self.dtype = createType(base.label.dtype.baseType)
def __repr__(self):
return '%s[%s]' % (self.args[0], self.args[1])
class PointerArithmetic(Expr):
def __init__(self, args, pointer, parent=None):
super(PointerArithmetic, self).__init__([args] + [pointer], parent)
self.pointer = pointer
self.offset = args
self.dtype = pointer.dtype
def __repr__(self):
return '*(%s + %s)' % (self.pointer, self.args)
class Number(Node, sp.AtomicExpr):
def __init__(self, number, parent=None):
super(Number, self).__init__(parent)
self.dtype, self.value = get_type_from_sympy(number)
self._args = tuple()
@property
def args(self):
"""Returns all arguments/children of this node"""
return self._args
@property
def symbolsDefined(self):
"""Set of symbols which are defined by this node. """
return set()
@property
def undefinedSymbols(self):
"""Symbols which are use but are not defined inside this node"""
raise set()
def __repr__(self):
return repr(self.value)
def __float__(self):
return float(self.value)
def __int__(self):
return int(self.value)
#_expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'}
#
#
#class Expr(Node):
# def __init__(self, args, parent=None):
# super(Expr, self).__init__(parent)
# self._args = list(args)
# self.dtype = None
#
# @property
# def args(self):
# return self._args
#
# @args.setter
# def args(self, value):
# self._args = value
#
# def replace(self, child, replacements):
# idx = self.args.index(child)
# del self.args[idx]
# if type(replacements) is list:
# for e in replacements:
# e.parent = self
# self.args = self.args[:idx] + replacements + self.args[idx:]
# else:
# replacements.parent = self
# self.args.insert(idx, replacements)
#
# @property
# def symbolsDefined(self):
# return set() # Todo fix for symbol analysis
#
# @property
# def undefinedSymbols(self):
# return set() # Todo fix for symbol analysis
#
# def __repr__(self):
# return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args)
#
#
#class PointerArithmetic(Expr):
# def __init__(self, args, pointer, parent=None):
# super(PointerArithmetic, self).__init__([args] + [pointer], parent)
# self.pointer = pointer
# self.offset = args
# self.dtype = pointer.dtype
#
# def __repr__(self):
# return '*(%s + %s)' % (self.pointer, self.args)
from sympy.printing.printer import Printer
from graphviz import Digraph, lang
import graphviz
class DotPrinter(Printer):
......@@ -14,7 +15,6 @@ class DotPrinter(Printer):
self.dot.quote_edge = lang.quote
def _print_KernelFunction(self, function):
print(self._nodeToStrFunction(function))
self.dot.node(self._nodeToStrFunction(function), style='filled', fillcolor='#E69F00')
self._print(function.body)
......@@ -75,13 +75,18 @@ def dotprint(node, view=False, short=False, full=False, **kwargs):
:param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph
:return: string in DOT format
"""
nodeToStrFunction = __shortened if short else lambda expr: repr(type(expr)) + repr(expr) if full else repr
nodeToStrFunction = repr
if short:
nodeToStrFunction = __shortened
elif full:
nodeToStrFunction = lambda expr: repr(type(expr)) + repr(expr)
printer = DotPrinter(nodeToStrFunction, full, **kwargs)
dot = printer.doprint(node)
if view:
printer.dot.render(view=view)
return graphviz.Source(dot)
return dot
if __name__ == "__main__":
from pystencils import Field
import sympy as sp
......
import ctypes
import sympy as sp
import numpy as np
import llvmlite.ir as ir
from sympy.core.cache import cacheit
from pystencils.cache import memorycache
......@@ -18,6 +19,16 @@ class castFunc(sp.Function, sp.Rel):
raise NotImplementedError()
class pointerArithmeticFunc(sp.Function, sp.Rel):
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
return self.args[0].canonical
else:
raise NotImplementedError()
class TypedSymbol(sp.Symbol):
def __new__(cls, *args, **kwds):
obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
......@@ -93,7 +104,10 @@ def createTypeFromString(specification):
if basePart[0][-1] == "*":
basePart[0] = basePart[0][:-1]
parts.append('*')
baseType = BasicType(basePart[0], const)
try:
baseType = BasicType(basePart[0], const)
except TypeError:
baseType = BasicType(createTypeFromString.map[basePart[0]], const)
currentType = baseType
# Parse pointer parts
for part in parts:
......@@ -109,6 +123,13 @@ def createTypeFromString(specification):
currentType = PointerType(currentType, const, restrict)
return currentType
createTypeFromString.map = {
'i64': np.int64,
'i32': np.int32,
'i16': np.int16,
'i8': np.int8,
}
def getBaseType(type):
while type.baseType is not None:
......@@ -145,6 +166,60 @@ toCtypes.map = {
}
def ctypes_from_llvm(data_type):
if isinstance(data_type, ir.PointerType):
ctype = ctypes_from_llvm(data_type.pointee)
if ctype is None:
return ctypes.c_void_p
else:
return ctypes.POINTER(ctype)
elif isinstance(data_type, ir.IntType):
if data_type.width == 8:
return ctypes.c_int8
elif data_type.width == 16:
return ctypes.c_int16
elif data_type.width == 32:
return ctypes.c_int32
elif data_type.width == 64:
return ctypes.c_int64
else:
raise ValueError("Int width %d is not supported" % data_type.width)
elif isinstance(data_type, ir.FloatType):
return ctypes.c_float
elif isinstance(data_type, ir.DoubleType):
return ctypes.c_double
elif isinstance(data_type, ir.VoidType):
return None # Void type is not supported by ctypes
else:
raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))
def to_llvm_type(data_type):
"""
Transforms a given type into ctypes
:param data_type: Subclass of Type
:return: llvmlite type object
"""
if isinstance(data_type, PointerType):
return to_llvm_type(data_type.baseType).as_pointer()
else:
return to_llvm_type.map[data_type.numpyDtype]
to_llvm_type.map = {
np.dtype(np.int8): ir.IntType(8),
np.dtype(np.int16): ir.IntType(16),
np.dtype(np.int32): ir.IntType(32),
np.dtype(np.int64): ir.IntType(64),
np.dtype(np.uint8): ir.IntType(8),
np.dtype(np.uint16): ir.IntType(16),
np.dtype(np.uint32): ir.IntType(32),
np.dtype(np.uint64): ir.IntType(64),
np.dtype(np.float32): ir.FloatType(),
np.dtype(np.float64): ir.DoubleType(),
}
def peelOffType(dtype, typeToPeelOff):
while type(dtype) is typeToPeelOff:
dtype = dtype.baseType
......@@ -210,7 +285,7 @@ def getTypeOfExpression(expr):
return collateTypes(tuple(getTypeOfExpression(a) for a in branchResults))
elif isinstance(expr, sp.Indexed):
typedSymbol = expr.base.label
return typedSymbol.dtype
return typedSymbol.dtype.baseType
elif isinstance(expr, sp.boolalg.Boolean):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = createTypeFromString("bool")
......@@ -222,31 +297,36 @@ def getTypeOfExpression(expr):
types = tuple(getTypeOfExpression(a) for a in expr.args)
return collateTypes(types)
raise NotImplementedError("Could not determine type for " + str(expr))
raise NotImplementedError("Could not determine type for", expr, type(expr))
class Type(sp.Basic):
def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls)
def __lt__(self, other):
def __lt__(self, other): # deprecated
# Needed for sorting the types inside an expression
if isinstance(self, BasicType):
if isinstance(other, BasicType):
return self.numpyDtype < other.numpyDtype # TODO const
if isinstance(other, PointerType):
return self.numpyDtype > other.numpyDtype # TODO const
elif isinstance(other, PointerType):
return False
if isinstance(other, StructType):
else: # isinstance(other, StructType):
raise NotImplementedError("Struct type comparison is not yet implemented")
if isinstance(self, PointerType):
elif isinstance(self, PointerType):
if isinstance(other, BasicType):
return True
if isinstance(other, PointerType):
return self.baseType < other.baseType # TODO const, restrict
if isinstance(other, StructType):
elif isinstance(other, PointerType):
return self.baseType > other.baseType # TODO const, restrict
else: # isinstance(other, StructType):
raise NotImplementedError("Struct type comparison is not yet implemented")
if isinstance(self, StructType):
elif isinstance(self, StructType):
raise NotImplementedError("Struct type comparison is not yet implemented")
else:
raise NotImplementedError
def _sympystr(self, *args, **kwargs):
return str(self)
class BasicType(Type):
......@@ -317,6 +397,9 @@ class BasicType(Type):
result += " const"
return result
def __repr__(self):
return str(self)
def __eq__(self, other):
if not isinstance(other, BasicType):
return False
......@@ -397,6 +480,9 @@ class PointerType(Type):
def __str__(self):
return "%s *%s%s" % (self.baseType, " RESTRICT " if self.restrict else "", " const " if self.const else "")
def __repr__(self):
return str(self)
def __hash__(self):
return hash(str(self))
......@@ -444,6 +530,9 @@ class StructType(object):
result += " const"
return result
def __repr__(self):
return str(self)
def __hash__(self):
return hash((self.numpyDtype, self.const))
......@@ -475,6 +564,7 @@ def get_type_from_sympy(node):
elif isinstance(node, sp.Integer):
return createType('int'), int(node)
elif isinstance(node, sp.Rational):
raise NotImplementedError('Rationals are not supported yet')
# TODO is it always float?
return createType('double'), float(node.p/node.q)
else:
raise TypeError(node, ' is not a supported type (yet)!')
......@@ -317,6 +317,10 @@ class Field(object):
def offsets(self):
return self._offsets
@offsets.setter
def offsets(self, value):
self._offsets = value
@property
def requiredGhostLayers(self):
return int(np.max(np.abs(self._offsets)))
......
from .kernelcreation import createKernel
from .jit import compileLLVM
\ No newline at end of file
from .kernelcreation import createKernel, createIndexedKernel
from .llvmjit import compileLLVM, generate_and_jit, Jit, make_python_function
from .llvm import generateLLVM
import sympy as sp
from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \
desympy_ast, insert_casts
from pystencils.data_types import TypedSymbol
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, insertCasts#, \
#desympy_ast, insert_casts
from pystencils.data_types import TypedSymbol, BasicType, StructType
from pystencils.field import Field
import pystencils.astnodes as ast
......@@ -54,17 +55,85 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
typedSplitGroups = [[typeSymbol(s) for s in splitGroup] for splitGroup in splitGroups]
splitInnerLoop(code, typedSplitGroups)
basePointerInfo = [['spatialInner0'], ['spatialInner1']]
basePointerInfo = []
for i in range(len(loopOrder)):
basePointerInfo.append(['spatialInner%d' % i])
basePointerInfos = {field.name: parseBasePointerInfo(basePointerInfo, loopOrder, field) for field in allFields}
resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos)
moveConstantsBeforeLoop(code)
print('Ast:')
#print('Ast:')
#print(code)
#desympy_ast(code)
#print('Desympied ast:')
#print(code)
#insert_casts(code)
print(code)
desympy_ast(code)
print('Desympied ast:')
code = insertCasts(code)
print(code)
insert_casts(code)
return code
def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typeForSymbol=None,
coordinateNames=('x', 'y', 'z')):
"""
Similar to :func:`createKernel`, but here not all cells of a field are updated but only cells with
coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling.
The coordinates are stored in a separated indexField, which is a one dimensional array with struct data type.
This struct has to contain fields named 'x', 'y' and for 3D fields ('z'). These names are configurable with the
'coordinateNames' parameter. The struct can have also other fields that can be read and written in the kernel, for
example boundary parameters.
:param listOfEquations: list of update equations or AST nodes
:param indexFields: list of index fields, i.e. 1D fields with struct data type
:param typeForSymbol: see documentation of :func:`createKernel`
:param functionName: see documentation of :func:`createKernel`
:param coordinateNames: name of the coordinate fields in the struct data type
:return: abstract syntax tree
"""
fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
allFields = fieldsRead.union(fieldsWritten)
for indexField in indexFields:
indexField.isIndexField = True
assert indexField.spatialDimensions == 1, "Index fields have to be 1D"
nonIndexFields = [f for f in allFields if f not in indexFields]
spatialCoordinates = {f.spatialDimensions for f in nonIndexFields}
assert len(spatialCoordinates) == 1, "Non-index fields do not have the same number of spatial coordinates"
spatialCoordinates = list(spatialCoordinates)[0]
def getCoordinateSymbolAssignment(name):
for indexField in indexFields:
assert isinstance(indexField.dtype, StructType), "Index fields have to have a struct datatype"
dataType = indexField.dtype
if dataType.hasElement(name):
rhs = indexField[0](name)
lhs = TypedSymbol(name, BasicType(dataType.getElementType(name)))
return SympyAssignment(lhs, rhs)
raise ValueError("Index %s not found in any of the passed index fields" % (name,))
coordinateSymbolAssignments = [getCoordinateSymbolAssignment(n) for n in coordinateNames[:spatialCoordinates]]
coordinateTypedSymbols = [eq.lhs for eq in coordinateSymbolAssignments]
assignments = coordinateSymbolAssignments + assignments
# make 1D loop over index fields
loopBody = Block([])
loopNode = LoopOverCoordinate(loopBody, coordinateToLoopOver=0, start=0, stop=indexFields[0].shape[0])
for assignment in assignments:
loopBody.append(assignment)
functionBody = Block([loopNode])
ast = KernelFunction(functionBody, allFields, functionName)
fixedCoordinateMapping = {f.name: coordinateTypedSymbols for f in nonIndexFields}
resolveFieldAccesses(ast, set(['indexField']), fieldToFixedCoordinates=fixedCoordinateMapping)
moveConstantsBeforeLoop(ast)
desympy_ast(ast)
insert_casts(ast)
return ast
......@@ -6,16 +6,20 @@ from sympy import S
# S is numbers?
from pystencils.llvm.control_flow import Loop
from ..data_types import createType
from ..astnodes import Indexed
from pystencils.data_types import createType, to_llvm_type, getTypeOfExpression
from sympy import Indexed # TODO used astnodes, this should not work!
def generateLLVM(ast_node, module=ir.Module(), builder=ir.IRBuilder()):
def generateLLVM(ast_node, module=None, builder