import numpy as np
from scipy.special import spherical_jn, spherical_yn
import repscat as rs
import repscat.aux_funcs as af
from repscat.constants import C_0_SI, H_BAR_SI, EPSILON_0_SI
from wigners import wigner_3j, clebsch_gordan
import treams.special as sp

from compute_quantities_in_field_over_shifted_sphere import get_field_at_displaced_sphere

def compute_scalar_product_over_shifted_sphere(outgoing, g_wavefunc, sphere_params, displacement):
    displacement_vectors = [
        [displacement, 0, 0],
        [0, displacement, 0],
        [0, 0, displacement]
    ]
    axes = ['x','y','z']
    
    for i in range(3):
        print(f'Scalar products over {axes[i]}-displaced sphere:')
        print('out reg ', get_scalar_product_over_shifted_sphere(outgoing, 'h+', g_wavefunc, 'j', sphere_params, displacement_vectors[i]))
        print('out out ', get_scalar_product_over_shifted_sphere(outgoing, 'h+', g_wavefunc, 'h+', sphere_params, displacement_vectors[i]))
        print('out in ', get_scalar_product_over_shifted_sphere(outgoing, 'h+', g_wavefunc, 'h-', sphere_params, displacement_vectors[i]))
    
def get_scalar_product_over_shifted_sphere(repr1, radtype1, repr2, radtype2, sphere_params, displacement_vector):
    surface_radius = sphere_params["surface_radius"]
    num_theta = sphere_params["num_theta"]
    num_phi = sphere_params["num_phi"]

    # Dimensional constant
    factor = (
        2  # I use E-field instead of RS vectors
        * (-1j)
        * surface_radius**2
        * 1e-18
        * rs.EPSILON_0_SI  # Different def of RS vecs
        / (rs.H_BAR_SI * rs.C_0_SI)
    )

    theta_list = np.linspace(0, np.pi, num_theta)
    phi_list = np.linspace(0, 2 * np.pi, num_phi)
    k_list = repr1.k_list
    params = {
        "num_r": 1,
        "num_theta": num_theta,
        "num_phi": num_phi,
        "t_list": None,
        "max_r": surface_radius,
        "what_to_plot": "E-density",
        "region": sphere_params['region']
    }
    spacetime_domain = rs.visuals.get_spacetime_domain(params)
    spacetime_domain["r_list"] = [surface_radius]

    # Electromagnetic field in (lam, k, \bm r)
    field1 = get_field_at_displaced_sphere(repr1, radtype1, num_theta, num_phi, surface_radius, displacement_vector)
    field2 = get_field_at_displaced_sphere(repr2, radtype2, num_theta, num_phi, surface_radius, displacement_vector)

    # Parts for einsum
    levi_civita = np.zeros((3, 3, 3), dtype=int)
    for i in range(1, 4):
        for j in range(1, 4):
            for k in range(1, 4):
                levi_civita[i - 1, j - 1, k - 1] = (i - j) * (j - k) * (k - i) / 2

    vec_surf = np.zeros((3, len(theta_list), len(phi_list)), dtype=complex)
    for i_theta, theta in np.ndenumerate(theta_list):
        for i_phi, phi in np.ndenumerate(phi_list):
            vec_surf[:, i_theta, i_phi] = np.array(
                [
                    np.sin(theta) * np.cos(phi),
                    np.sin(theta) * np.sin(phi),
                    np.cos(theta),
                ]
            ).reshape(3, 1)

    k_part = np.diff(k_list, append=k_list[-1]) / k_list
    theta_part = np.diff(theta_list, append=theta_list[-1]) * np.sin(theta_list)
    phi_part = np.diff(phi_list, append=phi_list[-1])

    # Quantities
    scalar_product = (
        np.einsum(
            "a,b,c,d,kcd,ijk,baicd,bajcd->",
            [1, -1],
            k_part,
            theta_part,
            phi_part,
            vec_surf,
            levi_civita,
            np.conj(field1),
            field2,
            optimize=True,
        )
        * factor
    )

    return scalar_product

def get_displaced_point(r_sphere, theta, phi, displacement_vec):
    x_displaced = r_sphere * np.sin(theta) * np.cos(phi) + displacement_vec[0]
    y_displaced = r_sphere * np.sin(theta) * np.sin(phi) + displacement_vec[1]
    z_displaced= r_sphere * np.cos(theta) + displacement_vec[2]
    
    r_displaced = np.sqrt( x_displaced**2 + y_displaced**2 + z_displaced**2 )
    theta_displaced = np.arctan2(np.sqrt(x_displaced**2 + y_displaced**2), z_displaced)
    phi_displaced = np.arctan2(y_displaced, x_displaced)
    
    return r_displaced, theta_displaced, phi_displaced

def get_field_at_displaced_sphere(rep_wf_am, radtype, num_theta, num_phi, r_sphere,  displacement_vec):
    ''' Electric field in SI
    '''
    
    num_3D = 3
    num_helicity = 2
   
    theta_list = np.linspace(0, np.pi, num_theta)
    phi_list = np.linspace(0, 2*np.pi, num_phi)
    k_list = rep_wf_am.k_list
  
    
    field_vals = np.zeros((len(k_list), num_helicity, num_3D, num_theta, num_phi), dtype=complex)
    for i_theta, theta in np.ndenumerate(theta_list):
        for i_phi, phi in np.ndenumerate(phi_list):
            field_vals[:,:,:, i_theta, i_phi] = get_field_at_displaced_sphere_point(rep_wf_am, radtype, theta, phi, r_sphere, displacement_vec)[:,:,:,np.newaxis]
            
    return field_vals
    
    
def get_field_at_displaced_sphere_point(rep_wf_am, radtype, theta, phi, r_sphere, displacement_vec):
    '''Electric field in (k, lam, r) in SI
    '''
    max_J = rep_wf_am.vals.shape[2] 
    ket_type = radtype

    k_list = rep_wf_am.k_list
    
    r_displaced, theta_displaced, phi_displaced = get_displaced_point(r_sphere, theta, phi, displacement_vec)

    L_list = np.arange(0, max_J + 2)
    m_list = np.arange(-max_J, max_J + 1)

    trm_k = np.square(k_list)

    bes = np.zeros((len(L_list), len(k_list)), dtype=complex)
    
    for i_k, k in np.ndenumerate(k_list):
        if ket_type == "j":
            bes[:, i_k] = spherical_jn(L_list, k * r_displaced).reshape(max_J + 2, 1)
        if ket_type == "h-" and r_displaced > 0:
            bes[:, i_k] = 0.5 * (
                spherical_jn(L_list, k * r_displaced) - 1j * spherical_yn(L_list, k * r_displaced)
            ).reshape(max_J + 2, 1)
        if ket_type == "h+" and r_displaced > 0:
            bes[:, i_k] = 0.5 * (
                spherical_jn(L_list, k * r_displaced) + 1j * spherical_yn(L_list, k * r_displaced)
            ).reshape(max_J + 2, 1)

    cg_1_JL = np.zeros((max_J, 2, len(L_list)), dtype=complex)
    for J in range(1, max_J + 1):
        i_J = J - 1
        trm_J = 1 / np.sqrt(2 * J + 1)
        for i_lam, lam in np.ndenumerate(np.array([1, -1])):
            for L in [J - 1, J, J + 1]:
                trm_L = (2 * L + 1) * (1j) ** L
                cg_1_JL[i_J, i_lam, L] = (
                    clebsch_gordan(L, 0, 1, lam, J, lam) * trm_J * trm_L
                )
    polariz_is = np.zeros((3, 3), dtype=complex)
    for i_s, s in np.ndenumerate(np.array([-1, 0, 1])):
        polariz_is[:, i_s] = af.get_polarizvec_z(s).reshape(3, 1)

    small_wig = np.zeros((max_J + 2, 2 * max_J + 1, 3), dtype=float)
    for L in L_list:
        for i_m, m in np.ndenumerate(np.arange(-max_J, max_J + 1)):
            for i_s, s in np.ndenumerate(np.array([-1, 0, 1])):
                if abs(m - s) <= L:
                    small_wig[L, i_m, i_s] = sp.wignersmalld(L, m - s, 0, theta_displaced)

    exp_msphi = np.zeros((len(m_list), 3), dtype=complex)
    for i_m, m in np.ndenumerate(m_list):
        for i_s, s in np.ndenumerate(np.array([-1, 0, 1])):
            exp_msphi[i_m, i_s] = np.exp(1j * (m - s) * phi_displaced)

    cg_2 = np.zeros((max_J, len(m_list), len(L_list), 3), dtype=complex)
    for J in range(1, max_J + 1):
        i_J = J - 1
        for m in range(-J, J + 1):
            i_m = m + max_J
            for L in [J - 1, J, J + 1]:
                for i_s, s in np.ndenumerate(np.array([-1, 0, 1])):
                    if abs(m - s) <= L:
                        cg_2[i_J, i_m, L, i_s] = clebsch_gordan(
                            L, m - s, 1, s, J, m
                        )  # (j1, m1, j2, m2, j3, m3)
                        
    units_factor = np.sqrt(C_0_SI * H_BAR_SI / EPSILON_0_SI) * 10**9 / (2 * np.pi) * np.sqrt(2*np.pi)

    res = (
        np.einsum(
            "kajm,lk,jal,is,lms,ms,jmls,k->kai",
            rep_wf_am.vals,
            bes,
            cg_1_JL,
            polariz_is,
            small_wig,
            exp_msphi,
            cg_2,
            trm_k,
            optimize=True,
        )
        * units_factor
    )

    return res 