import sympy as sp
import functools
from sympy import S, Indexed
from sympy.printing.printer import Printer
import llvmlite.ir as ir
from pystencils.assignment import Assignment
from pystencils.llvm.control_flow import Loop
from pystencils.data_types import create_type, to_llvm_type, get_type_of_expression, collate_types, \
    create_composite_type_from_string


def generate_llvm(ast_node, module=None, builder=None):
    """Prints the ast as llvm code."""
    if module is None:
        module = ir.Module()
    if builder is None:
        builder = ir.IRBuilder()
    printer = LLVMPrinter(module, builder)
    return printer._print(ast_node)


# noinspection PyPep8Naming
class LLVMPrinter(Printer):
    """Convert expressions to LLVM IR"""

    def __init__(self, module, builder, fn=None, *args, **kwargs):
        self.func_arg_map = kwargs.pop("func_arg_map", {})
        super(LLVMPrinter, self).__init__(*args, **kwargs)
        self.fp_type = ir.DoubleType()
        self.fp_pointer = self.fp_type.as_pointer()
        self.integer = ir.IntType(64)
        self.integer_pointer = self.integer.as_pointer()
        self.void = ir.VoidType()
        self.module = module
        self.builder = builder
        self.fn = fn
        self.ext_fn = {}  # keep track of wrappers to external functions
        self.tmp_var = {}

    def _add_tmp_var(self, name, value):
        self.tmp_var[name] = value

    def _remove_tmp_var(self, name):
        del self.tmp_var[name]

    def _print_Number(self, n):
        if get_type_of_expression(n) == create_type("int"):
            return ir.Constant(self.integer, int(n))
        elif get_type_of_expression(n) == create_type("double"):
            return ir.Constant(self.fp_type, float(n))
        else:
            raise NotImplementedError("Numbers can only have int and double", n)

    def _print_Float(self, expr):
        return ir.Constant(self.fp_type, float(expr))

    def _print_Integer(self, expr):
        return ir.Constant(self.integer, int(expr))

    def _print_int(self, i):
        return ir.Constant(self.integer, i)

    def _print_Symbol(self, s):
        val = self.tmp_var.get(s)
        if not val:
            # look up parameter with name s
            val = self.func_arg_map.get(s.name)
        if not val:
            raise LookupError("Symbol not found: %s" % s)
        return val

    def _print_Pow(self, expr):
        base0 = self._print(expr.base)
        if expr.exp == S.NegativeOne:
            return self.builder.fdiv(ir.Constant(self.fp_type, 1.0), base0)
        if expr.exp == S.Half:
            fn = self.ext_fn.get("sqrt")
            if not fn:
                fn_type = ir.FunctionType(self.fp_type, [self.fp_type])
                fn = ir.Function(self.module, fn_type, "sqrt")
                self.ext_fn["sqrt"] = fn
            return self.builder.call(fn, [base0], "sqrt")
        if expr.exp == 2:
            return self.builder.fmul(base0, base0)
        elif expr.exp == 3:
            return self.builder.fmul(self.builder.fmul(base0, base0), base0)

        exp0 = self._print(expr.exp)
        fn = self.ext_fn.get("pow")
        if not fn:
            fn_type = ir.FunctionType(self.fp_type, [self.fp_type, self.fp_type])
            fn = ir.Function(self.module, fn_type, "pow")
            self.ext_fn["pow"] = fn
        return self.builder.call(fn, [base0, exp0], "pow")

    def _print_Mul(self, expr):
        nodes = [self._print(a) for a in expr.args]
        e = nodes[0]
        if get_type_of_expression(expr) == create_type('double'):
            mul = self.builder.fmul
        else:  # int TODO unsigned/signed
            mul = self.builder.mul
        for node in nodes[1:]:
            e = mul(e, node)
        return e

    def _print_Add(self, expr):
        nodes = [self._print(a) for a in expr.args]
        e = nodes[0]
        if get_type_of_expression(expr) == create_type('double'):
            add = self.builder.fadd
        else:  # int TODO unsigned/signed
            add = self.builder.add
        for node in nodes[1:]:
            e = add(e, node)
        return e

    def _print_Or(self, expr):
        nodes = [self._print(a) for a in expr.args]
        e = nodes[0]
        for node in nodes[1:]:
            e = self.builder.or_(e, node)
        return e

    def _print_And(self, expr):
        nodes = [self._print(a) for a in expr.args]
        e = nodes[0]
        for node in nodes[1:]:
            e = self.builder.and_(e, node)
        return e

    def _print_StrictLessThan(self, expr):
        return self._comparison('<', expr)

    def _print_LessThan(self, expr):
        return self._comparison('<=', expr)

    def _print_StrictGreaterThan(self, expr):
        return self._comparison('>', expr)

    def _print_GreaterThan(self, expr):
        return self._comparison('>=', expr)

    def _print_Unequality(self, expr):
        return self._comparison('!=', expr)

    def _print_Equality(self, expr):
        return self._comparison('==', expr)

    def _comparison(self, cmpop, expr):
        if collate_types([get_type_of_expression(arg) for arg in expr.args]) == create_type('double'):
            comparison = self.builder.fcmp_unordered
        else:
            comparison = self.builder.icmp_signed
        return comparison(cmpop, self._print(expr.lhs), self._print(expr.rhs))

    def _print_KernelFunction(self, func):
        # KernelFunction does not posses a return type
        return_type = self.void
        parameter_type = []
        parameters = func.get_parameters()
        for parameter in parameters:
            parameter_type.append(to_llvm_type(parameter.symbol.dtype))
        func_type = ir.FunctionType(return_type, tuple(parameter_type))
        name = func.function_name
        fn = ir.Function(self.module, func_type, name)
        self.ext_fn[name] = fn

        # set proper names to arguments
        for i, arg in enumerate(fn.args):
            arg.name = parameters[i].symbol.name
            self.func_arg_map[parameters[i].symbol.name] = arg

        # func.attributes.add("inlinehint")
        # func.attributes.add("argmemonly")
        block = fn.append_basic_block(name="entry")
        self.builder = ir.IRBuilder(block)  # TODO use goto_block instead
        self._print(func.body)
        self.builder.ret_void()
        self.fn = fn
        return fn

    def _print_Block(self, block):
        for node in block.args:
            self._print(node)

    def _print_LoopOverCoordinate(self, loop):
        with Loop(self.builder, self._print(loop.start), self._print(loop.stop), self._print(loop.step),
                  loop.loop_counter_name, loop.loop_counter_symbol.name) as i:
            self._add_tmp_var(loop.loop_counter_symbol, i)
            self._print(loop.body)
            self._remove_tmp_var(loop.loop_counter_symbol)

    def _print_SympyAssignment(self, assignment):
        expr = self._print(assignment.rhs)
        lhs = assignment.lhs
        if isinstance(lhs, Indexed):
            ptr = self._print(lhs.base.label)
            index = self._print(lhs.args[1])
            gep = self.builder.gep(ptr, [index])
            return self.builder.store(expr, gep)
        self.func_arg_map[assignment.lhs.name] = expr
        return expr

    def _print_boolean_cast_func(self, conversion):
        return self._print_cast_func(conversion)

    def _print_cast_func(self, conversion):
        node = self._print(conversion.args[0])
        to_dtype = get_type_of_expression(conversion)
        from_dtype = get_type_of_expression(conversion.args[0])
        if from_dtype == to_dtype:
            return self._print(conversion.args[0])

        # (From, to)
        decision = {
            (create_composite_type_from_string("int16"),
             create_composite_type_from_string("int64")): lambda: ir.Constant(self.integer, node),
            (create_composite_type_from_string("int"),
             create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
            (create_composite_type_from_string("int16"),
             create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
            (create_composite_type_from_string("double"),
             create_composite_type_from_string("int")): functools.partial(self.builder.fptosi, node, self.integer),
            (create_composite_type_from_string("double *"),
             create_composite_type_from_string("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
            (create_composite_type_from_string("int"),
             create_composite_type_from_string("double *")): functools.partial(self.builder.inttoptr,
                                                                               node, self.fp_pointer),
            (create_composite_type_from_string("double * restrict"),
             create_composite_type_from_string("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
            (create_composite_type_from_string("int"),
             create_composite_type_from_string("double * restrict")): functools.partial(self.builder.inttoptr, node,
                                                                                        self.fp_pointer),
            (create_composite_type_from_string("double * restrict const"),
             create_composite_type_from_string("int")): functools.partial(self.builder.ptrtoint, node,
                                                                          self.integer),
            (create_composite_type_from_string("int"),
             create_composite_type_from_string("double * restrict const")): functools.partial(self.builder.inttoptr,
                                                                                              node, self.fp_pointer),
        }
        # TODO float, TEST: const, restrict
        # TODO bitcast, addrspacecast
        # TODO unsigned/signed fills
        # print([x for x in decision.keys()])
        # print("Types:")
        # print([(type(x), type(y)) for (x, y) in decision.keys()])
        # print("Cast:")
        # print((from_dtype, to_dtype))
        return decision[(from_dtype, to_dtype)]()

    def _print_pointer_arithmetic_func(self, pointer):
        ptr = self._print(pointer.args[0])
        index = self._print(pointer.args[1])
        return self.builder.gep(ptr, [index])

    def _print_Indexed(self, indexed):
        ptr = self._print(indexed.base.label)
        index = self._print(indexed.args[1])
        gep = self.builder.gep(ptr, [index])
        return self.builder.load(gep, name=indexed.base.label.name)

    def _print_Piecewise(self, piece):
        if not piece.args[-1].cond:
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")
        if piece.has(Assignment):
            raise NotImplementedError('The llvm-backend does not support assignments'
                                      'in the Piecewise function. It is questionable'
                                      'whether to implement it. So far there is no'
                                      'use-case to test it.')
        else:
            phi_data = []
            after_block = self.builder.append_basic_block()
            for (expr, condition) in piece.args:
                if condition == sp.sympify(True):  # Don't use 'is' use '=='!
                    phi_data.append((self._print(expr), self.builder.block))
                    self.builder.branch(after_block)
                    self.builder.position_at_end(after_block)
                else:
                    cond = self._print(condition)
                    true_block = self.builder.append_basic_block()
                    false_block = self.builder.append_basic_block()
                    self.builder.cbranch(cond, true_block, false_block)
                    self.builder.position_at_end(true_block)
                    phi_data.append((self._print(expr), true_block))
                    self.builder.branch(after_block)
                    self.builder.position_at_end(false_block)

            phi = self.builder.phi(to_llvm_type(get_type_of_expression(piece)))
            for (val, block) in phi_data:
                phi.add_incoming(val, block)
            return phi

    # Should have a list of math library functions to validate this.
    # TODO function calls to libs
    def _print_Function(self, expr):
        name = expr.func.__name__
        e0 = self._print(expr.args[0])
        fn = self.ext_fn.get(name)
        if not fn:
            fn_type = ir.FunctionType(self.fp_type, [self.fp_type])
            fn = ir.Function(self.module, fn_type, name)
            self.ext_fn[name] = fn
        return self.builder.call(fn, [e0], name)

    def empty_printer(self, expr):
        try:
            import inspect
            mro = inspect.getmro(expr)
        except AttributeError:
            mro = "None"
        raise TypeError("Unsupported type for LLVM JIT conversion: Expression:\"%s\", Type:\"%s\", MRO:%s"
                        % (expr, type(expr), mro))