"""
@author: Maxim Vavilin maxim.vavilin@kit.edu
"""
import numpy as np
import repscat as rs
import copy
import treams.special as sp
import matplotlib.pyplot as plt
import time


def get_k2_lists(xi, k1_list, k2_domain):
    """k2 values that contribute to each k1
    Returns array of shape (k1, k2).
    k2_domain is just where we want our incident pulse to be.
    """
    k2_lists = np.zeros((len(k1_list), len(k1_list)))
    for i_k1, k1 in np.ndenumerate(k1_list):
        k2_min = max(k2_domain[0], k1*np.exp(-2*np.abs(xi)))
        k2_max = min(k2_domain[-1], k1*np.exp(2*np.abs(xi)))
        k2_lists[i_k1] = np.linspace(k2_min, k2_max, len(k1_list))
    return k2_lists

def get_k3_lists(xi, k1_list, k2_lists):
    """Integration region for each xi and each k1
    Returns array of shape (k1, k2, k3)
    """
    k3_lists = np.zeros((len(k1_list), k2_lists.shape[1], len(k1_list)))
    for i_k1, k1 in np.ndenumerate(k1_list):
        k2_list = np.squeeze(k2_lists[i_k1])
        for i_k2, k2 in np.ndenumerate(k2_list):
            min_k3 = max(k1 * np.exp(-np.abs(xi)), k2 * np.exp(-np.abs(xi)))
            max_k3 = min(k1 * np.exp(np.abs(xi)), k2 * np.exp(np.abs(xi)))
            k3_lists[i_k1, i_k2] = np.linspace(min_k3, max_k3, len(k1_list))
            if min_k3 - max_k3 > 1e-10:
                raise ValueError(f"k3 list broken: {min_k3}, {max_k3}")
    return k3_lists

def interpolate_tmat(k_list, tmat, k3_lists):
    """tmat should be of shape (k,lam,lam,j,j,m,m)
    returns np.array of (k1,k2,k3,lam,lam,j,j,m,m)
    to do: compare interpolation with generation of Tmatrix
    """
    jay1_max = tmat.shape[3]
    jay2_max = tmat.shape[4]

    num_k1 = k3_lists.shape[0]
    num_k2 = k3_lists.shape[1]
    num_k3 = k3_lists.shape[2]

    tmat_interpol = np.zeros(
        (
            num_k1,
            num_k2,
            num_k3,
            2,
            2,
            jay1_max,
            jay2_max,
            2 * jay1_max + 1,
            2 * jay2_max + 1,
        ),
        dtype=complex,
    )
    for i_k1 in range(num_k1):
        for i_k2 in range(num_k2):
            for i_jay1 in range(jay1_max):
                jay1 = i_jay1 + 1
                for m1 in range(-jay1, jay1 + 1):
                    i_m1 = m1 + jay1_max
                    for i_jay2 in range(jay2_max):
                        jay2 = i_jay2 + 1
                        for m2 in range(-jay2, jay2 + 1):
                            i_m2 = m2 + jay2_max
                            for i_lam1 in [0, 1]:
                                for i_lam2 in [0, 1]:
                                    integration_region = np.squeeze(k3_lists[i_k1, i_k2])
                                    real_part = np.interp(
                                        integration_region,
                                        k_list,
                                        np.real(
                                            np.squeeze(
                                                tmat[
                                                    :,
                                                    i_lam1,
                                                    i_lam2,
                                                    i_jay1,
                                                    i_jay2,
                                                    i_m1,
                                                    i_m2,
                                                ]
                                            )
                                        ),
                                    )
                                    imag_part = np.interp(
                                        integration_region,
                                        k_list,
                                        np.imag(
                                            np.squeeze(
                                                tmat[
                                                    :,
                                                    i_lam1,
                                                    i_lam2,
                                                    i_jay1,
                                                    i_jay2,
                                                    i_m1,
                                                    i_m2,
                                                ]
                                            )
                                        ),
                                    )

                                    tmat_interpol[
                                        i_k1,
                                        i_k2,
                                        :,
                                        i_lam1,
                                        i_lam2,
                                        i_jay1,
                                        i_jay2,
                                        i_m1,
                                        i_m2,
                                    ] = (
                                        real_part + 1j * imag_part
                                    )

    return tmat_interpol


def boost_tmat_vals_precise(xi, k_list, tmat):
    """Get T(k1,k2) from T(k)
    k3_list is a sub-interval of k_list for each xi. It is the integration domain.
    """
    
    # Check if frequencies work together
    # check_wavenumbers(xi_list, k1_list, k2_list, k_list)

    k1_list = np.linspace(
        k_list[0] * np.exp(-np.abs(xi)), k_list[-1] * np.exp(np.abs(xi)), len(k_list)
    )
    k2_domain = np.linspace(
        k_list[0] * np.exp(np.abs(xi)), k_list[-1] * np.exp(-np.abs(xi)), len(k_list)
    ) # where the incident pulse is defined

    k2_lists = get_k2_lists(xi, k1_list, k2_domain) # where the T-mat is defined

    # Get integration domain for each k1 and k2
    k3_lists = get_k3_lists(xi, k1_list, k2_lists)

    # Interpolate tmat (to do: compare with computed tmats)
    t1 = time.time()
    tmat_interpolated = interpolate_tmat(k_list, tmat, k3_lists)
    # print("Time for tmat interpol: ", time.time() - t1)

    ### Einsum parts
    num_k1 = k3_lists.shape[0]
    num_k2 = k3_lists.shape[1]
    num_k3 = k3_lists.shape[2]

    integration_measure = np.zeros((num_k1, num_k2, num_k3))
    for i_k1 in range(num_k1):
        for i_k2 in range(num_k2):
            k3_list = np.squeeze(k3_lists[i_k1, i_k2])
            integration_measure[i_k1, i_k2] = np.diff(
                k3_list, append=k3_list[-1])/ (np.sinh(xi) ** 2) # Our definition of monochrom Tmat

    # Wigner matrices
    jay1_max = tmat.shape[3]
    jay2_max = tmat.shape[4]

    lam_list = np.array([1, -1])

    t1 = time.time()
    wig_i = np.zeros(
        (num_k1, num_k2, 2, jay1_max, 2 * jay1_max + 1, num_k3), dtype=complex
    )
    wig_r = np.zeros(
        (num_k1, num_k2, 2, jay1_max, 2 * jay1_max + 1, num_k3), dtype=complex
    )
    wig_s = np.zeros(
        (num_k1, num_k2, 2, jay2_max, 2 * jay2_max + 1, num_k3), dtype=complex
    )
    wig_j = np.zeros(
        (num_k1, num_k2, 2, jay2_max, 2 * jay2_max + 1, num_k3), dtype=complex
    )
    for i_k1, k1 in np.ndenumerate(k1_list):
        i_k1=i_k1[0]
        k2_list = np.squeeze(k2_lists[i_k1])
        for i_k2, k2 in np.ndenumerate(k2_list):
            i_k2=i_k2[0]
            # if k1 * np.exp(-2 * np.abs(xi)) - k2 > 0 or k2 - k1 * np.exp(2 * np.abs(xi)) > 0:
            #     print('BAD at i_k1=', i_k1)

            if k1 * np.exp(-2 * np.abs(xi)) - k2 > 1e-7  or k2 - k1 * np.exp(2 * np.abs(xi)) > 1e-7:
                raise ValueError(f'Bad k2_lists: not {k1 * np.exp(-2 * np.abs(xi))} < {k2} < {k1 * np.exp(2 * np.abs(xi))}')
            k3_list = np.squeeze(k3_lists[i_k1, i_k2])

            # WIG_1
            costheta = np.squeeze((k1 * np.cosh(xi) - k3_list) / (k1 * np.sinh(xi)))
            if np.any(np.abs(costheta) > 1.05):
                raise ValueError("Something wrong with cos theta")
            costheta = np.clip(costheta, a_min=-1, a_max=1)
            
            for i_lam1, lam1 in np.ndenumerate(lam_list):
                for jay1 in range(1, jay1_max + 1):
                    i_jay1 = jay1 - 1
                    fac = np.sqrt(2 * jay1 + 1) / k1
                    for m1 in range(-jay1, jay1 + 1):
                        i_m1 = m1 + jay1_max
                        wig_i[
                            i_k1, i_k2, i_lam1, i_jay1, i_m1
                        ] = fac * sp.wignersmalld(
                            jay1, m1, lam1, np.arccos(costheta)
                        )
            
            # WIG_2
            costheta = (k1 - k3_list * np.cosh(xi)) / (k3_list * np.sinh(xi))
            if np.any(np.abs(costheta) > 1.05):
                raise ValueError("Something wrong with cos theta")
            costheta = np.clip(costheta, a_min=-1, a_max=1)
            for i_lam1, lam1 in np.ndenumerate(lam_list):
                for jay1 in range(1, jay1_max + 1):
                    i_jay1 = jay1 - 1
                    fac = np.sqrt(2 * jay1 + 1)
                    for m1 in range(-jay1, jay1 + 1):
                        i_m1 = m1 + jay1_max
                        wig_r[
                            i_k1, i_k2, i_lam1, i_jay1, i_m1
                        ] = fac * sp.wignersmalld(
                            jay1, m1, lam1, np.arccos(costheta)
                        )

            # WIG_3
            costheta = -(k3_list * np.cosh(xi) - k2) / (k3_list * np.sinh(xi))
            if np.any(np.abs(costheta) > 1.05):
                print(f"k1 {k1} k2 {k2} k3[0] {k3_list[0]} k3[-1] {k3_list[-1]}")
                print(costheta)
                raise ValueError("Bad costheta")
            costheta = np.clip(costheta, a_min=-1, a_max=1)
            for i_lam2, lam2 in np.ndenumerate(lam_list):
                for jay2 in range(1, jay2_max + 1):
                    i_jay2 = jay2 - 1
                    fac = np.sqrt(2 * jay2 + 1)
                    for m2 in range(-jay2, jay2 + 1):
                        i_m2 = m2 + jay2_max
                        wig_s[i_k1, i_k2, i_lam2, i_jay2, i_m2] = (
                            fac
                            * sp.wignersmalld(jay2, m2, lam2, np.arccos(costheta))
                            / k3_list
                        )

            # WIG_4
            costheta = -(k3_list - k2 * np.cosh(xi)) / (k2 * np.sinh(xi))
            if np.any(np.abs(costheta) > 1.05):
                raise ValueError("Bad costheta")
            costheta = np.clip(costheta, a_min=-1, a_max=1)
            for i_lam2, lam2 in np.ndenumerate(lam_list):
                for jay2 in range(1, jay2_max + 1):
                    i_jay2 = jay2 - 1
                    fac = np.sqrt(2 * jay2 + 1) / k2
                    for m2 in range(-jay2, jay2 + 1):
                        i_m2 = m2 + jay2_max
                        wig_j[
                            i_k1, i_k2, i_lam2, i_jay2, i_m2
                        ] = fac * sp.wignersmalld(
                            jay2, m2, lam2, np.arccos(costheta)
                        )
    # print("Time for wigs: ", time.time() - t1)
    # print("Check k1 k2 k3")
    # idx_k1 = np.abs(k1_list - 0.0095).argmin()
    # idx_k2 = np.abs(np.squeeze(k2_lists[idx_k1]) - 0.00949).argmin()+1
    # print(np.squeeze(k2_list[idx_k1]))
    # print('Points I check:')
    # print(f"idx_k1 {idx_k1} idx_k2 {idx_k2}")
    # print(f"k1 {k1_list[idx_k1]} k2 {k2_lists[idx_k1, idx_k2]}")
    # print(k1_list[idx_k1])
    # print(k2_list[idx_k2])
    # print(k3_lists[idx_k1,idx_k2,15])
    # print(f"Check wig1={wig_i[idx_k1,idx_k2,1,0,jay1_max+1,:]}")
    # print(f"Check wig2={wig_r[idx_k1,idx_k2,1,0,jay1_max+1,:]}")
    # print(f"Check wig3={wig_s[idx_k1,idx_k2,1,0,jay1_max+1,:]}")
    # print('In code tmat interp ', tmat_interpolated[idx_k1,idx_k2,:,1,0,0,0,jay1_max,jay1_max])
    # print('In code tmat  ', tmat[:,1,0,0,0,jay1_max,jay1_max])
    # path = np.einsum_path(
    #     "xpk,xpkabrsmn,xpaimk,xparmk,xpqbsnk,xpqbjnk->xpqabijmn",
    #     integration_measure,  # (xi,k1,k3)
    #     tmat_interpolated,  # (xi,k1,k3,lam,lam,j,j,m,m)
    #     wig_i,  # (xi,k1,lam1,jay1,m1,k3)
    #     wig_r,  # (xi,k1,lam1,jay1prime,m1,k3)
    #     wig_s,  # (xi,k1,k2,lam2,jay2prime,m2,k3)
    #     wig_j,  # (xi,k1,k2,lam2,jay2,m2,k3)
    # )

    t1 = time.time()
    # Einsum legend: x=xi, p=k1, q=k2, k=k3,
    # i=j1, j=j2, r=j1prime, s=j2prime,
    # m=m1, n=m2, a=lam1, b=lam2
    tmat_arr_boosted = (
        np.einsum(
            "pqk,pqkabrsmn,pqaimk,pqarmk,pqbsnk,pqbjnk->pqabijmn",
            integration_measure,  # (k1,k2,k3)
            tmat_interpolated,  # (k1,k2,k3,lam,lam,j1prime,j2prime,m1,m2)
            wig_i,  # (k1,k2,lam1,jay1,m1,k3)
            wig_r,  # (k1,k2,lam1,jay1prime,m1,k3)
            wig_s,  # (k1,k2,lam2,jay2prime,m2,k3)
            wig_j,  # (k1,k2,lam2,jay2,m2,k3)
            optimize=True,
        )
        * 0.25
    )
    # print("Time for einsum: ", time.time() - t1)
    # print('What I check in code: ', tmat_arr_boosted[idx_k1, idx_k2, 0,0,0,0,jay1_max+1, jay2_max+1])
    return tmat_arr_boosted
