diff --git a/ast.py b/ast.py index 7600e946f0195a6d30722970d9b0030182c598c8..a7e8cac346adb45c47557c45fabc505908888353 100644 --- a/ast.py +++ b/ast.py @@ -4,7 +4,7 @@ from pystencils.field import Field from pystencils.typedsymbol import TypedSymbol -class Node: +class Node(object): """Base class for all AST nodes""" def __init__(self, parent=None): @@ -35,6 +35,12 @@ class Node: result.update(arg.atoms(argType)) return result + def parents(self): + return None + + def children(self): + return None + class KernelFunction(Node): @@ -62,6 +68,9 @@ class KernelFunction(Node): self.isFieldArgument = True self.fieldName = name[len(Field.STRIDE_PREFIX):] + def __repr__(self): + return '<{0} {1}>'.format(self.dtype, self.name) + def __init__(self, body, functionName="kernel"): super(KernelFunction, self).__init__() self._body = body @@ -102,6 +111,13 @@ class KernelFunction(Node): l.isFieldStrideArgument, l.name), reverse=True) + def children(self): + yield self.body + + def __repr__(self): + self._updateParameters() + return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.functionName, self.parameters, self.body) + class Block(Node): def __init__(self, listOfNodes): @@ -156,6 +172,12 @@ class Block(Node): result.update(a.symbolsRead) return result + def children(self): + yield self._nodes + + def __repr__(self): + return ''.join('\t{!r}\n'.format(node) for node in self._nodes) + class PragmaBlock(Block): def __init__(self, pragmaLine, listOfNodes): @@ -252,6 +274,13 @@ class LoopOverCoordinate(Node): def coordinateToLoopOver(self): return self._coordinateToLoopOver + def children(self): + return self.body + + def __repr__(self): + return 'loop:{!s} {!s} in {!s}:{!s}:{!s}\n'.format(self.loopCounterName, self.coordinateToLoopOver, self.start, + self.stop, self.step) + '\t{!r}\n'.format(self.body) + class SympyAssignment(Node): diff --git a/backends/llvmbackend.py b/backends/llvmbackend.py new file mode 100644 index 0000000000000000000000000000000000000000..0d1eb1a86a3fde4636cf0ee901eb191d27ca0bb0 --- /dev/null +++ b/backends/llvmbackend.py @@ -0,0 +1,169 @@ +import llvmlite.ir as ir +import llvmlite.binding as llvm +import logging.config + +from sympy.utilities.codegen import CCodePrinter +from pystencils.ast import Node + +from sympy.printing.printer import Printer +from sympy import S +# S is numbers? + + +def generateLLVM(astNode): + return None + + +class LLVMPrinter(Printer): + """Convert expressions to LLVM IR""" + def __init__(self, module, builder, fn, *args, **kwargs): + self.func_arg_map = kwargs.pop("func_arg_map", {}) + super(LLVMPrinter, self).__init__(*args, **kwargs) + self.fp_type = ir.DoubleType() + #self.integer = ir.IntType(64) + self.module = module + self.builder = builder + self.fn = fn + self.ext_fn = {} # keep track of wrappers to external functions + self.tmp_var = {} + + def _add_tmp_var(self, name, value): + self.tmp_var[name] = value + + def _print_Number(self, n, **kwargs): + return ir.Constant(self.fp_type, float(n)) + + def _print_Integer(self, expr): + return ir.Constant(self.fp_type, float(expr.p)) + + def _print_Symbol(self, s): + val = self.tmp_var.get(s) + if not val: + # look up parameter with name s + val = self.func_arg_map.get(s) + if not val: + raise LookupError("Symbol not found: %s" % s) + return val + + def _print_Pow(self, expr): + base0 = self._print(expr.base) + if expr.exp == S.NegativeOne: + return self.builder.fdiv(ir.Constant(self.fp_type, 1.0), base0) + if expr.exp == S.Half: + fn = self.ext_fn.get("sqrt") + if not fn: + fn_type = ir.FunctionType(self.fp_type, [self.fp_type]) + fn = ir.Function(self.module, fn_type, "sqrt") + self.ext_fn["sqrt"] = fn + return self.builder.call(fn, [base0], "sqrt") + if expr.exp == 2: + return self.builder.fmul(base0, base0) + + exp0 = self._print(expr.exp) + fn = self.ext_fn.get("pow") + if not fn: + fn_type = ir.FunctionType(self.fp_type, [self.fp_type, self.fp_type]) + fn = ir.Function(self.module, fn_type, "pow") + self.ext_fn["pow"] = fn + return self.builder.call(fn, [base0, exp0], "pow") + + def _print_Mul(self, expr): + nodes = [self._print(a) for a in expr.args] + e = nodes[0] + for node in nodes[1:]: + e = self.builder.fmul(e, node) + return e + + def _print_Add(self, expr): + nodes = [self._print(a) for a in expr.args] + e = nodes[0] + for node in nodes[1:]: + e = self.builder.fadd(e, node) + return e + + # TODO - assumes all called functions take one double precision argument. + # Should have a list of math library functions to validate this. + + def _print_Function(self, expr): + name = expr.func.__name__ + e0 = self._print(expr.args[0]) + fn = self.ext_fn.get(name) + if not fn: + fn_type = ir.FunctionType(self.fp_type, [self.fp_type]) + fn = ir.Function(self.module, fn_type, name) + self.ext_fn[name] = fn + return self.builder.call(fn, [e0], name) + + def emptyPrinter(self, expr): + raise TypeError("Unsupported type for LLVM JIT conversion: %s" + % type(expr)) + + +class Eval(object): + def __init__(self): + llvm.initialize() + llvm.initialize_all_targets() + llvm.initialize_native_target() + llvm.initialize_native_asmprinter() + self.target = llvm.Target.from_default_triple() + + def compile(self, module): + logger.debug('=============Preparse') + logger.debug(str(module)) + llvmmod = llvm.parse_assembly(str(module)) + llvmmod.verify() + logger.debug('=============Function in IR') + logger.debug(str(llvmmod)) + # TODO cpu, features, opt + cpu = llvm.get_host_cpu_name() + features = llvm.get_host_cpu_features() + logger.debug('=======Things') + logger.debug(cpu) + logger.debug(features.flatten()) + target_machine = self.target.create_target_machine(cpu=cpu, features=features.flatten(), opt=2) + + logger.debug('Machine = ' + str(target_machine.target_data)) + + with open('gen.ll', 'w') as f: + f.write(str(llvmmod)) + optimize = True + if optimize: + pmb = llvm.create_pass_manager_builder() + pmb.opt_level = 2 + pmb.disable_unit_at_a_time = False + pmb.loop_vectorize = True + pmb.slp_vectorize = True + # TODO possible to pass for functions + pm = llvm.create_module_pass_manager() + pm.add_instruction_combining_pass() + pm.add_function_attrs_pass() + pm.add_constant_merge_pass() + pm.add_licm_pass() + pmb.populate(pm) + pm.run(llvmmod) + logger.debug("==========Opt") + logger.debug(str(llvmmod)) + with open('gen_opt.ll', 'w') as f: + f.write(str(llvmmod)) + + with llvm.create_mcjit_compiler(llvmmod, target_machine) as ee: + ee.finalize_object() + + logger.debug('==========Machine code') + logger.debug(target_machine.emit_assembly(llvmmod)) + with open('gen.S', 'w') as f: + f.write(target_machine.emit_assembly(llvmmod)) + with open('gen.o', 'wb') as f: + f.write(target_machine.emit_object(llvmmod)) + + # fptr = CFUNCTYPE(c_double, c_double, c_double)(ee.get_function_address('add2')) + # result = fptr(2, 3) + # print(result) + return 0 + + +if __name__ == "__main__": + logger = logging.getLogger(__name__) +else: + logger = logging.getLogger(__name__) + diff --git a/backends/logging.json b/backends/logging.json new file mode 100644 index 0000000000000000000000000000000000000000..42617b3878e3fefc3b298e35e857d560e37447b5 --- /dev/null +++ b/backends/logging.json @@ -0,0 +1,31 @@ +{ + "version" : 1, + "disable_existing_loggers" : false, + "formatters" : { + "simple" :{ + "format" : "[%(levelname)s]: %(message)s" + } + }, + "handlers" : { + "console": { + "class": "logging.StreamHandler", + "level": "INFO", + "formatter": "simple", + "stream": "ext://sys.stdout" + }, + "log_file": { + "class": "logging.FileHandler", + "level": "DEBUG", + "formatter": "simple", + "filename": "gen.log", + "mode" : "w", + "encoding": "utf8" + } + }, + "loggers" : { + "generator" : { + "level" : "DEBUG", + "handlers" : ["console", "log_file"] + } + } +} \ No newline at end of file diff --git a/llvm/__init__.py b/llvm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py new file mode 100644 index 0000000000000000000000000000000000000000..89eca80b86f18f9fdf4a3be208e0f32885539af2 --- /dev/null +++ b/llvm/kernelcreation.py @@ -0,0 +1,65 @@ +import sympy as sp +from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \ + typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop +from pystencils.typedsymbol import TypedSymbol +from pystencils.field import Field +import pystencils.ast as ast + + +def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, splitGroups=(), + iterationSlice=None, ghostLayers=None): + """ + Creates an abstract syntax tree for a kernel function, by taking a list of update rules. + + Loops are created according to the field accesses in the equations. + + :param listOfEquations: list of sympy equations, containing accesses to :class:`pystencils.field.Field`. + Defining the update rules of the kernel + :param functionName: name of the generated function - only important if generated code is written out + :param typeForSymbol: a map from symbol name to a C type specifier. If not specified all symbols are assumed to + be of type 'double' except symbols which occur on the left hand side of equations where the + right hand side is a sympy Boolean which are assumed to be 'bool' . + :param splitGroups: Specification on how to split up inner loop into multiple loops. For details see + transformation :func:`pystencils.transformation.splitInnerLoop` + :param iterationSlice: if not None, iteration is done only over this slice of the field + :param ghostLayers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers + if None, the number of ghost layers is determined automatically and assumed to be equal for a + all dimensions + + :return: :class:`pystencils.ast.KernelFunction` node + """ + if not typeForSymbol: + typeForSymbol = typingFromSympyInspection(listOfEquations, "double") + + def typeSymbol(term): + if isinstance(term, Field.Access) or isinstance(term, TypedSymbol): + return term + elif isinstance(term, sp.Symbol): + return TypedSymbol(term.name, typeForSymbol[term.name]) + else: + raise ValueError("Term has to be field access or symbol") + + fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol) + allFields = fieldsRead.union(fieldsWritten) + + for field in allFields: + field.setReadOnly(False) + for field in fieldsRead - fieldsWritten: + field.setReadOnly() + + body = ast.Block(assignments) + code = makeLoopOverDomain(body, functionName, iterationSlice=iterationSlice, ghostLayers=ghostLayers) + + if splitGroups: + typedSplitGroups = [[typeSymbol(s) for s in splitGroup] for splitGroup in splitGroups] + splitInnerLoop(code, typedSplitGroups) + + loopOrder = getOptimalLoopOrdering(allFields) + + basePointerInfo = [['spatialInner0'], ['spatialInner1']] + basePointerInfos = {field.name: parseBasePointerInfo(basePointerInfo, loopOrder, field) for field in allFields} + + resolveFieldAccesses(code, fieldToBasePointerInfo=basePointerInfos) + moveConstantsBeforeLoop(code) + + return code \ No newline at end of file