import itertools
import warnings
import operator
from functools import reduce, partial
from collections import defaultdict, Counter
import sympy as sp
from sympy.functions import Abs
from typing import Optional, Union, List, TypeVar, Iterable, Sequence, Callable, Dict, Tuple

from pystencils.data_types import get_type_of_expression, get_base_type, cast_func
from pystencils.assignment import Assignment

T = TypeVar('T')


def prod(seq: Iterable[T]) -> T:
    """Takes a sequence and returns the product of all elements"""
    return reduce(operator.mul, seq, 1)


def remove_small_floats(expr, threshold):
    """Removes all sp.Float objects whose absolute value is smaller than threshold

    >>> expr = sp.sympify("x + 1e-15 * y")
    >>> remove_small_floats(expr, 1e-14)
    x
    """
    if isinstance(expr, sp.Float) and sp.Abs(expr) < threshold:
        return 0
    else:
        new_args = [remove_small_floats(c, threshold) for c in expr.args]
        return expr.func(*new_args) if new_args else expr


def is_integer_sequence(sequence: Iterable) -> bool:
    """Checks if all elements of the passed sequence can be cast to integers"""
    try:
        for i in sequence:
            int(i)
        return True
    except TypeError:
        return False


def scalar_product(a: Iterable[T], b: Iterable[T]) -> T:
    """Scalar product between two sequences."""
    return sum(a_i * b_i for a_i, b_i in zip(a, b))


def kronecker_delta(*args):
    """Kronecker delta for variable number of arguments, 1 if all args are equal, otherwise 0"""
    for a in args:
        if a != args[0]:
            return 0
    return 1


def tanh_step_function_approximation(x, step_location, kind='right', steepness=0.0001):
    """Approximation of step function by a tanh function

    >>> tanh_step_function_approximation(1.2, step_location=1.0, kind='right')
    1.00000000000000
    >>> tanh_step_function_approximation(0.9, step_location=1.0, kind='right')
    0
    >>> tanh_step_function_approximation(1.1, step_location=1.0, kind='left')
    0
    >>> tanh_step_function_approximation(0.9, step_location=1.0, kind='left')
    1.00000000000000
    >>> tanh_step_function_approximation(0.5, step_location=(0, 1), kind='middle')
    1
    """
    if kind == 'left':
        return (1 - sp.tanh((x - step_location) / steepness)) / 2
    elif kind == 'right':
        return (1 + sp.tanh((x - step_location) / steepness)) / 2
    elif kind == 'middle':
        x1, x2 = step_location
        return 1 - (tanh_step_function_approximation(x, x1, 'left', steepness)
                    + tanh_step_function_approximation(x, x2, 'right', steepness))


def multidimensional_sum(i, dim):
    """Multidimensional summation

    Example:
        >>> list(multidimensional_sum(2, dim=3))
        [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]
    """
    prod_args = [range(dim)] * i
    return itertools.product(*prod_args)


def normalize_product(product: sp.Expr) -> List[sp.Expr]:
    """Expects a sympy expression that can be interpreted as a product and returns a list of all factors.

    Removes sp.Pow nodes that have integer exponent by representing them as single factors in list.

    Returns:
        * for a Mul node list of factors ('args')
        * for a Pow node with positive integer exponent a list of factors
        * for other node types [product] is returned
    """
    def handle_pow(power):
        if power.exp.is_integer and power.exp.is_number and power.exp > 0:
            return [power.base] * power.exp
        else:
            return [power]

    if isinstance(product, sp.Pow):
        return handle_pow(product)
    elif isinstance(product, sp.Mul):
        result = []
        for a in product.args:
            if a.func == sp.Pow:
                result += handle_pow(a)
            else:
                result.append(a)
        return result
    else:
        return [product]


def symmetric_product(*args, with_diagonal: bool = True) -> Iterable:
    """Similar to itertools.product but yields only values where the index is ascending i.e. values below/up to diagonal

    Examples:
        >>> list(symmetric_product([1, 2, 3], ['a', 'b', 'c']))
        [(1, 'a'), (1, 'b'), (1, 'c'), (2, 'b'), (2, 'c'), (3, 'c')]
        >>> list(symmetric_product([1, 2, 3], ['a', 'b', 'c'], with_diagonal=False))
        [(1, 'b'), (1, 'c'), (2, 'c')]
    """
    ranges = [range(len(a)) for a in args]
    for idx in itertools.product(*ranges):
        valid_index = True
        for t in range(1, len(idx)):
            if (with_diagonal and idx[t - 1] > idx[t]) or (not with_diagonal and idx[t - 1] >= idx[t]):
                valid_index = False
                break
        if valid_index:
            yield tuple(a[i] for a, i in zip(args, idx))


def fast_subs(expression: T, substitutions: Dict,
              skip: Optional[Callable[[sp.Expr], bool]] = None) -> T:
    """Similar to sympy subs function.

    Args:
        expression: expression where parts should be substituted
        substitutions: dict defining substitutions by mapping from old to new terms
        skip: function that marks expressions to be skipped (if True is returned) - that means that in these skipped
              expressions no substitutions are done

    This version is much faster for big substitution dictionaries than sympy version
    """
    if type(expression) is sp.Matrix:
        return expression.copy().applyfunc(partial(fast_subs, substitutions=substitutions))

    def visit(expr):
        if skip and skip(expr):
            return expr
        if hasattr(expr, "fast_subs"):
            return expr.fast_subs(substitutions)
        if expr in substitutions:
            return substitutions[expr]
        if not hasattr(expr, 'args'):
            return expr
        param_list = [visit(a) for a in expr.args]
        return expr if not param_list else expr.func(*param_list)

    if len(substitutions) == 0:
        return expression
    else:
        return visit(expression)


def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
                  required_match_replacement: Optional[Union[int, float]] = 0.5,
                  required_match_original: Optional[Union[int, float]] = None) -> sp.Expr:
    """Transformation for replacing a given subexpression inside a sum.

    Examples:
        The next example demonstrates the advantage of replace_additive compared to sympy.subs:
        >>> x, y, z, k = sp.symbols("x y z k")
        >>> subs_additive(3*x + 3*y, replacement=k, subexpression=x + y)
        3*k

        Terms that don't match completely can be substituted at the cost of additional terms.
        This trade-off is managed using the required_match parameters.
        >>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=1.0)
        3*x + 3*y + z
        >>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=0.5)
        3*k - 2*z
        >>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=2)
        3*k - 2*z

    Args:
        expr: input expression
        replacement: expression that is inserted for subexpression (if found)
        subexpression: expression to replace
        required_match_replacement:
             * if float: the percentage of terms of the subexpression that has to be matched in order to replace
             * if integer: the total number of terms that has to be matched in order to replace
             * None: is equal to integer 1
             * if both match parameters are given, both restrictions have to be fulfilled (i.e. logical AND)
        required_match_original:
             * if float: the percentage of terms of the original addition expression that has to be matched
             * if integer: the total number of terms that has to be matched in order to replace
             * None: is equal to integer 1

    Returns:
        new expression with replacement
    """
    def normalize_match_parameter(match_parameter, expression_length):
        if match_parameter is None:
            return 1
        elif isinstance(match_parameter, float):
            assert 0 <= match_parameter <= 1
            res = int(match_parameter * expression_length)
            return max(res, 1)
        elif isinstance(match_parameter, int):
            assert match_parameter > 0
            return match_parameter
        raise ValueError("Invalid parameter")

    normalized_replacement_match = normalize_match_parameter(required_match_replacement, len(subexpression.args))

    def visit(current_expr):
        if current_expr.is_Add:
            expr_max_length = max(len(current_expr.args), len(subexpression.args))
            normalized_current_expr_match = normalize_match_parameter(required_match_original, expr_max_length)
            expr_coefficients = current_expr.as_coefficients_dict()
            subexpression_coefficient_dict = subexpression.as_coefficients_dict()
            intersection = set(subexpression_coefficient_dict.keys()).intersection(set(expr_coefficients))
            if len(intersection) >= max(normalized_replacement_match, normalized_current_expr_match):
                # find common factor
                factors = defaultdict(lambda: 0)
                skips = 0
                for common_symbol in subexpression_coefficient_dict.keys():
                    if common_symbol not in expr_coefficients:
                        skips += 1
                        continue
                    factor = expr_coefficients[common_symbol] / subexpression_coefficient_dict[common_symbol]
                    factors[sp.simplify(factor)] += 1

                common_factor = max(factors.items(), key=operator.itemgetter(1))[0]
                if factors[common_factor] >= max(normalized_current_expr_match, normalized_replacement_match):
                    return current_expr - common_factor * subexpression + common_factor * replacement

        # if no subexpression was found
        param_list = [visit(a) for a in current_expr.args]
        if not param_list:
            return current_expr
        else:
            return current_expr.func(*param_list, evaluate=False)

    return visit(expr)


def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol],
                                  positive: Optional[bool] = None,
                                  replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr:
    """Replaces second order mixed terms like x*y by 2*( (x+y)**2 - x**2 - y**2 ).

    This makes the term longer - simplify usually is undoing these - however this
    transformation can be done to find more common sub-expressions

    Args:
        expr: input expression
        search_symbols: symbols that are searched for
                         for example, given [x,y,z] terms like x*y, x*z, z*y are replaced
        positive: there are two ways to do this substitution, either with term
                 (x+y)**2 or (x-y)**2 . if positive=True the first version is done,
                 if positive=False the second version is done, if positive=None the
                 sign is determined by the sign of the mixed term that is replaced
        replace_mixed: if a list is passed here, the expr x+y or x-y is replaced by a special new symbol
                       and the replacement equation is added to the list
    """
    mixed_symbols_replaced = set([e.lhs for e in replace_mixed]) if replace_mixed is not None else set()

    if expr.is_Mul:
        distinct_search_symbols = set()
        nr_of_search_terms = 0
        other_factors = 1
        for t in expr.args:
            if t in search_symbols:
                nr_of_search_terms += 1
                distinct_search_symbols.add(t)
            else:
                other_factors *= t
        if len(distinct_search_symbols) == 2 and nr_of_search_terms == 2:
            u, v = sorted(list(distinct_search_symbols), key=lambda symbol: symbol.name)
            if positive is None:
                other_factors_without_symbols = other_factors
                for s in other_factors.atoms(sp.Symbol):
                    other_factors_without_symbols = other_factors_without_symbols.subs(s, 1)
                positive = other_factors_without_symbols.is_positive
                assert positive is not None
            sign = 1 if positive else -1
            if replace_mixed is not None:
                new_symbol_str = 'P' if positive else 'M'
                mixed_symbol_name = u.name + new_symbol_str + v.name
                mixed_symbol = sp.Symbol(mixed_symbol_name.replace("_", ""))
                if mixed_symbol not in mixed_symbols_replaced:
                    mixed_symbols_replaced.add(mixed_symbol)
                    replace_mixed.append(Assignment(mixed_symbol, u + sign * v))
            else:
                mixed_symbol = u + sign * v
            return sp.Rational(1, 2) * sign * other_factors * (mixed_symbol ** 2 - u ** 2 - v ** 2)

    param_list = [replace_second_order_products(a, search_symbols, positive, replace_mixed) for a in expr.args]
    result = expr.func(*param_list, evaluate=False) if param_list else expr
    return result


def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order: int = 3) -> sp.Expr:
    """Removes all terms that contain more than 'order' factors of given 'symbols'

    Example:
        >>> x, y = sp.symbols("x y")
        >>> term = x**2 * y + y**2 * x + y**3 + x + y ** 2
        >>> remove_higher_order_terms(term, order=2, symbols=[x, y])
        x + y**2
    """
    from sympy.core.power import Pow
    from sympy.core.add import Add, Mul

    result = 0
    expr = expr.expand()

    def velocity_factors_in_product(product):
        factor_count = 0
        if type(product) is Mul:
            for factor in product.args:
                if type(factor) == Pow:
                    if factor.args[0] in symbols:
                        factor_count += factor.args[1]
                if factor in symbols:
                    factor_count += 1
        elif type(product) is Pow:
            if product.args[0] in symbols:
                factor_count += product.args[1]
        return factor_count

    if type(expr) == Mul or type(expr) == Pow:
        if velocity_factors_in_product(expr) <= order:
            return expr
        else:
            return sp.Rational(0, 1)

    if type(expr) != Add:
        return expr

    for sum_term in expr.args:
        if velocity_factors_in_product(sum_term) <= order:
            result += sum_term
    return result


def complete_the_square(expr: sp.Expr, symbol_to_complete: sp.Symbol,
                        new_variable: sp.Symbol) -> Tuple[sp.Expr, Optional[Tuple[sp.Symbol, sp.Expr]]]:
    """Transforms second order polynomial into only squared part.

    Examples:
        >>> a, b, c, s, n = sp.symbols("a b c s n")
        >>> expr = a * s**2 + b * s + c
        >>> completed_expr, substitution = complete_the_square(expr, symbol_to_complete=s, new_variable=n)
        >>> completed_expr
        a*n**2 + c - b**2/(4*a)
        >>> substitution
        (n, s + b/(2*a))

    Returns:
        (replaced_expr, tuple to pass to subs, such that old expr comes out again)
    """
    p = sp.Poly(expr, symbol_to_complete)
    coefficients = p.all_coeffs()
    if len(coefficients) != 3:
        return expr, None
    a, b, _ = coefficients
    expr = expr.subs(symbol_to_complete, new_variable - b / (2 * a))
    return sp.simplify(expr), (new_variable, symbol_to_complete + b / (2 * a))


def complete_the_squares_in_exp(expr: sp.Expr, symbols_to_complete: Sequence[sp.Symbol]):
    """Completes squares in arguments of exponential which makes them simpler to integrate.

    Very useful for integrating Maxwell-Boltzmann equilibria and its moment generating function
    """
    dummies = [sp.Dummy() for _ in symbols_to_complete]

    def visit(term):
        if term.func == sp.exp:
            exp_arg = term.args[0]
            for symbol_to_complete, dummy in zip(symbols_to_complete, dummies):
                exp_arg, substitution = complete_the_square(exp_arg, symbol_to_complete, dummy)
            return sp.exp(sp.expand(exp_arg))
        else:
            param_list = [visit(a) for a in term.args]
            if not param_list:
                return term
            else:
                return term.func(*param_list)

    result = visit(expr)
    for s, d in zip(symbols_to_complete, dummies):
        result = result.subs(d, s)
    return result


def extract_most_common_factor(term):
    """Processes a sum of fractions: determines the most common factor and splits term in common factor and rest"""
    coefficient_dict = term.as_coefficients_dict()
    counter = Counter([Abs(v) for v in coefficient_dict.values()])
    common_factor, occurrences = max(counter.items(), key=operator.itemgetter(1))
    if occurrences == 1 and (1 in counter):
        common_factor = 1
    return common_factor, term / common_factor


def count_operations(term: Union[sp.Expr, List[sp.Expr]],
                     only_type: Optional[str] = 'real') -> Dict[str, int]:
    """Counts the number of additions, multiplications and division.

    Args:
        term: a sympy expression (term, assignment) or sequence of sympy objects
        only_type: 'real' or 'int' to count only operations on these types, or None for all

    Returns:
        dict with 'adds', 'muls' and 'divs' keys
    """
    result = {'adds': 0, 'muls': 0, 'divs': 0}

    if isinstance(term, Sequence):
        for element in term:
            r = count_operations(element, only_type)
            for operation_name in result.keys():
                result[operation_name] += r[operation_name]
        return result
    elif isinstance(term, Assignment):
        term = term.rhs

    if not hasattr(term, 'evalf'):
        return result

    term = term.evalf()

    def check_type(e):
        if only_type is None:
            return True
        try:
            base_type = get_base_type(get_type_of_expression(e))
        except ValueError:
            return False
        if only_type == 'int' and (base_type.is_int() or base_type.is_uint()):
            return True
        if only_type == 'real' and (base_type.is_float()):
            return True
        else:
            return base_type == only_type

    def visit(t):
        visit_children = True
        if t.func is sp.Add:
            if check_type(t):
                result['adds'] += len(t.args) - 1
        elif t.func in [sp.Or, sp.And]:
            pass
        elif t.func is sp.Mul:
            if check_type(t):
                result['muls'] += len(t.args) - 1
                for a in t.args:
                    if a == 1 or a == -1:
                        result['muls'] -= 1
        elif isinstance(t, sp.Float) or isinstance(t, sp.Rational):
            pass
        elif isinstance(t, sp.Symbol):
            visit_children = False
        elif isinstance(t, sp.tensor.Indexed):
            visit_children = False
        elif t.is_integer:
            pass
        elif isinstance(t, cast_func):
            visit_children = False
            visit(t.args[0])
        elif t.func is sp.Pow:
            if check_type(t.args[0]):
                visit_children = False
                if t.exp.is_integer and t.exp.is_number:
                    if t.exp >= 0:
                        result['muls'] += int(t.exp) - 1
                    else:
                        result['muls'] -= 1
                        result['divs'] += 1
                        result['muls'] += (-int(t.exp)) - 1
            else:
                warnings.warn("Counting operations: only integer exponents are supported in Pow, "
                              "counting will be inaccurate")
        else:
            warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate")

        if visit_children:
            for a in t.args:
                visit(a)

    visit(term)
    return result


def count_operations_in_ast(ast) -> Dict[str, int]:
    """Counts number of operations in an abstract syntax tree, see also :func:`count_operations`"""
    from pystencils.astnodes import SympyAssignment
    result = {'adds': 0, 'muls': 0, 'divs': 0}

    def visit(node):
        if isinstance(node, SympyAssignment):
            r = count_operations(node.rhs)
            result['adds'] += r['adds']
            result['muls'] += r['muls']
            result['divs'] += r['divs']
        else:
            for arg in node.args:
                visit(arg)
    visit(ast)
    return result


def common_denominator(expr: sp.Expr) -> sp.Expr:
    """Finds least common multiple of all denominators occurring in an expression"""
    denominators = [r.q for r in expr.atoms(sp.Rational)]
    return sp.lcm(denominators)


def get_symmetric_part(expr: sp.Expr, symbols: Iterable[sp.Symbol]) -> sp.Expr:
    """
    Returns the symmetric part of a sympy expressions.

    Args:
        expr: sympy expression, labeled here as :math:`f`
        symbols: sequence of symbols which are considered as degrees of freedom, labeled here as :math:`x_0, x_1,...`

    Returns:
        :math:`\frac{1}{2} [ f(x_0, x_1, ..) + f(-x_0, -x_1) ]`
    """
    substitution_dict = {e: -e for e in symbols}
    return sp.Rational(1, 2) * (expr + expr.subs(substitution_dict))


def sort_assignments_topologically(assignments: Sequence[Assignment]) -> List[Assignment]:
    """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
    res = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in assignments])
    return [Assignment(a, b) for a, b in res]


class SymbolCreator:
    def __getattribute__(self, name):
        return sp.Symbol(name)