Commit e8286130 authored by Jan Hoenig's avatar Jan Hoenig
Browse files

Merged llvm_generate branch. The llvm generation can be used now

parents e54f702c 805f6cc8
import sympy as sp
from sympy.tensor import IndexedBase
from pystencils.field import Field
from pystencils.types import TypedSymbol
from pystencils.types import TypedSymbol, createType, get_type_from_sympy
class Node(object):
......@@ -294,7 +294,7 @@ class SympyAssignment(Node):
self._lhsSymbol = lhsSymbol
self.rhs = rhsTerm
self._isDeclaration = True
isCast = str(self._lhsSymbol.func).lower() == 'cast'
isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase) or isCast:
self._isDeclaration = False
self._isConst = isConst
......@@ -307,7 +307,7 @@ class SympyAssignment(Node):
def lhs(self, newValue):
self._lhsSymbol = newValue
self._isDeclaration = True
isCast = str(self._lhsSymbol.func).lower() == 'cast'
isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
self._isDeclaration = False
......@@ -344,7 +344,8 @@ class SympyAssignment(Node):
def replace(self, child, replacement):
if child == self.lhs:
self.lhs = child
replacement.parent = self
self.lhs = replacement
elif child == self.rhs:
replacement.parent = self
self.rhs = replacement
......@@ -394,8 +395,6 @@ class TemporaryMemoryFree(Node):
# TODO implement defined & undefinedSymbols
class Conversion(Node):
def __init__(self, child, dtype, parent=None):
super(Conversion, self).__init__(parent)
......@@ -422,9 +421,9 @@ class Conversion(Node):
raise set()
def __repr__(self):
return '(%s)' % (self.dtype,) + repr(self.args)
return '(%s(%s))' % (repr(self.dtype), repr(self.args[0].dtype)) + repr(self.args)
# TODO everything which is not Atomic expression: Pow)
# TODO Pow
_expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'}
......@@ -480,15 +479,33 @@ class Pow(Expr):
class Indexed(Expr):
def __init__(self, args, base, parent=None):
super(Indexed, self).__init__(args, parent)
self.base = base
# Get dtype from label, and unpointer it
self.dtype = createType(base.label.dtype.baseType)
def __repr__(self):
return '%s[%s]' % (self.args[0], self.args[1])
class Number(Node):
class PointerArithmetic(Expr):
def __init__(self, args, pointer, parent=None):
super(PointerArithmetic, self).__init__([args] + [pointer], parent)
self.pointer = pointer
self.offset = args
self.dtype = pointer.dtype
def __repr__(self):
return '*(%s + %s)' % (self.pointer, self.args)
class Number(Node, sp.AtomicExpr):
def __init__(self, number, parent=None):
super(Number, self).__init__(parent)
self._args = None
self.dtype = dtype
self.dtype, self.value = get_type_from_sympy(number)
self._args = tuple()
@property
def args(self):
......@@ -506,6 +523,12 @@ class Number(Node):
raise set()
def __repr__(self):
return '(%s)' % (self.dtype,) + repr(self.args)
return repr(self.value)
def __float__(self):
return float(self.value)
def __int__(self):
return int(self.value)
try:
from .llvm import generateLLVM
except ImportError:
pass
from .llvm import generateLLVM
from .cbackend import generateC
from .dot import dotprint
......@@ -6,9 +6,10 @@ class DotPrinter(Printer):
"""
A printer which converts ast to DOT (graph description language).
"""
def __init__(self, nodeToStrFunction, **kwargs):
def __init__(self, nodeToStrFunction, full, **kwargs):
super(DotPrinter, self).__init__()
self._nodeToStrFunction = nodeToStrFunction
self.full = full
self.dot = Digraph(**kwargs)
self.dot.quote_edge = lang.quote
......@@ -30,6 +31,21 @@ class DotPrinter(Printer):
def _print_SympyAssignment(self, assignment):
self.dot.node(self._nodeToStrFunction(assignment))
if self.full:
for node in assignment.args:
self._print(node)
for node in assignment.args:
self.dot.edge(self._nodeToStrFunction(assignment), self._nodeToStrFunction(node))
def emptyPrinter(self, expr):
if self.full:
self.dot.node(self._nodeToStrFunction(expr))
for node in expr.args:
self._print(node)
for node in expr.args:
self.dot.edge(self._nodeToStrFunction(expr), self._nodeToStrFunction(node))
else:
raise NotImplemented('Dotprinter cannot print', expr)
def doprint(self, expr):
self._print(expr)
......@@ -48,17 +64,20 @@ def __shortened(node):
return "Assignment: " + repr(node.lhs)
def dotprint(ast, view=False, short=False, **kwargs):
def dotprint(node, view=False, short=False, full=False, **kwargs):
"""
Returns a string which can be used to generate a DOT-graph
:param ast: The ast which should be generated
:param node: The ast which should be generated
:param view: Boolen, if rendering of the image directly should occur.
:param short: Uses the __shortened output
:param full: Prints the whole tree with type information
:param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph
:return: string in DOT format
"""
nodeToStrFunction = __shortened if short else repr
printer = DotPrinter(nodeToStrFunction, **kwargs)
dot = printer.doprint(ast)
nodeToStrFunction = lambda expr: repr(type(expr)) + repr(expr) if full else nodeToStrFunction
printer = DotPrinter(nodeToStrFunction, full, **kwargs)
dot = printer.doprint(node)
if view:
printer.dot.render(view=view)
return dot
......@@ -80,4 +99,4 @@ if __name__ == "__main__":
from pystencils.cpu import createKernel
ast = createKernel([updateRule])
print(dotprint(ast, short=True))
\ No newline at end of file
print(dotprint(ast, short=True))
import llvmlite.ir as ir
import functools
from sympy.printing.printer import Printer
from sympy import S
# S is numbers?
from pystencils.llvm.control_flow import Loop
from ..types import createType
from ..astnodes import Indexed
def generateLLVM(ast_node):
def generateLLVM(ast_node, module=ir.Module(), builder=ir.IRBuilder()):
"""
Prints the ast as llvm code
"""
module = ir.Module()
builder = ir.IRBuilder()
printer = LLVMPrinter(module, builder)
return printer._print(ast_node)
......@@ -25,6 +26,7 @@ class LLVMPrinter(Printer):
self.fp_type = ir.DoubleType()
self.fp_pointer = self.fp_type.as_pointer()
self.integer = ir.IntType(64)
self.integer_pointer = self.integer.as_pointer()
self.void = ir.VoidType()
self.module = module
self.builder = builder
......@@ -35,11 +37,16 @@ class LLVMPrinter(Printer):
def _add_tmp_var(self, name, value):
self.tmp_var[name] = value
def _print_Number(self, n, **kwargs):
return ir.Constant(self.fp_type, float(n))
def _print_Number(self, n):
if n.dtype == createType("int"):
return ir.Constant(self.integer, int(n))
elif n.dtype == createType("double"):
return ir.Constant(self.fp_type, float(n))
else:
raise NotImplementedError("Numbers can only have int and double", n)
def _print_Float(self, expr):
return ir.Constant(self.fp_type, float(expr.p))
return ir.Constant(self.fp_type, expr.p)
def _print_Integer(self, expr):
return ir.Constant(self.integer, expr.p)
......@@ -81,24 +88,31 @@ class LLVMPrinter(Printer):
def _print_Mul(self, expr):
nodes = [self._print(a) for a in expr.args]
e = nodes[0]
if expr.dtype == createType('double'):
mul = self.builder.fmul
else: # int TODO others?
mul = self.builder.mul
for node in nodes[1:]:
e = self.builder.fmul(e, node)
e = mul(e, node)
return e
def _print_Add(self, expr):
nodes = [self._print(a) for a in expr.args]
e = nodes[0]
if expr.dtype == createType('double'):
add = self.builder.fadd
else: # int TODO others?
add = self.builder.add
for node in nodes[1:]:
print(e, node)
e = self.builder.fadd(e, node)
e = add(e, node)
return e
def _print_KernelFunction(self, function):
return_type = self.void
# TODO argument in their own call?
# TODO argument in their own call? -> nope
parameter_type = []
for parameter in function.parameters:
# TODO what bout ptr shape and stride argument?
# TODO what about ptr shape and stride argument?
if parameter.isFieldArgument:
parameter_type.append(self.fp_pointer)
else:
......@@ -118,6 +132,7 @@ class LLVMPrinter(Printer):
block = fn.append_basic_block(name="entry")
self.builder = ir.IRBuilder(block)
self._print(function.body)
self.builder.ret_void()
self.fn = fn
return fn
......@@ -129,16 +144,57 @@ class LLVMPrinter(Printer):
with Loop(self.builder, self._print(loop.start), self._print(loop.stop), self._print(loop.step),
loop.loopCounterName, loop.loopCounterSymbol.name) as i:
self._add_tmp_var(loop.loopCounterSymbol, i)
# TODO remove tmp var
self._print(loop.body)
def _print_SympyAssignment(self, assignment):
expr = self._print(assignment.rhs)
# Should have a list of math library functions to validate this.
# TODO delete this -> NO this should be a function call
lhs = assignment.lhs
if isinstance(lhs, Indexed):
ptr = self._print(lhs.base.label)
index = self._print(lhs.args[1])
gep = self.builder.gep(ptr, [index])
return self.builder.store(expr, gep)
self.func_arg_map[assignment.lhs.name] = expr
return expr
def _print_Conversion(self, conversion):
node = self._print(conversion.args[0])
to_dtype = conversion.dtype
from_dtype = conversion.args[0].dtype
# (From, to)
decision = {
(createType("int"), createType("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
(createType("double"), createType("int")): functools.partial(self.builder.fptosi, node, self.integer),
(createType("double *"), createType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
(createType("int"), createType("double *")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
(createType("double * restrict"), createType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
(createType("int"), createType("double * restrict")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
(createType("double * restrict const"), createType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
(createType("int"), createType("double * restrict const")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
}
# TODO float, const, restrict
# TODO bitcast, addrspacecast
# print([x for x in decision.keys()])
# print("Types:")
# print([(type(x), type(y)) for (x, y) in decision.keys()])
# print("Cast:")
# print((from_dtype, to_dtype))
return decision[(from_dtype, to_dtype)]()
def _print_Indexed(self, indexed):
ptr = self._print(indexed.base.label)
index = self._print(indexed.args[1])
gep = self.builder.gep(ptr, [index])
return self.builder.load(gep, name=indexed.base.label.name)
def _print_PointerArithmetic(self, pointer):
ptr = self._print(pointer.pointer)
index = self._print(pointer.offset)
return self.builder.gep(ptr, [index])
# Should have a list of math library functions to validate this.
# TODO function calls
def _print_Function(self, expr):
name = expr.func.__name__
e0 = self._print(expr.args[0])
......@@ -150,5 +206,5 @@ class LLVMPrinter(Printer):
return self.builder.call(fn, [e0], name)
def emptyPrinter(self, expr):
raise TypeError("Unsupported type for LLVM JIT conversion: %s"
% type(expr))
raise TypeError("Unsupported type for LLVM JIT conversion: %s %s"
% (type(expr), expr))
from .kernelcreation import createKernel
from .jit import compileLLVM
\ No newline at end of file
import llvmlite.ir as ir
import llvmlite.binding as llvm
import logging.config
from ..types import toCtypes, createType
import ctypes as ct
class Eval(object):
def compileLLVM(module):
jit = Jit()
jit.parse(module)
jit.optimize()
jit.compile()
return jit
class Jit(object):
def __init__(self):
llvm.initialize()
llvm.initialize_all_targets()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
self.module = None
self.llvmmod = None
self.target = llvm.Target.from_default_triple()
self.cpu = llvm.get_host_cpu_name()
self.cpu_features = llvm.get_host_cpu_features()
self.target_machine = self.target.create_target_machine(cpu=self.cpu, features=self.cpu_features.flatten(), opt=2)
self.ee = None
self.fptr = None
def compile(self, module):
logger.debug('=============Preparse')
logger.debug(str(module))
def parse(self, module):
self.module = module
llvmmod = llvm.parse_assembly(str(module))
llvmmod.verify()
logger.debug('=============Function in IR')
logger.debug(str(llvmmod))
# TODO cpu, features, opt
cpu = llvm.get_host_cpu_name()
features = llvm.get_host_cpu_features()
logger.debug('=======Things')
logger.debug(cpu)
logger.debug(features.flatten())
target_machine = self.target.create_target_machine(cpu=cpu, features=features.flatten(), opt=2)
logger.debug('Machine = ' + str(target_machine.target_data))
self.llvmmod = llvmmod
with open('gen.ll', 'w') as f:
f.write(str(llvmmod))
optimize = True
if optimize:
pmb = llvm.create_pass_manager_builder()
pmb.opt_level = 2
pmb.disable_unit_at_a_time = False
pmb.loop_vectorize = True
pmb.slp_vectorize = True
# TODO possible to pass for functions
pm = llvm.create_module_pass_manager()
pm.add_instruction_combining_pass()
pm.add_function_attrs_pass()
pm.add_constant_merge_pass()
pm.add_licm_pass()
pmb.populate(pm)
pm.run(llvmmod)
logger.debug("==========Opt")
logger.debug(str(llvmmod))
with open('gen_opt.ll', 'w') as f:
f.write(str(llvmmod))
def write_ll(self, file):
with open(file, 'w') as f:
f.write(str(self.llvmmod))
with llvm.create_mcjit_compiler(llvmmod, target_machine) as ee:
ee.finalize_object()
def optimize(self):
pmb = llvm.create_pass_manager_builder()
pmb.opt_level = 2
pmb.disable_unit_at_a_time = False
pmb.loop_vectorize = True
pmb.slp_vectorize = True
# TODO possible to pass for functions
pm = llvm.create_module_pass_manager()
pm.add_instruction_combining_pass()
pm.add_function_attrs_pass()
pm.add_constant_merge_pass()
pm.add_licm_pass()
pmb.populate(pm)
pm.run(self.llvmmod)
logger.debug('==========Machine code')
logger.debug(target_machine.emit_assembly(llvmmod))
with open('gen.S', 'w') as f:
f.write(target_machine.emit_assembly(llvmmod))
with open('gen.o', 'wb') as f:
f.write(target_machine.emit_object(llvmmod))
def compile(self, assembly_file=None, object_file=None):
ee = llvm.create_mcjit_compiler(self.llvmmod, self.target_machine)
ee.finalize_object()
# fptr = CFUNCTYPE(c_double, c_double, c_double)(ee.get_function_address('add2'))
# result = fptr(2, 3)
# print(result)
return 0
if assembly_file is not None:
with open(assembly_file, 'w') as f:
f.write(self.target_machine.emit_assembly(self.llvmmod))
if object_file is not None:
with open(object_file, 'wb') as f:
f.write(self.target_machine.emit_object(self.llvmmod))
fptr = {}
for function in self.module.functions:
if not function.is_declaration:
return_type = None
if function.ftype.return_type != ir.VoidType():
return_type = toCtypes(createType(str(function.ftype.return_type)))
args = [toCtypes(createType(str(arg))) for arg in function.ftype.args]
function_address = ee.get_function_address(function.name)
fptr[function.name] = ct.CFUNCTYPE(return_type, *args)(function_address)
self.ee = ee
self.fptr = fptr
if __name__ == "__main__":
logger = logging.getLogger(__name__)
else:
logger = logging.getLogger(__name__)
def __call__(self, function, *args, **kwargs):
self.fptr[function](*args, **kwargs)
......@@ -60,9 +60,11 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos)
moveConstantsBeforeLoop(code)
print('Ast:')
print(code)
desympy_ast(code)
print('Desympied ast:')
print(code)
insert_casts(code)
return code
\ No newline at end of file
return code
......@@ -566,23 +566,45 @@ def get_type(node):
def insert_casts(node):
"""
Inserts casts where needed
Inserts casts and dtype where needed
:param node: ast which should be traversed
:return: node
"""
def add_conversion(node, dtype):
return node
def conversion(args):
target = args[0]
if isinstance(target.dtype, PointerType):
# Pointer arithmetic
for arg in args[1:]:
# Check validness
if not arg.dtype.is_int() and not arg.dtype.is_uint():
raise ValueError("Impossible pointer arithmetic", target, arg)
pointer = ast.PointerArithmetic(ast.Add(args[1:]), target)
return [pointer]
else:
for i in range(len(args)):
if args[i].dtype != target.dtype:
args[i] = ast.Conversion(args[i], target.dtype, node)
return args
for arg in node.args:
insert_casts(arg)
if isinstance(node, ast.Indexed):
#TODO revmove this
pass
elif isinstance(node, ast.Expr):
args = sorted((arg.dtype for arg in node.args), key=attrgetter('ptr', 'dtype'))
#print(node, node.args)
#print([type(arg) for arg in node.args])
#print([arg.dtype for arg in node.args])
args = sorted((arg for arg in node.args), key=attrgetter('dtype'))
target = args[0]
for i in range(len(args)):
args[i] = add_conversion(args[i], target.dtype)
node.args = args
node.args = conversion(args)
node.dtype = target.dtype
#print(node.dtype)
#print(node)
elif isinstance(node, ast.SympyAssignment):
if node.lhs.dtype != node.rhs.dtype:
node.replace(node.rhs, ast.Conversion(node.rhs, node.lhs.dtype))
elif isinstance(node, ast.LoopOverCoordinate):
pass
return node
......@@ -595,16 +617,45 @@ def desympy_ast(node):
:param node: ast which should be traversed. Only node's children will be modified.
:return: (modified) node
"""
if node.args is None:
return node
for i in range(len(node.args)):
arg = node.args[i]
if isinstance(arg, sp.Add):
node.replace(arg, ast.Add(arg.args, node))
elif isinstance(arg, sp.Number):
node.replace(arg, ast.Number(arg, node))
elif isinstance(arg, sp.Mul):
node.replace(arg, ast.Mul(arg.args, node))
elif isinstance(arg, sp.Pow):
node.replace(arg, ast.Pow(arg.args, node))
elif isinstance(arg, sp.tensor.Indexed):
node.replace(arg, ast.Indexed(arg.args, node))
elif isinstance(arg, sp.tensor.Indexed) or isinstance(arg, sp.tensor.indexed.Indexed):
node.replace(arg, ast.Indexed(arg.args, arg.base, node))
elif isinstance(arg, sp.tensor.IndexedBase):
node.replace(arg, arg.label)
#elif isinstance(arg, sp.containers.Tuple):
#
else:
#print('Not transforming:', type(arg), arg)
pass
for arg in node.args:
desympy_ast(arg)
return node
def check_dtype(node):
if isinstance(node, ast.KernelFunction):
pass
elif isinstance(node, ast.Block):
pass
elif isinstance(node, ast.LoopOverCoordinate):
pass
elif isinstance(node, ast.SympyAssignment):
pass
else:
#print(node)
#print(node.dtype)
pass
for arg in node.args:
check_dtype(arg)
import ctypes
import sympy as sp
import numpy as np
# import llvmlite.ir as ir
from sympy.core.cache import cacheit
......@@ -50,6 +51,11 @@ class TypedSymbol(sp.Symbol):