"""
@author: Maxim Vavilin maxim.vavilin@kit.edu
"""
import numpy as np
import repscat as rs
import copy
import treams.special as sp
import matplotlib.pyplot as plt


def get_domain_tmat_boosted(xi_list, k_list):
    ''' scattered k1 list large and incident k2 list small
    '''
    k1_lists = np.zeros((len(xi_list), len(k_list))) # Scattered freq
    k2_lists = np.zeros((len(xi_list), len(k_list))) # Incident freq
    for i_xi, xi in np.ndenumerate(xi_list):
        min_k1 = k_list[0] * np.exp(-np.abs(xi))
        max_k1 = k_list[-1] * np.exp(np.abs(xi))
        k1_lists[i_xi] = np.linspace(min_k1, max_k1, len(k_list))
        min_k2 = k_list[0] * np.exp(np.abs(xi))
        max_k2 = k_list[-1] * np.exp(-np.abs(xi))
        k2_lists[i_xi] = np.linspace(min_k2, max_k2, len(k_list))
        if min_k2 >= max_k2 or min_k1 >= max_k1:
            raise ValueError('xi too big for the k_list')
    return k1_lists, k2_lists


def get_k3_lists(xi_list, k1_lists, k_list):
    """Integration region for each xi and each k1
    Returns array of shape (xi, k1, k3)
    """
    k3_lists = np.zeros((len(xi_list), k1_lists.shape[1], len(k_list)))
    for i_xi, xi in np.ndenumerate(xi_list):
        for i_k1, k1 in np.ndenumerate(k1_lists[i_xi]):
            min_k3 = max(k1 * np.exp(-np.abs(xi)), k_list[0])
            max_k3 = min(k1 * np.exp(np.abs(xi)), k_list[-1])
            k3_lists[i_xi, i_k1] = np.linspace(min_k3, max_k3, len(k_list))
            if min_k3 - max_k3 > 1e-10:
                raise ValueError(f'k3 list broken: {min_k3}, {max_k3}')
    return k3_lists


def check_wavenumbers(xi_list, k1s, k2s, ks):
    """Checks if scattered k1_list is large and incident k2_list is small wrt k_list """
    for i_xi, xi in np.ndenumerate(xi_list):
        if k1s[i_xi, 0] < ks[0] * np.exp(-np.abs(xi)) or k1s[i_xi, -1] > ks[-1] * np.exp(np.abs(xi)):
            raise ValueError(f"Bad k1 list at xi={xi}")
        if k2s[i_xi, 0] < ks[0] * np.exp(np.abs(xi)) or k2s[i_xi, -1] > ks[-1] * np.exp(-np.abs(xi)):
            raise ValueError(f"Bad k2 list at xi={xi}")


def interpolate_tmat(xi_list, tmat, k1_lists, k_list, k3_lists):
    """tmat should be of shape (k,lam,lam,j,j,m,m)
    returns np.array of (xi,k1,k3,lam,lam,j,j,m,m)
    to do: compare interpolation with generation of Tmatrix
    """
    jay1_max = tmat.shape[3]
    jay2_max = tmat.shape[4]

    num_k1 = k1_lists.shape[1]
    num_k3 = k3_lists.shape[2]

    tmat_interpol = np.zeros(
        (
            len(xi_list),
            num_k1,
            num_k3,
            2,
            2,
            jay1_max,
            jay2_max,
            2 * jay1_max + 1,
            2 * jay2_max + 1,
        ),
        dtype=complex,
    )
    for i_xi in range(len(xi_list)):
        for i_k1 in range(num_k1):
            for i_jay1 in range(jay1_max):
                jay1 = i_jay1 + 1
                for m1 in range(-jay1, jay1 + 1):
                    i_m1 = m1 + jay1_max
                    for i_jay2 in range(jay2_max):
                        jay2 = i_jay2 + 1
                        for m2 in range(-jay2, jay2 + 1):
                            i_m2 = m2 + jay2_max
                            for i_lam1 in [0, 1]:
                                for i_lam2 in [0, 1]:
                                    integration_region = k3_lists[i_xi, i_k1]
                                    real_part = np.interp(
                                        integration_region,
                                        k_list,
                                        np.real(
                                            np.squeeze(
                                                tmat[
                                                    :,
                                                    i_lam1,
                                                    i_lam2,
                                                    i_jay1,
                                                    i_jay2,
                                                    i_m1,
                                                    i_m2,
                                                ]
                                            )
                                        ),
                                    )
                                    imag_part = np.interp(
                                        integration_region,
                                        k_list,
                                        np.imag(
                                            np.squeeze(
                                                tmat[
                                                    :,
                                                    i_lam1,
                                                    i_lam2,
                                                    i_jay1,
                                                    i_jay2,
                                                    i_m1,
                                                    i_m2,
                                                ]
                                            )
                                        ),
                                    )

                                    tmat_interpol[
                                        i_xi,
                                        i_k1,
                                        :,
                                        i_lam1,
                                        i_lam2,
                                        i_jay1,
                                        i_jay2,
                                        i_m1,
                                        i_m2,
                                    ] = (
                                        real_part + 1j * imag_part
                                    )

    return tmat_interpol


def boost_tmat_vals(xi_list):
    """Get boosted T(k1,k2) from T(k) via plane wave basis
    """
    # Get Tmat in PW basis
    

    # Interpolate tmat (to do: compare with computed tmats)
    tmat_interpolated = interpolate_tmat(xi_list, tmat, k1_lists, k_list, k3_lists)
    

    ### Einsum parts
    num_k1 = k1_lists.shape[1]
    num_k2 = k2_lists.shape[1]
    num_k3 = len(k_list)

    integration_measure = np.zeros((len(xi_list), num_k1, num_k3))
    for i_xi, xi in np.ndenumerate(xi_list):
        for i_k1 in range(num_k1):
            integration_measure[i_xi, i_k1] = np.diff(
                k3_lists[i_xi, i_k1], append=np.squeeze(k3_lists[i_xi, i_k1, -1])
            ) / (np.sinh(xi)**2)

    # Wigner matrices
    jay1_max = tmat.shape[3]
    jay2_max = tmat.shape[4]

    wig_i = np.zeros(
        (len(xi_list), num_k1, 2, jay1_max, 2 * jay1_max + 1, num_k3), dtype=complex
    )
    for i_xi, xi in np.ndenumerate(xi_list):
        for i_k1, k1 in np.ndenumerate(k1_lists[i_xi]):
            costheta = np.squeeze(
                (k1 * np.cosh(xi) - k3_lists[i_xi, i_k1]) / (k1 * np.sinh(xi))
            )
            costheta = np.clip(costheta, a_min=-1, a_max=1)
            for i_lam1, lam1 in np.ndenumerate(np.array([1, -1])):
                for jay1 in range(1, jay1_max + 1):
                    i_jay1 = jay1 - 1
                    fac = np.sqrt(2 * jay1 + 1) / k1
                    for m1 in range(-jay1, jay1 + 1):
                        i_m1 = m1 + jay1_max
                        wig_i[i_xi, i_k1, i_lam1, i_jay1, i_m1] = fac * sp.wignersmalld(
                            jay1, m1, lam1, np.arccos(costheta)
                        )
    wig_r = np.zeros(
        (len(xi_list), num_k1, 2, jay1_max, 2 * jay1_max + 1, num_k3), dtype=complex
    )
    for i_xi, xi in np.ndenumerate(xi_list):
        for i_k1, k1 in np.ndenumerate(k1_lists[i_xi]):
            costheta = (k1 - k3_lists[i_xi, i_k1] * np.cosh(xi)) / (
                k3_lists[i_xi, i_k1] * np.sinh(xi)
            )
            costheta = np.clip(costheta, a_min=-1, a_max=1)
            for i_lam1, lam1 in np.ndenumerate(np.array([1, -1])):
                for jay1 in range(1, jay1_max + 1):
                    i_jay1 = jay1 - 1
                    fac = np.sqrt(2 * jay1 + 1)
                    for m1 in range(-jay1, jay1 + 1):
                        i_m1 = m1 + jay1_max
                        wig_r[i_xi, i_k1, i_lam1, i_jay1, i_m1] = fac * sp.wignersmalld(
                            jay1, m1, lam1, np.arccos(costheta)
                        )

    wig_s = np.zeros(
        (len(xi_list), num_k1, num_k2, 2, jay2_max, 2 * jay2_max + 1, num_k3),
        dtype=complex,
    )
    for i_xi, xi in np.ndenumerate(xi_list):
        for i_k1, k1 in np.ndenumerate(k1_lists[i_xi]):
            for i_k2, k2 in np.ndenumerate(k2_lists[i_xi]):
                for i_k3, k3 in np.ndenumerate(k3_lists[i_xi,i_k1]):
                    i_k3 = i_k3[1]
                    if k2*np.exp(-np.abs(xi))<k3<k2*np.exp(np.abs(xi)):
                        costheta = -(k3 * np.cosh(xi) - k2) / (k3 * np.sinh(xi))
                        costheta = np.clip(costheta, a_min=-1, a_max=1)
                        for i_lam2, lam2 in np.ndenumerate(np.array([1, -1])):
                            for jay2 in range(1, jay2_max + 1):
                                i_jay2 = jay2 - 1
                                fac = np.sqrt(2 * jay2 + 1)
                                for m2 in range(-jay2, jay2 + 1):
                                    i_m2 = m2 + jay2_max
                                    wig_s[i_xi, i_k1, i_k2, i_lam2, i_jay2, i_m2, i_k3] = (
                                        fac
                                        * sp.wignersmalld(jay2, m2, lam2, np.arccos(costheta))
                                        / k3
                                    )
    
    wig_j = np.zeros(
        (len(xi_list), num_k1, num_k2, 2, jay2_max, 2 * jay2_max + 1, num_k3),
        dtype=complex,
    )
    for i_xi, xi in np.ndenumerate(xi_list):
        for i_k1, k1 in np.ndenumerate(k1_lists[i_xi]):
            for i_k2, k2 in np.ndenumerate(k2_lists[i_xi]):
                for i_k3, k3 in np.ndenumerate(k3_lists[i_xi,i_k1]):
                    i_k3 = i_k3[1]
                    if k3*np.exp(-np.abs(xi))<k2<k3*np.exp(np.abs(xi)):
                        costheta = -(k3 - k2 * np.cosh(xi)) / (k2 * np.sinh(xi))
                        # if i_k1==idx_k1 and i_k2==idx_k2 and i_k3[1]==0:
                        #     # print('FUCK ', i_xi, i_k1, i_k2, i_k3)
                        #     print(f'CSTHETA {costheta}, k1={k1}, k2={k2}, k3={k3}')
                        costheta = np.clip(costheta, a_min=-1, a_max=1)
                        for i_lam2, lam2 in np.ndenumerate(np.array([1, -1])):
                            for jay2 in range(1, jay2_max + 1):
                                i_jay2 = jay2 - 1
                                fac = np.sqrt(2 * jay2 + 1) / k2
                                for m2 in range(-jay2, jay2 + 1):
                                    i_m2 = m2 + jay2_max
                                    wig_j[
                                        i_xi, i_k1, i_k2, i_lam2, i_jay2, i_m2, i_k3
                                    ] = fac * sp.wignersmalld(
                                        jay2, m2, lam2, np.arccos(costheta)
                                    )
                                    
    # print(f'Check wig1={wig_i[0,idx_k1,0,0,jay1_max+1,:]}')
    # print(f'Check wig2={wig_r[0,idx_k1,0,0,jay1_max+1,:]}')
    # print(f'Check wig3[0]={wig_s[0,idx_k1,idx_k2,0,0,jay1_max+1,0]}')
    # Einsum legend: x=xi, p=k1, q=k2, k=k3,
    # i=j1, j=j2, r=j1prime, s=j2prime,
    # m=m1, n=m2, a=lam1, b=lam2
    tmat_arr_boosted = (
        np.einsum(
            "xpk,xpkabrsmn,xpaimk,xparmk,xpqbsnk,xpqbjnk->xpqabijmn",
            integration_measure, #(xi,k1,k3)
            tmat_interpolated,  # (xi,k1,k3,lam,lam,j,j,m,m)
            wig_i,  # (xi,k1,lam1,jay1,m1,k3)
            wig_r,  # (xi,k1,lam1,jay1prime,m1,k3)
            wig_s,  # (xi,k1,k2,lam2,jay2prime,m2,k3)
            wig_j,  # (xi,k1,k2,lam2,jay2,m2,k3)
            optimize=True,
        )* 0.25
    )
    return tmat_arr_boosted
