import sympy as sp
from copy import copy
from typing import List, Optional, Dict, Any, Set, Sequence, Iterator, Iterable
from pystencils.assignment import Assignment
from pystencils.sympyextensions import fast_subs, count_operations, sort_assignments_topologically


class AssignmentCollection:
    """
    A collection of equations with subexpression definitions, also represented as assignments,
    that are used in the main equations. AssignmentCollection can be passed to simplification methods.
    These simplification methods can change the subexpressions, but the number and
    left hand side of the main equations themselves is not altered.
    Additionally a dictionary of simplification hints is stored, which are set by the functions that create
    equation collections to transport information to the simplification system.

    Attributes:
        main_assignments: list of assignments
        subexpressions: list of assignments defining subexpressions used in main equations
        simplification_hints: dict that is used to annotate the equation collection with hints that are
                              used by the simplification system. See documentation of the simplification rules for
                              potentially required hints and their meaning.
        subexpression_symbol_generator: generator for new symbols that are used when new subexpressions are added
                                        used to get new symbols that are unique for this AssignmentCollection

    """

    # ------------------------------- Creation & Inplace Manipulation --------------------------------------------------

    def __init__(self, main_assignments: List[Assignment], subexpressions: List[Assignment],
                 simplification_hints: Optional[Dict[str, Any]] = None,
                 subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None:
        self.main_assignments = main_assignments
        self.subexpressions = subexpressions

        if simplification_hints is None:
            simplification_hints = {}

        self.simplification_hints = simplification_hints

        if subexpression_symbol_generator is None:
            self.subexpression_symbol_generator = SymbolGen()
        else:
            self.subexpression_symbol_generator = subexpression_symbol_generator

    def add_simplification_hint(self, key: str, value: Any) -> None:
        """Adds an entry to the simplification_hints dictionary and checks that is does not exist yet."""
        assert key not in self.simplification_hints, "This hint already exists"
        self.simplification_hints[key] = value

    def add_subexpression(self, rhs: sp.Expr, lhs: Optional[sp.Symbol] = None, topological_sort=True) -> sp.Symbol:
        """Adds a subexpression to current collection.

        Args:
            rhs: right hand side of new subexpression
            lhs: optional left hand side of new subexpression. If None a new unique symbol is generated.
            topological_sort: sort the subexpressions topologically after insertion, to make sure that
                              definition of a symbol comes before its usage. If False, subexpression is appended.

        Returns:
            left hand side symbol (which could have been generated)
        """
        if lhs is None:
            lhs = next(self.subexpression_symbol_generator)
        eq = Assignment(lhs, rhs)
        self.subexpressions.append(eq)
        if topological_sort:
            self.topological_sort(sort_subexpressions=True, sort_main_assignments=False)
        return lhs

    def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None:
        """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition."""
        if sort_subexpressions:
            self.subexpressions = sort_assignments_topologically(self.subexpressions)
        if sort_main_assignments:
            self.main_assignments = sort_assignments_topologically(self.main_assignments)

    # ---------------------------------------------- Properties  -------------------------------------------------------

    @property
    def all_assignments(self) -> List[Assignment]:
        """Subexpression and main equations as a single list."""
        return self.subexpressions + self.main_assignments

    @property
    def free_symbols(self) -> Set[sp.Symbol]:
        """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
        free_symbols = set()
        for eq in self.all_assignments:
            free_symbols.update(eq.rhs.atoms(sp.Symbol))
        return free_symbols - self.bound_symbols

    @property
    def bound_symbols(self) -> Set[sp.Symbol]:
        """All symbols which occur on the left hand side of a main assignment or a subexpression."""
        bound_symbols_set = set([eq.lhs for eq in self.all_assignments])
        assert len(bound_symbols_set) == len(self.subexpressions) + len(self.main_assignments), \
            "Not in SSA form - same symbol assigned multiple times"
        return bound_symbols_set

    @property
    def defined_symbols(self) -> Set[sp.Symbol]:
        """All symbols which occur as left-hand-sides of one of the main equations"""
        return set([assignment.lhs for assignment in self.main_assignments])

    @property
    def operation_count(self):
        """See :func:`count_operations` """
        return count_operations(self.all_assignments, only_type=None)

    def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]:
        """Returns all symbols that depend on one of the passed symbols.

        A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- some_expression(b)' i.e. when
        'b' is required to compute 'a'.
        """

        queue = list(symbols)

        def add_symbols_from_expr(expr):
            dependent_symbols = expr.atoms(sp.Symbol)
            for ds in dependent_symbols:
                queue.append(ds)

        handled_symbols = set()
        assignment_dict = {e.lhs: e.rhs for e in self.all_assignments}

        while len(queue) > 0:
            e = queue.pop(0)
            if e in handled_symbols:
                continue
            if e in assignment_dict:
                add_symbols_from_expr(assignment_dict[e])
            handled_symbols.add(e)

        return handled_symbols

    def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]]=None, module=None):
        """Returns a python function to evaluate this equation collection.

        Args:
            symbols: symbol(s) which are the parameter for the created function
            fixed_symbols: dictionary with substitutions, that are applied before sympy's lambdify
            module: same as sympy.lambdify parameter. Defines which module to use e.g. 'numpy'

        Examples:
              >>> a, b, c, d = sp.symbols("a b c d")
              >>> ac = AssignmentCollection([Assignment(c, a + b), Assignment(d, a**2 + b)],
              ...                           subexpressions=[Assignment(b, a + b / 2)])
              >>> python_function = ac.lambdify([a], fixed_symbols={b: 2})
              >>> python_function(4)
              {c: 6, d: 18}
        """
        assignments = self.new_with_substitutions(fixed_symbols, substitute_on_lhs=False) if fixed_symbols else self
        assignments = assignments.new_without_subexpressions().main_assignments
        lambdas = {assignment.lhs: sp.lambdify(symbols, assignment.rhs, module) for assignment in assignments}

        def f(*args, **kwargs):
            return {s: func(*args, **kwargs) for s, func in lambdas.items()}

        return f
    # ---------------------------- Creating new modified collections ---------------------------------------------------

    def copy(self,
             main_assignments: Optional[List[Assignment]] = None,
             subexpressions: Optional[List[Assignment]] = None) -> 'AssignmentCollection':
        """Returns a copy with optionally replaced main_assignments and/or subexpressions."""

        res = copy(self)
        res.simplification_hints = self.simplification_hints.copy()
        res.subexpression_symbol_generator = copy(self.subexpression_symbol_generator)

        if main_assignments is not None:
            res.main_assignments = main_assignments
        else:
            res.main_assignments = self.main_assignments.copy()

        if subexpressions is not None:
            res.subexpressions = subexpressions
        else:
            res.subexpressions = self.subexpressions.copy()

        return res

    def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False,
                               substitute_on_lhs: bool = True) -> 'AssignmentCollection':
        """Returns new object, where terms are substituted according to the passed substitution dict.

        Args:
            substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions
            add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions
            substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments

        Returns:
            New AssignmentCollection where substitutions have been applied, self is not altered.
        """
        if substitute_on_lhs:
            new_subexpressions = [fast_subs(eq, substitutions) for eq in self.subexpressions]
            new_equations = [fast_subs(eq, substitutions) for eq in self.main_assignments]
        else:
            new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, substitutions)) for eq in self.subexpressions]
            new_equations = [Assignment(eq.lhs, fast_subs(eq.rhs, substitutions)) for eq in self.main_assignments]

        if add_substitutions_as_subexpressions:
            new_subexpressions = [Assignment(b, a) for a, b in substitutions.items()] + new_subexpressions
            new_subexpressions = sort_assignments_topologically(new_subexpressions)
        return self.copy(new_equations, new_subexpressions)

    def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection':
        """Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
        own_definitions = set([e.lhs for e in self.main_assignments])
        other_definitions = set([e.lhs for e in other.main_assignments])
        assert len(own_definitions.intersection(other_definitions)) == 0, \
            "Cannot new_merged, since both collection define the same symbols"

        own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions}
        substitution_dict = {}

        processed_other_subexpression_equations = []
        for other_subexpression_eq in other.subexpressions:
            if other_subexpression_eq.lhs in own_subexpression_symbols:
                if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
                    continue  # exact the same subexpression equation exists already
                else:
                    # different definition - a new name has to be introduced
                    new_lhs = next(self.subexpression_symbol_generator)
                    new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict))
                    processed_other_subexpression_equations.append(new_eq)
                    substitution_dict[other_subexpression_eq.lhs] = new_lhs
            else:
                processed_other_subexpression_equations.append(fast_subs(other_subexpression_eq, substitution_dict))

        processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments]
        return self.copy(self.main_assignments + processed_other_main_assignments,
                         self.subexpressions + processed_other_subexpression_equations)

    def new_filtered(self, symbols_to_extract: Iterable[sp.Symbol]) -> 'AssignmentCollection':
        """Extracts equations that have symbols_to_extract as left hand side, together with necessary subexpressions.

        Returns:
            new AssignmentCollection, self is not altered
        """
        symbols_to_extract = set(symbols_to_extract)
        dependent_symbols = self.dependent_symbols(symbols_to_extract)
        new_assignments = []
        for eq in self.all_assignments:
            if eq.lhs in symbols_to_extract:
                new_assignments.append(eq)

        new_sub_expr = [eq for eq in self.subexpressions
                        if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract]
        return AssignmentCollection(new_assignments, new_sub_expr)

    def new_without_unused_subexpressions(self) -> 'AssignmentCollection':
        """Returns new collection that only contains subexpressions required to compute the main assignments."""
        all_lhs = [eq.lhs for eq in self.main_assignments]
        return self.new_filtered(all_lhs)

    def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection':
        """Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere."""
        new_subexpressions = []
        subs_dict = None
        for se in self.subexpressions:
            if se.lhs == symbol:
                subs_dict = {se.lhs: se.rhs}
            else:
                new_subexpressions.append(se)
        if subs_dict is None:
            return self

        new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions]
        new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments]
        return self.copy(new_eqs, new_subexpressions)

    def new_without_subexpressions(self, subexpressions_to_keep: Set[sp.Symbol] = set()) -> 'AssignmentCollection':
        """Returns a new collection where all subexpressions have been inserted."""
        if len(self.subexpressions) == 0:
            return self.copy()

        subexpressions_to_keep = set(subexpressions_to_keep)

        kept_subexpressions = []
        if self.subexpressions[0].lhs in subexpressions_to_keep:
            substitution_dict = {}
            kept_subexpressions = self.subexpressions[0]
        else:
            substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}

        subexpression = [e for e in self.subexpressions]
        for i in range(1, len(subexpression)):
            subexpression[i] = fast_subs(subexpression[i], substitution_dict)
            if subexpression[i].lhs in subexpressions_to_keep:
                kept_subexpressions.append(subexpression[i])
            else:
                substitution_dict[subexpression[i].lhs] = subexpression[i].rhs

        new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments]
        return self.copy(new_assignment, kept_subexpressions)

    # ----------------------------------------- Display and Printing   -------------------------------------------------

    def _repr_html_(self):
        """Interface to Jupyter notebook, to display as a nicely formatted HTML table"""
        def make_html_equation_table(equations):
            no_border = 'style="border:none"'
            html_table = '<table style="border:none; width: 100%; ">'
            line = '<tr {nb}> <td {nb}>$${eq}$$</td>  </tr> '
            for eq in equations:
                format_dict = {'eq': sp.latex(eq),
                               'nb': no_border, }
                html_table += line.format(**format_dict)
            html_table += "</table>"
            return html_table

        result = ""
        if len(self.subexpressions) > 0:
            result += "<div>Subexpressions:</div>"
            result += make_html_equation_table(self.subexpressions)
        result += "<div>Main Assignments:</div>"
        result += make_html_equation_table(self.main_assignments)
        return result

    def __repr__(self):
        return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.main_assignments])

    def __str__(self):
        result = "Subexpressions:\n"
        for eq in self.subexpressions:
            result += f"\t{eq}\n"
        result += "Main Assignments:\n"
        for eq in self.main_assignments:
            result += f"\t{eq}\n"
        return result


class SymbolGen:
    """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
    def __init__(self, symbol="xi"):
        self._ctr = 0
        self._symbol = symbol

    def __iter__(self):
        return self

    def __next__(self):
        name = f"{self._symbol}_{self._ctr}"
        self._ctr += 1
        return sp.Symbol(name)