import numpy as np
import matplotlib.pyplot as plt
import treams.special as sp
from reptemscat.representation import Info
import reptemscat.aux_funcs as af
from reptemscat.fields.spherical_pulse import (
    WaveFunctionAngularMomentum as sph_pulse_wf_am,
)
from reptemscat.fields.plane_pulse import WaveFunctionAngularMomentum as plane_pulse_wf_am


def am_from_pw(rep_pw, max_J):
    if rep_pw.info.representation_name != "WF_PW":
        raise TypeError("Representation must be in plane wave basis")

    am_info = Info(
        rep_pw.info.field_name,
        "WF_AM",
        num_k=rep_pw.info.parameters["num_k"],  # 1/nm
        center_wavelength = rep_pw.info.parameters["center_wavelength"],
        max_J=max_J,
    )

    am_domain = {"k_list": rep_pw.domain["k_list"]}

    am_vals = am_from_pw_vals(
        rep_pw.vals,
        rep_pw.domain["k_list"],
        rep_pw.domain["eta_list"],
        rep_pw.domain["phi_list"],
        max_J,
    )

    return plane_pulse_wf_am(am_info, domain=am_domain, vals=am_vals)


def am_from_pw_vals(vals_pw, k_list, eta_list, phi_list, max_J):
    vals_am = np.zeros((2, len(k_list), (max_J + 1) ** 2 - 1), dtype=complex)

    idx_phi_mat = np.zeros(((max_J + 1) ** 2 - 1, len(phi_list)), dtype=complex)
    for i_phi, phi in np.ndenumerate(phi_list[:-1]):
        step_phi = phi_list[i_phi[0] + 1] - phi_list[i_phi]
        for m in range(-max_J, max_J + 1):
            m_phi_elem = np.exp(-1j * m * phi) * step_phi
            for J in range(max(1, np.abs(m)), max_J + 1):
                idx_phi_mat[af.get_idx(J, m), i_phi] = m_phi_elem

    lam_idx_eta_mat = np.zeros((2, (max_J + 1) ** 2 - 1, len(eta_list)), dtype=complex)
    for i_lam in [0, 1]:
        lam = 1 if i_lam == 0 else 0
        for i_eta, eta in np.ndenumerate(eta_list[:-1]):
            step_eta = eta_list[i_eta[0] + 1] - eta_list[i_eta]
            idx = 0
            for J in range(1, max_J + 1):
                step_eta_fac_J = np.sqrt((2 * J + 1) / (4 * np.pi)) * step_eta
                for m in range(-J, J + 1):
                    lam_idx_eta_mat[i_lam, idx, i_eta] = (
                        sp.wignersmalld(J, m, lam, np.arccos(eta)) * step_eta_fac_J
                    )
                    idx += 1

    vals_am = np.einsum("aci,cj,abij->abc", lam_idx_eta_mat, idx_phi_mat, vals_pw)

    return vals_am
