Skip to content
Snippets Groups Projects
Commit 4a8659e8 authored by Jan Hoenig's avatar Jan Hoenig
Browse files

it actually somehow comiles

parent b444ae25
Branches
Tags
No related merge requests found
......@@ -511,4 +511,10 @@ class Number(Node, sp.AtomicExpr):
def __repr__(self):
return repr(self.value)
def __float__(self):
return float(self.value)
def __int__(self):
return int(self.value)
import llvmlite.ir as ir
import functools
from sympy.printing.printer import Printer
from sympy import S
# S is numbers?
from pystencils.llvm.control_flow import Loop
from ..types import DataType
from ..astnodes import Indexed
def generateLLVM(ast_node):
......@@ -25,6 +28,7 @@ class LLVMPrinter(Printer):
self.fp_type = ir.DoubleType()
self.fp_pointer = self.fp_type.as_pointer()
self.integer = ir.IntType(64)
self.integer_pointer = self.integer.as_pointer()
self.void = ir.VoidType()
self.module = module
self.builder = builder
......@@ -35,8 +39,13 @@ class LLVMPrinter(Printer):
def _add_tmp_var(self, name, value):
self.tmp_var[name] = value
def _print_Number(self, n, **kwargs):
return ir.Constant(self.fp_type, n)
def _print_Number(self, n):
if n.dtype == DataType("int"):
return ir.Constant(self.integer, int(n))
elif n.dtype == DataType("double"):
return ir.Constant(self.fp_type, float(n))
else:
raise NotImplementedError("Numbers can only have int and double", n)
def _print_Float(self, expr):
return ir.Constant(self.fp_type, expr.p)
......@@ -81,16 +90,23 @@ class LLVMPrinter(Printer):
def _print_Mul(self, expr):
nodes = [self._print(a) for a in expr.args]
e = nodes[0]
if expr.dtype == DataType('double'):
mul = self.builder.fmul
else: # int TODO others?
mul = self.builder.mul
for node in nodes[1:]:
e = self.builder.fmul(e, node)
e = mul(e, node)
return e
def _print_Add(self, expr):
nodes = [self._print(a) for a in expr.args]
e = nodes[0]
if expr.dtype == DataType('double'):
add = self.builder.fadd
else: # int TODO others?
add = self.builder.add
for node in nodes[1:]:
print(e, node)
e = self.builder.fadd(e, node)
e = add(e, node)
return e
def _print_KernelFunction(self, function):
......@@ -118,6 +134,7 @@ class LLVMPrinter(Printer):
block = fn.append_basic_block(name="entry")
self.builder = ir.IRBuilder(block)
self._print(function.body)
self.builder.ret_void()
self.fn = fn
return fn
......@@ -129,29 +146,47 @@ class LLVMPrinter(Printer):
with Loop(self.builder, self._print(loop.start), self._print(loop.stop), self._print(loop.step),
loop.loopCounterName, loop.loopCounterSymbol.name) as i:
self._add_tmp_var(loop.loopCounterSymbol, i)
# TODO remove tmp var
self._print(loop.body)
def _print_SympyAssignment(self, assignment):
expr = self._print(assignment.rhs)
lhs = assignment.lhs
if isinstance(lhs, Indexed):
ptr = self._print(lhs.base.label)
index = self._print(lhs.args[1])
gep = self.builder.gep(ptr, [index])
return self.builder.store(expr, gep)
self.func_arg_map[assignment.lhs.name] = expr
return expr
def _print_Conversion(self, conversion):
node = self._print(conversion.args[0])
to_dtype = conversion.dtype
from_dtype = conversion.args[0].dtype
print(to_dtype, from_dtype)
# fp -> int: fptosi
# int -> fp: sitofp
# ptr -> int: ptrtoint
# int -> ptr: inttoptr
# ?bitcast, ?addrspacecast
# (From, to)
decision = {
(DataType("int"), DataType("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
(DataType("double"), DataType("int")): functools.partial(self.builder.fptosi, node, self.integer),
(DataType("double *"), DataType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
(DataType("int"), DataType("double *")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
(DataType("double * __restrict__"), DataType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
(DataType("int"), DataType("double * __restrict__")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
(DataType("const double * __restrict__"), DataType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
(DataType("int"), DataType("const double * __restrict__")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
}
# TODO float, const, restrict
# TODO bitcast, addrspacecast
return decision[(from_dtype, to_dtype)]()
def _print_Indexed(self, indexed):
pass
ptr = self._print(indexed.base.label)
index = self._print(indexed.args[1])
gep = self.builder.gep(ptr, [index])
return self.builder.load(gep, name=indexed.base.label.name)
# Should have a list of math library functions to validate this.
# TODO delete this -> NO this should be a function call
# Should have a list of math library functions to validate this.
# TODO function calls
def _print_Function(self, expr):
name = expr.func.__name__
e0 = self._print(expr.args[0])
......@@ -163,5 +198,5 @@ class LLVMPrinter(Printer):
return self.builder.call(fn, [e0], name)
def emptyPrinter(self, expr):
raise TypeError("Unsupported type for LLVM JIT conversion: %s"
% type(expr))
raise TypeError("Unsupported type for LLVM JIT conversion: %s %s"
% type(expr), expr)
from .kernelcreation import createKernel
from .jit import compileLLVM
\ No newline at end of file
import llvmlite.binding as llvm
import logging.config
logger = logging.getLogger(__name__)
def compileLLVM(module):
return Eval().compile(module)
class Eval(object):
def __init__(self):
......@@ -63,9 +69,3 @@ class Eval(object):
# result = fptr(2, 3)
# print(result)
return 0
if __name__ == "__main__":
logger = logging.getLogger(__name__)
else:
logger = logging.getLogger(__name__)
......@@ -70,6 +70,9 @@ class DataType(object):
if self.dtype > other.dtype:
return True
def __hash__(self):
return hash(repr(self))
def get_type_from_sympy(node):
# Rational, NumberSymbol?
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment