From 860cf788a3d0970aa1faa2661cbd21a1d2a2fba1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de> Date: Wed, 11 Oct 2017 15:44:35 +0200 Subject: [PATCH] Jan's rest of Master Thesis and followup Work: Added LLVM: CodePrinter and a compiler Updated data_types Added tests Added jupyter notebooks Fixed bugs Restructured transformation functions --- astnodes.py | 213 ++++++------------ backends/dot.py | 11 +- data_types.py | 116 ++++++++-- field.py | 4 + llvm/__init__.py | 5 +- llvm/jit.py | 81 ------- llvm/kernelcreation.py | 87 ++++++- {backends => llvm}/llvm.py | 67 +++--- llvm/llvmjit.py | 182 +++++++++++++++ transformations/__init__.py | 2 + transformations/stage2.py | 159 +++++++++++++ .../transformations.py | 1 - 12 files changed, 647 insertions(+), 281 deletions(-) delete mode 100644 llvm/jit.py rename {backends => llvm}/llvm.py (81%) create mode 100644 llvm/llvmjit.py create mode 100644 transformations/__init__.py create mode 100644 transformations/stage2.py rename transformations.py => transformations/transformations.py (99%) diff --git a/astnodes.py b/astnodes.py index 09a23075f..2255e9455 100644 --- a/astnodes.py +++ b/astnodes.py @@ -61,6 +61,10 @@ class Node(object): for a in self.args: a.subs(*args, **kwargs) + @property + def func(self): + return self.__class__ + def atoms(self, argType): """ Returns a set of all children which are an instance of the given argType @@ -224,6 +228,7 @@ class Block(Node): def __init__(self, listOfNodes): super(Node, self).__init__() self._nodes = listOfNodes + self.parent = None for n in self._nodes: n.parent = self @@ -324,6 +329,17 @@ class LoopOverCoordinate(Node): result.append(e) return result + def replace(self, child, replacement): + if child == self.body: + self.body = replacement + elif child == self.start: + self.start = replacement + elif child == self.step: + self.step = replacement + elif child == self.stop: + self.stop = replacement + + @property def symbolsDefined(self): return set([self.loopCounterSymbol]) @@ -372,11 +388,15 @@ class LoopOverCoordinate(Node): return len(self.atoms(LoopOverCoordinate)) == 0 def __str__(self): - return 'loop:{!s} in {!s}:{!s}:{!s}\n{!s}'.format(self.loopCounterName, self.start, self.stop, self.step, - ("\t" + "\t".join(str(self.body).splitlines(True)))) + return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loopCounterName, self.start, + self.loopCounterName, self.stop, + self.loopCounterName, self.step, + ("\t" + "\t".join(str(self.body).splitlines(True)))) def __repr__(self): - return 'loop:{!s} in {!s}:{!s}:{!s}'.format(self.loopCounterName, self.start, self.stop, self.step) + return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loopCounterName, self.start, + self.loopCounterName, self.stop, + self.loopCounterName, self.step) class SympyAssignment(Node): @@ -488,141 +508,52 @@ class TemporaryMemoryFree(Node): return [] -# TODO implement defined & undefinedSymbols -class Conversion(Node): - def __init__(self, child, dtype, parent=None): - super(Conversion, self).__init__(parent) - self._args = [child] - self.dtype = dtype - - @property - def args(self): - """Returns all arguments/children of this node""" - return self._args - - @args.setter - def args(self, value): - self._args = value - - @property - def symbolsDefined(self): - """Set of symbols which are defined by this node. """ - return set() - - @property - def undefinedSymbols(self): - """Symbols which are use but are not defined inside this node""" - raise set() - - def __repr__(self): - return '(%s(%s))' % (repr(self.dtype), repr(self.args[0].dtype)) + repr(self.args) - -# TODO Pow - - -_expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'} - - -class Expr(Node): - def __init__(self, args, parent=None): - super(Expr, self).__init__(parent) - self._args = list(args) - self.dtype = None - - @property - def args(self): - return self._args - - @args.setter - def args(self, value): - self._args = value - - def replace(self, child, replacements): - idx = self.args.index(child) - del self.args[idx] - if type(replacements) is list: - for e in replacements: - e.parent = self - self.args = self.args[:idx] + replacements + self.args[idx:] - else: - replacements.parent = self - self.args.insert(idx, replacements) - - @property - def symbolsDefined(self): - return set() # Todo fix for symbol analysis - - @property - def undefinedSymbols(self): - return set() # Todo fix for symbol analysis - - def __repr__(self): - return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args) - - -class Mul(Expr): - pass - - -class Add(Expr): - pass - - -class Pow(Expr): - pass - - -class Indexed(Expr): - def __init__(self, args, base, parent=None): - super(Indexed, self).__init__(args, parent) - self.base = base - # Get dtype from label, and unpointer it - self.dtype = createType(base.label.dtype.baseType) - - def __repr__(self): - return '%s[%s]' % (self.args[0], self.args[1]) - - -class PointerArithmetic(Expr): - def __init__(self, args, pointer, parent=None): - super(PointerArithmetic, self).__init__([args] + [pointer], parent) - self.pointer = pointer - self.offset = args - self.dtype = pointer.dtype - - def __repr__(self): - return '*(%s + %s)' % (self.pointer, self.args) - - -class Number(Node, sp.AtomicExpr): - def __init__(self, number, parent=None): - super(Number, self).__init__(parent) - - self.dtype, self.value = get_type_from_sympy(number) - self._args = tuple() - - @property - def args(self): - """Returns all arguments/children of this node""" - return self._args - - @property - def symbolsDefined(self): - """Set of symbols which are defined by this node. """ - return set() - - @property - def undefinedSymbols(self): - """Symbols which are use but are not defined inside this node""" - raise set() - - def __repr__(self): - return repr(self.value) - - def __float__(self): - return float(self.value) - - def __int__(self): - return int(self.value) - - +#_expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'} +# +# +#class Expr(Node): +# def __init__(self, args, parent=None): +# super(Expr, self).__init__(parent) +# self._args = list(args) +# self.dtype = None +# +# @property +# def args(self): +# return self._args +# +# @args.setter +# def args(self, value): +# self._args = value +# +# def replace(self, child, replacements): +# idx = self.args.index(child) +# del self.args[idx] +# if type(replacements) is list: +# for e in replacements: +# e.parent = self +# self.args = self.args[:idx] + replacements + self.args[idx:] +# else: +# replacements.parent = self +# self.args.insert(idx, replacements) +# +# @property +# def symbolsDefined(self): +# return set() # Todo fix for symbol analysis +# +# @property +# def undefinedSymbols(self): +# return set() # Todo fix for symbol analysis +# +# def __repr__(self): +# return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args) +# +# +#class PointerArithmetic(Expr): +# def __init__(self, args, pointer, parent=None): +# super(PointerArithmetic, self).__init__([args] + [pointer], parent) +# self.pointer = pointer +# self.offset = args +# self.dtype = pointer.dtype +# +# def __repr__(self): +# return '*(%s + %s)' % (self.pointer, self.args) diff --git a/backends/dot.py b/backends/dot.py index 63ff4651f..a36c8a5a3 100644 --- a/backends/dot.py +++ b/backends/dot.py @@ -1,5 +1,6 @@ from sympy.printing.printer import Printer from graphviz import Digraph, lang +import graphviz class DotPrinter(Printer): @@ -14,7 +15,6 @@ class DotPrinter(Printer): self.dot.quote_edge = lang.quote def _print_KernelFunction(self, function): - print(self._nodeToStrFunction(function)) self.dot.node(self._nodeToStrFunction(function), style='filled', fillcolor='#E69F00') self._print(function.body) @@ -75,13 +75,18 @@ def dotprint(node, view=False, short=False, full=False, **kwargs): :param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph :return: string in DOT format """ - nodeToStrFunction = __shortened if short else lambda expr: repr(type(expr)) + repr(expr) if full else repr + nodeToStrFunction = repr + if short: + nodeToStrFunction = __shortened + elif full: + nodeToStrFunction = lambda expr: repr(type(expr)) + repr(expr) printer = DotPrinter(nodeToStrFunction, full, **kwargs) dot = printer.doprint(node) if view: - printer.dot.render(view=view) + return graphviz.Source(dot) return dot + if __name__ == "__main__": from pystencils import Field import sympy as sp diff --git a/data_types.py b/data_types.py index 6442147cd..6f766b7aa 100644 --- a/data_types.py +++ b/data_types.py @@ -1,6 +1,7 @@ import ctypes import sympy as sp import numpy as np +import llvmlite.ir as ir from sympy.core.cache import cacheit from pystencils.cache import memorycache @@ -18,6 +19,16 @@ class castFunc(sp.Function, sp.Rel): raise NotImplementedError() +class pointerArithmeticFunc(sp.Function, sp.Rel): + + @property + def canonical(self): + if hasattr(self.args[0], 'canonical'): + return self.args[0].canonical + else: + raise NotImplementedError() + + class TypedSymbol(sp.Symbol): def __new__(cls, *args, **kwds): obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds) @@ -93,7 +104,10 @@ def createTypeFromString(specification): if basePart[0][-1] == "*": basePart[0] = basePart[0][:-1] parts.append('*') - baseType = BasicType(basePart[0], const) + try: + baseType = BasicType(basePart[0], const) + except TypeError: + baseType = BasicType(createTypeFromString.map[basePart[0]], const) currentType = baseType # Parse pointer parts for part in parts: @@ -109,6 +123,13 @@ def createTypeFromString(specification): currentType = PointerType(currentType, const, restrict) return currentType +createTypeFromString.map = { + 'i64': np.int64, + 'i32': np.int32, + 'i16': np.int16, + 'i8': np.int8, +} + def getBaseType(type): while type.baseType is not None: @@ -145,6 +166,60 @@ toCtypes.map = { } +def ctypes_from_llvm(data_type): + if isinstance(data_type, ir.PointerType): + ctype = ctypes_from_llvm(data_type.pointee) + if ctype is None: + return ctypes.c_void_p + else: + return ctypes.POINTER(ctype) + elif isinstance(data_type, ir.IntType): + if data_type.width == 8: + return ctypes.c_int8 + elif data_type.width == 16: + return ctypes.c_int16 + elif data_type.width == 32: + return ctypes.c_int32 + elif data_type.width == 64: + return ctypes.c_int64 + else: + raise ValueError("Int width %d is not supported" % data_type.width) + elif isinstance(data_type, ir.FloatType): + return ctypes.c_float + elif isinstance(data_type, ir.DoubleType): + return ctypes.c_double + elif isinstance(data_type, ir.VoidType): + return None # Void type is not supported by ctypes + else: + raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type)) + + +def to_llvm_type(data_type): + """ + Transforms a given type into ctypes + :param data_type: Subclass of Type + :return: llvmlite type object + """ + if isinstance(data_type, PointerType): + return to_llvm_type(data_type.baseType).as_pointer() + else: + return to_llvm_type.map[data_type.numpyDtype] + +to_llvm_type.map = { + np.dtype(np.int8): ir.IntType(8), + np.dtype(np.int16): ir.IntType(16), + np.dtype(np.int32): ir.IntType(32), + np.dtype(np.int64): ir.IntType(64), + + np.dtype(np.uint8): ir.IntType(8), + np.dtype(np.uint16): ir.IntType(16), + np.dtype(np.uint32): ir.IntType(32), + np.dtype(np.uint64): ir.IntType(64), + + np.dtype(np.float32): ir.FloatType(), + np.dtype(np.float64): ir.DoubleType(), +} + def peelOffType(dtype, typeToPeelOff): while type(dtype) is typeToPeelOff: dtype = dtype.baseType @@ -210,7 +285,7 @@ def getTypeOfExpression(expr): return collateTypes(tuple(getTypeOfExpression(a) for a in branchResults)) elif isinstance(expr, sp.Indexed): typedSymbol = expr.base.label - return typedSymbol.dtype + return typedSymbol.dtype.baseType elif isinstance(expr, sp.boolalg.Boolean): # if any arg is of vector type return a vector boolean, else return a normal scalar boolean result = createTypeFromString("bool") @@ -222,31 +297,36 @@ def getTypeOfExpression(expr): types = tuple(getTypeOfExpression(a) for a in expr.args) return collateTypes(types) - raise NotImplementedError("Could not determine type for " + str(expr)) + raise NotImplementedError("Could not determine type for", expr, type(expr)) class Type(sp.Basic): def __new__(cls, *args, **kwargs): return sp.Basic.__new__(cls) - def __lt__(self, other): + def __lt__(self, other): # deprecated # Needed for sorting the types inside an expression if isinstance(self, BasicType): if isinstance(other, BasicType): - return self.numpyDtype < other.numpyDtype # TODO const - if isinstance(other, PointerType): + return self.numpyDtype > other.numpyDtype # TODO const + elif isinstance(other, PointerType): return False - if isinstance(other, StructType): + else: # isinstance(other, StructType): raise NotImplementedError("Struct type comparison is not yet implemented") - if isinstance(self, PointerType): + elif isinstance(self, PointerType): if isinstance(other, BasicType): return True - if isinstance(other, PointerType): - return self.baseType < other.baseType # TODO const, restrict - if isinstance(other, StructType): + elif isinstance(other, PointerType): + return self.baseType > other.baseType # TODO const, restrict + else: # isinstance(other, StructType): raise NotImplementedError("Struct type comparison is not yet implemented") - if isinstance(self, StructType): + elif isinstance(self, StructType): raise NotImplementedError("Struct type comparison is not yet implemented") + else: + raise NotImplementedError + + def _sympystr(self, *args, **kwargs): + return str(self) class BasicType(Type): @@ -317,6 +397,9 @@ class BasicType(Type): result += " const" return result + def __repr__(self): + return str(self) + def __eq__(self, other): if not isinstance(other, BasicType): return False @@ -397,6 +480,9 @@ class PointerType(Type): def __str__(self): return "%s *%s%s" % (self.baseType, " RESTRICT " if self.restrict else "", " const " if self.const else "") + def __repr__(self): + return str(self) + def __hash__(self): return hash(str(self)) @@ -444,6 +530,9 @@ class StructType(object): result += " const" return result + def __repr__(self): + return str(self) + def __hash__(self): return hash((self.numpyDtype, self.const)) @@ -475,6 +564,7 @@ def get_type_from_sympy(node): elif isinstance(node, sp.Integer): return createType('int'), int(node) elif isinstance(node, sp.Rational): - raise NotImplementedError('Rationals are not supported yet') + # TODO is it always float? + return createType('double'), float(node.p/node.q) else: raise TypeError(node, ' is not a supported type (yet)!') diff --git a/field.py b/field.py index 98e388c49..3fb5b72e3 100644 --- a/field.py +++ b/field.py @@ -317,6 +317,10 @@ class Field(object): def offsets(self): return self._offsets + @offsets.setter + def offsets(self, value): + self._offsets = value + @property def requiredGhostLayers(self): return int(np.max(np.abs(self._offsets))) diff --git a/llvm/__init__.py b/llvm/__init__.py index f34532f8b..16cd3d751 100644 --- a/llvm/__init__.py +++ b/llvm/__init__.py @@ -1,2 +1,3 @@ -from .kernelcreation import createKernel -from .jit import compileLLVM \ No newline at end of file +from .kernelcreation import createKernel, createIndexedKernel +from .llvmjit import compileLLVM, generate_and_jit, Jit, make_python_function +from .llvm import generateLLVM diff --git a/llvm/jit.py b/llvm/jit.py deleted file mode 100644 index 918c202f3..000000000 --- a/llvm/jit.py +++ /dev/null @@ -1,81 +0,0 @@ -import llvmlite.ir as ir -import llvmlite.binding as llvm -from ..data_types import toCtypes, createType - -import ctypes as ct - - -def compileLLVM(module): - jit = Jit() - jit.parse(module) - jit.optimize() - jit.compile() - return jit - - -class Jit(object): - def __init__(self): - llvm.initialize() - llvm.initialize_all_targets() - llvm.initialize_native_target() - llvm.initialize_native_asmprinter() - - self.module = None - self.llvmmod = None - self.target = llvm.Target.from_default_triple() - self.cpu = llvm.get_host_cpu_name() - self.cpu_features = llvm.get_host_cpu_features() - self.target_machine = self.target.create_target_machine(cpu=self.cpu, features=self.cpu_features.flatten(), opt=2) - self.ee = None - self.fptr = None - - def parse(self, module): - self.module = module - llvmmod = llvm.parse_assembly(str(module)) - llvmmod.verify() - self.llvmmod = llvmmod - - def write_ll(self, file): - with open(file, 'w') as f: - f.write(str(self.llvmmod)) - - def optimize(self): - pmb = llvm.create_pass_manager_builder() - pmb.opt_level = 2 - pmb.disable_unit_at_a_time = False - pmb.loop_vectorize = True - pmb.slp_vectorize = True - # TODO possible to pass for functions - pm = llvm.create_module_pass_manager() - pm.add_instruction_combining_pass() - pm.add_function_attrs_pass() - pm.add_constant_merge_pass() - pm.add_licm_pass() - pmb.populate(pm) - pm.run(self.llvmmod) - - def compile(self, assembly_file=None, object_file=None): - ee = llvm.create_mcjit_compiler(self.llvmmod, self.target_machine) - ee.finalize_object() - - if assembly_file is not None: - with open(assembly_file, 'w') as f: - f.write(self.target_machine.emit_assembly(self.llvmmod)) - if object_file is not None: - with open(object_file, 'wb') as f: - f.write(self.target_machine.emit_object(self.llvmmod)) - - fptr = {} - for function in self.module.functions: - if not function.is_declaration: - return_type = None - if function.ftype.return_type != ir.VoidType(): - return_type = toCtypes(createType(str(function.ftype.return_type))) - args = [toCtypes(createType(str(arg))) for arg in function.ftype.args] - function_address = ee.get_function_address(function.name) - fptr[function.name] = ct.CFUNCTYPE(return_type, *args)(function_address) - self.ee = ee - self.fptr = fptr - - def __call__(self, function, *args, **kwargs): - self.fptr[function](*args, **kwargs) diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index 403c9bb53..a22c39040 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -1,8 +1,9 @@ import sympy as sp +from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \ - typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \ - desympy_ast, insert_casts -from pystencils.data_types import TypedSymbol + typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, insertCasts#, \ + #desympy_ast, insert_casts +from pystencils.data_types import TypedSymbol, BasicType, StructType from pystencils.field import Field import pystencils.astnodes as ast @@ -54,17 +55,85 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl typedSplitGroups = [[typeSymbol(s) for s in splitGroup] for splitGroup in splitGroups] splitInnerLoop(code, typedSplitGroups) - basePointerInfo = [['spatialInner0'], ['spatialInner1']] + basePointerInfo = [] + for i in range(len(loopOrder)): + basePointerInfo.append(['spatialInner%d' % i]) basePointerInfos = {field.name: parseBasePointerInfo(basePointerInfo, loopOrder, field) for field in allFields} resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos) moveConstantsBeforeLoop(code) - print('Ast:') + #print('Ast:') + #print(code) + #desympy_ast(code) + #print('Desympied ast:') + #print(code) + #insert_casts(code) print(code) - desympy_ast(code) - print('Desympied ast:') + code = insertCasts(code) print(code) - insert_casts(code) - return code + + +def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typeForSymbol=None, + coordinateNames=('x', 'y', 'z')): + """ + Similar to :func:`createKernel`, but here not all cells of a field are updated but only cells with + coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling. + + The coordinates are stored in a separated indexField, which is a one dimensional array with struct data type. + This struct has to contain fields named 'x', 'y' and for 3D fields ('z'). These names are configurable with the + 'coordinateNames' parameter. The struct can have also other fields that can be read and written in the kernel, for + example boundary parameters. + + :param listOfEquations: list of update equations or AST nodes + :param indexFields: list of index fields, i.e. 1D fields with struct data type + :param typeForSymbol: see documentation of :func:`createKernel` + :param functionName: see documentation of :func:`createKernel` + :param coordinateNames: name of the coordinate fields in the struct data type + :return: abstract syntax tree + """ + fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol) + allFields = fieldsRead.union(fieldsWritten) + + for indexField in indexFields: + indexField.isIndexField = True + assert indexField.spatialDimensions == 1, "Index fields have to be 1D" + + nonIndexFields = [f for f in allFields if f not in indexFields] + spatialCoordinates = {f.spatialDimensions for f in nonIndexFields} + assert len(spatialCoordinates) == 1, "Non-index fields do not have the same number of spatial coordinates" + spatialCoordinates = list(spatialCoordinates)[0] + + def getCoordinateSymbolAssignment(name): + for indexField in indexFields: + assert isinstance(indexField.dtype, StructType), "Index fields have to have a struct datatype" + dataType = indexField.dtype + if dataType.hasElement(name): + rhs = indexField[0](name) + lhs = TypedSymbol(name, BasicType(dataType.getElementType(name))) + return SympyAssignment(lhs, rhs) + raise ValueError("Index %s not found in any of the passed index fields" % (name,)) + + coordinateSymbolAssignments = [getCoordinateSymbolAssignment(n) for n in coordinateNames[:spatialCoordinates]] + coordinateTypedSymbols = [eq.lhs for eq in coordinateSymbolAssignments] + assignments = coordinateSymbolAssignments + assignments + + # make 1D loop over index fields + loopBody = Block([]) + loopNode = LoopOverCoordinate(loopBody, coordinateToLoopOver=0, start=0, stop=indexFields[0].shape[0]) + + for assignment in assignments: + loopBody.append(assignment) + + functionBody = Block([loopNode]) + ast = KernelFunction(functionBody, allFields, functionName) + + fixedCoordinateMapping = {f.name: coordinateTypedSymbols for f in nonIndexFields} + resolveFieldAccesses(ast, set(['indexField']), fieldToFixedCoordinates=fixedCoordinateMapping) + moveConstantsBeforeLoop(ast) + + desympy_ast(ast) + insert_casts(ast) + + return ast diff --git a/backends/llvm.py b/llvm/llvm.py similarity index 81% rename from backends/llvm.py rename to llvm/llvm.py index 34fed3765..14cbcc681 100644 --- a/backends/llvm.py +++ b/llvm/llvm.py @@ -6,16 +6,20 @@ from sympy import S # S is numbers? from pystencils.llvm.control_flow import Loop -from ..data_types import createType -from ..astnodes import Indexed +from pystencils.data_types import createType, to_llvm_type, getTypeOfExpression +from sympy import Indexed # TODO used astnodes, this should not work! -def generateLLVM(ast_node, module=ir.Module(), builder=ir.IRBuilder()): +def generateLLVM(ast_node, module=None, builder=None): """ Prints the ast as llvm code """ + if module is None: + module = ir.Module() + if builder is None: + builder = ir.IRBuilder() printer = LLVMPrinter(module, builder) - return printer._print(ast_node) + return printer._print(ast_node) #TODO use doprint() instead??? class LLVMPrinter(Printer): @@ -37,19 +41,22 @@ class LLVMPrinter(Printer): def _add_tmp_var(self, name, value): self.tmp_var[name] = value + def _remove_tmp_var(self, name): + del self.tmp_var[name] + def _print_Number(self, n): - if n.dtype == createType("int"): + if getTypeOfExpression(n) == createType("int"): return ir.Constant(self.integer, int(n)) - elif n.dtype == createType("double"): + elif getTypeOfExpression(n) == createType("double"): return ir.Constant(self.fp_type, float(n)) else: raise NotImplementedError("Numbers can only have int and double", n) def _print_Float(self, expr): - return ir.Constant(self.fp_type, expr.p) + return ir.Constant(self.fp_type, float(expr)) def _print_Integer(self, expr): - return ir.Constant(self.integer, expr.p) + return ir.Constant(self.integer, int(expr)) def _print_int(self, i): return ir.Constant(self.integer, i) @@ -64,6 +71,7 @@ class LLVMPrinter(Printer): return val def _print_Pow(self, expr): + #print(expr) base0 = self._print(expr.base) if expr.exp == S.NegativeOne: return self.builder.fdiv(ir.Constant(self.fp_type, 1.0), base0) @@ -88,9 +96,9 @@ class LLVMPrinter(Printer): def _print_Mul(self, expr): nodes = [self._print(a) for a in expr.args] e = nodes[0] - if expr.dtype == createType('double'): + if getTypeOfExpression(expr) == createType('double'): mul = self.builder.fmul - else: # int TODO others? + else: # int TODO unsigned/signed mul = self.builder.mul for node in nodes[1:]: e = mul(e, node) @@ -99,24 +107,20 @@ class LLVMPrinter(Printer): def _print_Add(self, expr): nodes = [self._print(a) for a in expr.args] e = nodes[0] - if expr.dtype == createType('double'): + if getTypeOfExpression(expr) == createType('double'): add = self.builder.fadd - else: # int TODO others? + else: # int TODO unsigned/signed add = self.builder.add for node in nodes[1:]: e = add(e, node) return e def _print_KernelFunction(self, function): + # KernelFunction does not posses a return type return_type = self.void - # TODO argument in their own call? -> nope parameter_type = [] for parameter in function.parameters: - # TODO what about ptr shape and stride argument? - if parameter.isFieldArgument: - parameter_type.append(self.fp_pointer) - else: - parameter_type.append(self.fp_type) + parameter_type.append(to_llvm_type(parameter.dtype)) func_type = ir.FunctionType(return_type, tuple(parameter_type)) name = function.functionName fn = ir.Function(self.module, func_type, name) @@ -130,7 +134,7 @@ class LLVMPrinter(Printer): # func.attributes.add("inlinehint") # func.attributes.add("argmemonly") block = fn.append_basic_block(name="entry") - self.builder = ir.IRBuilder(block) + self.builder = ir.IRBuilder(block) #TODO use goto_block instead self._print(function.body) self.builder.ret_void() self.fn = fn @@ -144,8 +148,8 @@ class LLVMPrinter(Printer): with Loop(self.builder, self._print(loop.start), self._print(loop.stop), self._print(loop.step), loop.loopCounterName, loop.loopCounterSymbol.name) as i: self._add_tmp_var(loop.loopCounterSymbol, i) - # TODO remove tmp var self._print(loop.body) + self._remove_tmp_var(loop.loopCounterSymbol) def _print_SympyAssignment(self, assignment): expr = self._print(assignment.rhs) @@ -158,10 +162,10 @@ class LLVMPrinter(Printer): self.func_arg_map[assignment.lhs.name] = expr return expr - def _print_Conversion(self, conversion): + def _print_castFunc(self, conversion): node = self._print(conversion.args[0]) - to_dtype = conversion.dtype - from_dtype = conversion.args[0].dtype + to_dtype = getTypeOfExpression(conversion) + from_dtype = getTypeOfExpression(conversion.args[0]) # (From, to) decision = { (createType("int"), createType("double")): functools.partial(self.builder.sitofp, node, self.fp_type), @@ -173,8 +177,9 @@ class LLVMPrinter(Printer): (createType("double * restrict const"), createType("int")): functools.partial(self.builder.ptrtoint, node, self.integer), (createType("int"), createType("double * restrict const")): functools.partial(self.builder.inttoptr, node, self.fp_pointer), } - # TODO float, const, restrict + # TODO float, TEST: const, restrict # TODO bitcast, addrspacecast + # TODO unsigned/signed fills # print([x for x in decision.keys()]) # print("Types:") # print([(type(x), type(y)) for (x, y) in decision.keys()]) @@ -182,21 +187,21 @@ class LLVMPrinter(Printer): # print((from_dtype, to_dtype)) return decision[(from_dtype, to_dtype)]() + def _print_pointerArithmeticFunc(self, pointer): + ptr = self._print(pointer.args[0]) + index = self._print(pointer.args[1]) + return self.builder.gep(ptr, [index]) + def _print_Indexed(self, indexed): ptr = self._print(indexed.base.label) index = self._print(indexed.args[1]) gep = self.builder.gep(ptr, [index]) return self.builder.load(gep, name=indexed.base.label.name) - def _print_PointerArithmetic(self, pointer): - ptr = self._print(pointer.pointer) - index = self._print(pointer.offset) - return self.builder.gep(ptr, [index]) - # Should have a list of math library functions to validate this. - # TODO function calls + # TODO function calls to libs def _print_Function(self, expr): - name = expr.func.__name__ + name = expr.name e0 = self._print(expr.args[0]) fn = self.ext_fn.get(name) if not fn: diff --git a/llvm/llvmjit.py b/llvm/llvmjit.py new file mode 100644 index 000000000..d767c0ee5 --- /dev/null +++ b/llvm/llvmjit.py @@ -0,0 +1,182 @@ +import llvmlite.ir as ir +import llvmlite.binding as llvm +import numpy as np +import ctypes as ct +import subprocess +import shutil + +from ..data_types import toCtypes, createType, ctypes_from_llvm +from .llvm import generateLLVM +from ..cpu.cpujit import buildCTypeArgumentList + + +def generate_and_jit(ast): + gen = generateLLVM(ast) + if isinstance(gen, ir.Module): + return compileLLVM(gen) + else: + return compileLLVM(gen.module) + + +def make_python_function(ast, argumentDict={}, func=None): + try: + args = buildCTypeArgumentList(ast.parameters, argumentDict) + except KeyError: + # not all parameters specified yet + return make_python_function_incomplete(ast, argumentDict, func) + if func is None: + jit = generate_and_jit(ast) + func = jit.get_function_ptr(ast.functionName) + return lambda: func(*args) + + +def make_python_function_incomplete(ast, argumentDict, func=None): + if func is None: + jit = generate_and_jit(ast) + func = jit.get_function_ptr(ast.functionName) + parameters = ast.parameters + + cache = {} + + def wrapper(**kwargs): + key = hash(tuple((k, id(v)) for k, v in kwargs.items())) + try: + args = cache[key] + func(*args) + except KeyError: + fullArguments = argumentDict.copy() + fullArguments.update(kwargs) + args = buildCTypeArgumentList(parameters, fullArguments) + cache[key] = args + func(*args) + + return wrapper + + +def compileLLVM(module): + jit = Jit() + jit.parse(module) + jit.optimize() + jit.compile() + return jit + + +class Jit(object): + def __init__(self): + llvm.initialize() + llvm.initialize_all_targets() + llvm.initialize_native_target() + llvm.initialize_native_asmprinter() + + self.module = None + self._llvmmod = llvm.parse_assembly("") + self.target = llvm.Target.from_default_triple() + self.cpu = llvm.get_host_cpu_name() + self.cpu_features = llvm.get_host_cpu_features() + self.target_machine = self.target.create_target_machine(cpu=self.cpu, features=self.cpu_features.flatten(), opt=2) + llvm.check_jit_execution() + self.ee = llvm.create_mcjit_compiler(self.llvmmod, self.target_machine) + self.ee.finalize_object() + self.fptr = None + + @property + def llvmmod(self): + return self._llvmmod + + @llvmmod.setter + def llvmmod(self, mod): + self.ee.remove_module(self.llvmmod) + self.ee.add_module(mod) + self.ee.finalize_object() + self.compile() + self._llvmmod = mod + + def parse(self, module): + self.module = module + llvmmod = llvm.parse_assembly(str(module)) + llvmmod.verify() + llvmmod.triple = self.target.triple + llvmmod.name = 'module' + self.llvmmod = llvmmod + + def write_ll(self, file): + with open(file, 'w') as f: + f.write(str(self.llvmmod)) + + def write_assembly(self, file): + with open(file, 'w') as f: + f.write(self.target_machine.emit_assembly(self.llvmmod)) + + def write_object_file(self, file): + with open(file, 'wb') as f: + f.write(self.target_machine.emit_object(self.llvmmod)) + + def optimize(self): + pmb = llvm.create_pass_manager_builder() + pmb.opt_level = 2 + pmb.disable_unit_at_a_time = False + pmb.loop_vectorize = True + pmb.slp_vectorize = True + # TODO possible to pass for functions + pm = llvm.create_module_pass_manager() + pm.add_instruction_combining_pass() + pm.add_function_attrs_pass() + pm.add_constant_merge_pass() + pm.add_licm_pass() + pmb.populate(pm) + pm.run(self.llvmmod) + + def optimize_polly(self, opt): + if shutil.which(opt) is None: + print('Path to the executable is wrong') + return + canonicalize = subprocess.Popen([opt, '-polly-canonicalize'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) + + analyze = subprocess.Popen( + [opt, '-polly-codegen', '-polly-vectorizer=polly', '-polly-parallel', '-polly-process-unprofitable', '-f'], + stdin=canonicalize.stdout, stdout=subprocess.PIPE) + + canonicalize.communicate(input=self.llvmmod.as_bitcode()) + + optimize = subprocess.Popen([opt, '-O3', '-f'], stdin=analyze.stdout, stdout=subprocess.PIPE) + opts, _ = optimize.communicate() + llvmmod = llvm.parse_bitcode(opts) + llvmmod.verify() + self.llvmmod = llvmmod + + def compile(self): + fptr = {} + for function in self.module.functions: + if not function.is_declaration: + return_type = None + if function.ftype.return_type != ir.VoidType(): + return_type = toCtypes(createType(str(function.ftype.return_type))) + args = [toCtypes(createType(str(arg))) for arg in function.ftype.args] + function_address = self.ee.get_function_address(function.name) + fptr[function.name] = ct.CFUNCTYPE(return_type, *args)(function_address) + self.fptr = fptr + + def __call__(self, function, *args, **kwargs): + target_function = next(f for f in self.module.functions if f.name == function) + arg_types = [ctypes_from_llvm(arg.type) for arg in target_function.args] + + transformed_args = [] + for i, arg in enumerate(args): + if isinstance(arg, np.ndarray): + transformed_args.append(arg.ctypes.data_as(arg_types[i])) + else: + transformed_args.append(arg) + + self.fptr[function](*transformed_args) + + def print_functions(self): + for f in self.module.functions: + print(f.ftype.return_type, f.name, f.args) + + def get_function_ptr(self, name): + fptr = self.fptr[name] + fptr.jit = self + return fptr + + + diff --git a/transformations/__init__.py b/transformations/__init__.py new file mode 100644 index 000000000..a8ba7d85a --- /dev/null +++ b/transformations/__init__.py @@ -0,0 +1,2 @@ +from .transformations import * +from .stage2 import * diff --git a/transformations/stage2.py b/transformations/stage2.py new file mode 100644 index 000000000..d1cfd2d15 --- /dev/null +++ b/transformations/stage2.py @@ -0,0 +1,159 @@ +from operator import attrgetter + +import sympy as sp + +from pystencils.data_types import TypedSymbol, createType, PointerType, StructType, getBaseType, getTypeOfExpression, collateTypes, castFunc, pointerArithmeticFunc +import pystencils.astnodes as ast + + +def insertCasts(node): # TODO test casts!!!, edit testcase + """ + Checks the types and inserts casts and pointer arithmetic where necessary + :param node: the head node of the ast + :return: modified ast + """ + def cast(zippedArgsTypes, target): + """ + Adds casts to the arguments if their type differs from the target type + :param zippedArgsTypes: a zipped list of args and types + :param target: The target data type + :return: args with possible casts + """ + casted_args = [] + for arg, dataType in zippedArgsTypes: + if dataType.numpyDtype != target.numpyDtype: # TODO ignoring const + casted_args.append(castFunc(arg, target)) + else: + casted_args.append(arg) + return casted_args + + def pointerArithmetic(args): + """ + Creates a valid pointer arithmetic function + :param args: Arguments of the add expression + :return: pointerArithmeticFunc + """ + pointer = None + newArgs = [] + for arg, dataType in args: + if dataType.func == PointerType: + assert pointer is None + pointer = arg + for arg, dataType in args: + if arg != pointer: + assert dataType.is_int() or dataType.is_uint() + newArgs.append(arg) + newArgs = sp.Add(*newArgs) if len(newArgs) > 0 else newArgs + return pointerArithmeticFunc(pointer, newArgs) + + if isinstance(node, sp.AtomicExpr): + return node + args = [] + for arg in node.args: + args.append(insertCasts(arg)) + # TODO indexed, SympyAssignment, LoopOverCoordinate, Pow + if node.func in (sp.Add, sp.Mul): + types = [getTypeOfExpression(arg) for arg in args] + assert len(types) > 0 + target = collateTypes(types) + zipped = list(zip(args, types)) + print(zipped) + if target.func == PointerType: + assert node.func == sp.Add + return pointerArithmetic(zipped) + else: + return node.func(*cast(zipped, target)) + elif node.func == ast.SympyAssignment: + # TODO casting of rhs/lhs + return node.func(*args) + elif node.func == ast.ResolvedFieldAccess: + #print("Node:", node, type(node), node.__class__.mro()) + # TODO Everything + return node + elif node.func == ast.Block: + for oldArg, newArg in zip(node.args, args): + node.replace(oldArg, newArg) + return node + elif node.func == ast.LoopOverCoordinate: + for oldArg, newArg in zip(node.args, args): + node.replace(oldArg, newArg) + return node + + #print(node.func(*args)) + return node.func(*args) + + +def insert_casts(node): + """ + Inserts casts and dtype where needed + :param node: ast which should be traversed + :return: node + """ + def conversion(args): + target = args[0] + if isinstance(target.dtype, PointerType): + # Pointer arithmetic + for arg in args[1:]: + # Check validness + if not arg.dtype.is_int() and not arg.dtype.is_uint(): + raise ValueError("Impossible pointer arithmetic", target, arg) + pointer = ast.PointerArithmetic(ast.Add(args[1:]), target) + return [pointer] + + else: + for i in range(len(args)): + if args[i].dtype.numpyDtype != target.dtype.numpyDtype: # TODO ignoring const -> valid behavior? + args[i] = ast.Conversion(args[i], createType(target.dtype), node) + return args + + for arg in node.args: + insert_casts(arg) + if isinstance(node, ast.Indexed): + # TODO need to do something here? + pass + elif isinstance(node, ast.Expr): + args = sorted((arg for arg in node.args), key=attrgetter('dtype')) + target = args[0] + node.args = conversion(args) + node.dtype = target.dtype + elif isinstance(node, ast.SympyAssignment): + if node.lhs.dtype != node.rhs.dtype: + node.replace(node.rhs, ast.Conversion(node.rhs, node.lhs.dtype)) + elif isinstance(node, ast.LoopOverCoordinate): + pass + return node + + +#def desympy_ast(node): +# """ +# Remove Sympy Expressions, which have more then one argument. +# This is necessary for further changes in the tree. +# :param node: ast which should be traversed. Only node's children will be modified. +# :return: (modified) node +# """ +# if node.args is None: +# return node +# for i in range(len(node.args)): +# arg = node.args[i] +# if isinstance(arg, sp.Add): +# node.replace(arg, ast.Add(arg.args, node)) +# elif isinstance(arg, sp.Number): +# node.replace(arg, ast.Number(arg, node)) +# elif isinstance(arg, sp.Mul): +# node.replace(arg, ast.Mul(arg.args, node)) +# elif isinstance(arg, sp.Pow): +# node.replace(arg, ast.Pow(arg, node)) +# elif isinstance(arg, sp.tensor.Indexed) or isinstance(arg, sp.tensor.indexed.Indexed): +# node.replace(arg, ast.Indexed(arg.args, arg.base, node)) +# elif isinstance(arg, sp.tensor.IndexedBase): +# node.replace(arg, arg.label) +# elif isinstance(arg, sp.Function): +# node.replace(arg, ast.Function(arg.func, arg.args, node)) +# #elif isinstance(arg, sp.containers.Tuple): +# # +# else: +# #print('Not transforming:', type(arg), arg) +# pass +# for arg in node.args: +# desympy_ast(arg) +# return node diff --git a/transformations.py b/transformations/transformations.py similarity index 99% rename from transformations.py rename to transformations/transformations.py index 5c1fe5698..7cc41daf1 100644 --- a/transformations.py +++ b/transformations/transformations.py @@ -195,7 +195,6 @@ def parseBasePointerInfo(basePointerSpecification, loopOrder, field): if i in specifiedCoordinates: raise ValueError("Coordinate %d specified two times" % (i,)) specifiedCoordinates.add(i) - for element in specGroup: if type(element) is int: addNewElement(element) -- GitLab