# -*- coding: utf-8 -*-
"""
Created on Fri Mar 24 03:08:20 2023

@author: maxim
"""

import numpy as np
import treams.special as sp
from reptemscat.constants import C_0

# import quaternionic
# import spherical
from scipy.integrate import quad  # , dblquad


def complex_quadrature(func, a, b, **kwargs):
    def real_func(x, *args):
        return np.real(func(x, *args))

    def imag_func(x, *args):
        return np.imag(func(x, *args))

    real_integral = quad(real_func, a, b, **kwargs)
    imag_integral = quad(imag_func, a, b, **kwargs)
    return (
        real_integral[0] + 1j * imag_integral[0],
        real_integral[1:],
        imag_integral[1:],
    )


def integrand_eta(eta, J, m, lam, k, Dr2):
    return (
        np.exp(-(k**2) * (1 - eta**2) * Dr2)
        * np.abs(eta)
        * (1 + lam * eta)
        * sp.wignersmalld(J, m, lam, np.arccos(eta))
    )


def getWaveFunc(
            k_list,
            N_eta,
            N_phi,
            N_J,
            k0,
            dt,
            dr,
            mult_f,
            WF_incid_PW_name,
            WF_incid_AM_name,
            ifPositiveEta,
            ifPositiveHelicity,
        ):
    ### CONSTANTS
    N_k = np.shape(k_list)[0]

    Max_eta = 1
    Min_eta = -1
    eta_list = np.linspace(Min_eta, Max_eta, N_eta, endpoint=True)

    Min_phi = 0
    Max_phi = 2 * np.pi
    phi_list = np.linspace(Min_phi, Max_phi, N_phi, endpoint=True)

    Dt2C2 = dt**2 * C_0**2 / 4
    Dr2 = dr**2 / 4

    ##################################    Generate WF in PW basis
    Wfunc_PW_BothHelicities = np.zeros((2, N_k, N_eta, N_phi), complex)

    for idx, val in np.ndenumerate(Wfunc_PW_BothHelicities):
        lam = 1 if idx[0] == 0 else -1
        k = k_list[idx[1]]
        eta = eta_list[idx[2]]
        phi = phi_list[idx[3]]

        if ifPositiveEta == True:
            if eta < 0:
                Wfunc_PW_BothHelicities[idx] = 0
            else:
                Wfunc_PW_BothHelicities[idx] = (
                    mult_f
                    * np.exp(-((k - k0) ** 2) * Dt2C2)
                    * np.exp(-(k**2) * (1 - eta**2) * Dr2)
                    * np.exp(1j * phi)
                    * np.abs(eta)
                    * (1 + lam * eta)
                )
        else:
            Wfunc_PW_BothHelicities[idx] = (
                mult_f
                * np.exp(-((k - k0) ** 2) * Dt2C2)
                * np.exp(-(k**2) * (1 - eta**2) * Dr2)
                * np.exp(1j * phi)
                * np.abs(eta)
                * (1 + lam * eta)
            )

    if ifPositiveHelicity == True:
        Wfunc_PW_BothHelicities[
            1, :, :, :
        ] = 0  # Interesting: it does not contribute sinificantly

    # with open(WF_incid_PW_name, 'wb') as f:
    #     np.save(f, Wfunc_PW_BothHelicities)
    print("Generated WF in PW basis")

    ############################     Generate WF in AM basis
    Wfunc_AM_BothHelicities = np.zeros((2, N_k, N_J * (N_J + 2)), complex)
    for i_lam in range(2):
        lam = 1 if i_lam == 0 else -1
        for i_k, k in np.ndenumerate(k_list):
            if (i_k[0]) % (N_k // 10 + 1) == 0:
                print(f"{i_k[0]}/{N_k}")
            idx = 0
            for J in range(1, N_J + 1):
                for m in range(-J, J + 1):
                    if m == 1:  ## Integration over phi gives initial polarization
                        int_eta = (
                            np.sqrt((2 * J + 1) / (4 * np.pi))
                            * complex_quadrature(
                                integrand_eta,
                                Min_eta,
                                Max_eta,
                                args=(J, m, lam, k, Dr2, ifPositiveEta),
                            )[0]
                        )
                        Wfunc_AM_BothHelicities[i_lam, i_k, idx] = (
                            mult_f * np.exp(-((k - k0) ** 2) * Dt2C2) * int_eta
                        )
                    idx += 1
    if ifPositiveHelicity == True:
        Wfunc_AM_BothHelicities[1, :, :] = 0
    # with open(WF_incid_AM_name, 'wb') as f:
    #     np.save(f, Wfunc_AM_BothHelicities)

    print("Generated WF in AM basis")

    return Wfunc_PW_BothHelicities, Wfunc_AM_BothHelicities


def get_plane_pulse_wave_func_pw(
    k_list,
    N_eta,
    N_phi,
    center_wavelength,
    dt,
    dr,
    mult_f,
    ifPositiveHelicity,
    ifPositiveEta,
):
    #    Generates WF in PW basis
    k0 = 2*np.pi/center_wavelength
    N_k = np.shape(k_list)[0]
    Max_eta = 1
    Min_eta = -1
    eta_list = np.linspace(Min_eta, Max_eta, N_eta, endpoint=True)
    Min_phi = 0
    Max_phi = 2 * np.pi
    phi_list = np.linspace(Min_phi, Max_phi, N_phi, endpoint=True)
    Dt2C2 = dt**2 * C_0**2 / 2
    Dr2 = dr**2 / 2

    exp_k = np.exp(-((k_list - k0) ** 2) * Dt2C2)
    exp_k_eta = np.zeros((len(k_list), len(eta_list)), dtype = complex)
    for i_k, k in np.ndenumerate(k_list):
        for i_eta, eta in np.ndenumerate(eta_list):
            if not (ifPositiveEta and eta < 0):
                exp_k_eta[i_k, i_eta] = np.exp(-(k**2) * (1 - eta**2) * Dr2)
    exp_phi = np.exp(1j * phi_list)
    fac_lam_eta =  np.zeros( (2, len(eta_list)), dtype = complex)
    for i_lam in [0,1]:
        lam = 1 if i_lam == 0 else -1
        if not (ifPositiveHelicity and lam == -1):
            for i_eta, eta in np.ndenumerate(eta_list):
                if not (ifPositiveEta and eta < 0):
                    fac_lam_eta[i_lam, i_eta] = np.abs(eta) * (1 + lam * eta)

    wave_func_pw = np.einsum('b,bc,d,ac->abcd',exp_k,exp_k_eta,exp_phi,fac_lam_eta)*mult_f
    # wave_func_pw = np.zeros((2, N_k, N_eta, N_phi), complex)
    # for idx, val in np.ndenumerate(wave_func_pw):
    #     lam = 1 if idx[0] == 0 else -1
    #     k = k_list[idx[1]]
    #     eta = eta_list[idx[2]]
    #     phi = phi_list[idx[3]]
    #     if not (ifPositiveHelicity and lam == -1):
    #         if not (ifPositiveEta and eta < 0):
    #             wave_func_pw[idx] = (
    #                 mult_f
    #                 * np.exp(-((k - k0) ** 2) * Dt2C2)
    #                 * np.exp(-(k**2) * (1 - eta**2) * Dr2)
    #                 * np.exp(1j * phi)
    #                 * np.abs(eta)
    #                 * (1 + lam * eta)
    #             )

    print("Generated WF in PW basis")

    return wave_func_pw

def get_plane_pulse_wave_func_am(
    k_list,
    N_J,
    center_wavelength,
    width_time,
    width_space,
    mult_f,
    ifPositiveHelicity,
    ifPositiveEta,
):
    # Generate WF in AM basis

    k0 = 2 * np.pi / center_wavelength
    N_k = np.shape(k_list)[0]
    Max_eta = 1
    Min_eta = -1
    Dt2C2 = width_time**2 * C_0**2 / 2
    Dr2 = width_space**2 / 2

    wave_func_am = np.zeros((2, N_k, N_J * (N_J + 2)), complex)
    lam_range = [1] if ifPositiveHelicity else [1, -1]
    min_eta_param = 0 if ifPositiveEta else Min_eta
    for lam in lam_range:
        i_lam = 0 if lam == 1 else 1
        for i_k, k in np.ndenumerate(k_list):
            if (i_k[0]) % (N_k // 10 + 1) == 0:
                print(f"{i_k[0]}/{N_k}")
            idx = 0
            for J in range(1, N_J + 1):
                for m in range(-J, J + 1):
                    if m == 1:  ## Integration over phi gives initial polarization
                        int_eta = (
                            np.sqrt((2 * J + 1) / (4 * np.pi))
                            * complex_quadrature(
                                integrand_eta,
                                min_eta_param,
                                Max_eta,
                                args=(J, m, lam, k, Dr2),
                            )[0]
                        )
                        wave_func_am[i_lam, i_k, idx] = (
                            mult_f * np.exp(-((k - k0) ** 2) * Dt2C2) * int_eta
                        )
                    idx += 1
    print("Generated WF in AM basis")
    return wave_func_am
