"""
Created on Mon Mar 20 11:27:25 2023
@author: Maxim Vavilin maxim.vavilin@kit.de
"""

import numpy as np
from scipy.integrate import quad
from wigners import wigner_3j


import repscat as rs


def get_spherical_mesh(x_list, y_list, z_list):
    x, y, z = np.meshgrid(x_list, y_list, z_list, indexing="ij")
    r_array = np.sqrt(x**2 + y**2 + z**2)
    theta_array = np.arctan2(np.sqrt(x**2 + y**2), z)
    phi_array = np.arctan2(y, x)
    # return np.squeeze(r_array), np.squeeze(theta_array), np.squeeze(phi_array)
    return r_array, theta_array, phi_array


def get_field_on_cube_side(repr, radtype, cube_side, region):
    x_list = cube_side[0]
    y_list = cube_side[1]
    z_list = cube_side[2]

    r_array, theta_array, phi_array = get_spherical_mesh(x_list, y_list, z_list)
    k_list = repr.k_list
    field = np.zeros(
        (len(k_list), 2, 3, len(x_list), len(y_list), len(z_list)), dtype=complex
    )
    for i_x, x in np.ndenumerate(x_list):
        for i_y, y in np.ndenumerate(y_list):
            for i_z, z in np.ndenumerate(z_list):
                r = r_array[i_x, i_y, i_z]
                theta = theta_array[i_x, i_y, i_z]
                phi = phi_array[i_x, i_y, i_z]
                spacetime_domain = {
                    "r_list": np.array([r]),
                    "theta_list": np.array([theta]),
                    "phi_list": np.array([phi]),
                    "region": region,
                }
                field[:, :, :, i_x, i_y, i_z] = np.expand_dims(
                    np.squeeze(rs.get_field_of_k_from_wf_am(repr, radtype, spacetime_domain)),
                    axis=-1,
                )
    return field


def integrate_side(repr, radtype, cube_side, norm_vec, region):
    k_list = repr.k_list
    # Get side of cube as 2D situation
    ab_list = []
    for arr in cube_side:
        if len(arr) == 1:
            side_dist = arr[0]
        else:
            ab_list.append(arr)

    a_list = ab_list[0]
    b_list = ab_list[1]

    # Get field in (k, lam, 3, x, y, z) on the 2D side axb
    field = np.squeeze(get_field_on_cube_side(repr, radtype, cube_side, region))

    # Parts for einsum
    factor = (
        2  # I use E-field instead of RS vectors
        * (-1j)
        * 1e-18
        * rs.EPSILON_0_SI  # Different def of RS vecs
        / (rs.H_BAR_SI * rs.C_0_SI)
    )

    levi_civita = np.zeros((3, 3, 3), dtype=int)
    for i in range(1, 4):
        for j in range(1, 4):
            for k in range(1, 4):
                levi_civita[i - 1, j - 1, k - 1] = (i - j) * (j - k) * (k - i) / 2

    k_part = np.diff(k_list, append=k_list[-1]) / k_list
    a_part = np.diff(a_list, append=a_list[-1])
    b_part = np.diff(b_list, append=b_list[-1])

    # Quantities
    photons = (
        np.einsum(
            "l,k,a,b,h,fgh,klfab,klgab->",
            [1, -1],
            k_part,
            a_part,
            b_part,
            norm_vec,
            levi_civita,
            np.conj(field),
            field,
            optimize=True,
        )
        * factor
    )

    helicity = (
        np.einsum(
            "k,a,b,h,fgh,klfab,klgab->",
            k_part,
            a_part,
            b_part,
            norm_vec,
            levi_civita,
            np.conj(field),
            field,
            optimize=True,
        )
        * factor
        * rs.H_BAR_SI
    )
    energy = (
        np.einsum(
            "l,k,a,b,h,fgh,klfab,klgab->",
            [1, -1],
            k_part * k_list,
            a_part,
            b_part,
            norm_vec,
            levi_civita,
            np.conj(field),
            field,
            optimize=True,
        )
        * factor
        * rs.C_0_SI
        * rs.H_BAR_SI
        * 1e9
    )

    return energy, helicity, photons

def integrate_side_scalar_product(repr1, radtype1, repr2, radtype2, cube_side, norm_vec ,region):
    k_list = repr1.k_list
    # Get side of cube as 2D situation
    ab_list = []
    for arr in cube_side:
        if len(arr) == 1:
            side_dist = arr[0]
        else:
            ab_list.append(arr)

    a_list = ab_list[0]
    b_list = ab_list[1]

    # Get field in (k, lam, 3, x, y, z) on the 2D side axb
    field1 = np.squeeze(get_field_on_cube_side(repr1, radtype1, cube_side, region))
    field2 = np.squeeze(get_field_on_cube_side(repr2, radtype2, cube_side, region))

    # Parts for einsum
    factor = (
        2  # I use E-field instead of RS vectors
        * (-1j)
        * 1e-18
        * rs.EPSILON_0_SI  # Different def of RS vecs
        / (rs.H_BAR_SI * rs.C_0_SI)
    )

    levi_civita = np.zeros((3, 3, 3), dtype=int)
    for i in range(1, 4):
        for j in range(1, 4):
            for k in range(1, 4):
                levi_civita[i - 1, j - 1, k - 1] = (i - j) * (j - k) * (k - i) / 2

    k_part = np.diff(k_list, append=k_list[-1]) / k_list
    a_part = np.diff(a_list, append=a_list[-1])
    b_part = np.diff(b_list, append=b_list[-1])

    # Quantities
    photons = (
        np.einsum(
            "l,k,a,b,h,fgh,klfab,klgab->",
            [1, -1],
            k_part,
            a_part,
            b_part,
            norm_vec,
            levi_civita,
            np.conj(field1),
            field2,
            optimize=True,
        )
        * factor
    )

    return photons

def get_scalar_product_surface_cube(repr1, radtype1, repr2, radtype2, cube_params):
    cube_side_half_length = cube_params["cube_side_half_length"]
    num_points = cube_params["num_points"]

    x_list = np.linspace(-cube_side_half_length, cube_side_half_length, num_points)
    y_list = np.linspace(-cube_side_half_length, cube_side_half_length, num_points)
    z_list = np.linspace(-cube_side_half_length, cube_side_half_length, num_points)

    cube_sides = []
    norm_vecs = []
    for sgn in [1.0, -1.0]:
        cube_sides.append([x_list, y_list, np.array([sgn * cube_side_half_length])])
        cube_sides.append([x_list, np.array([sgn * cube_side_half_length]), z_list])
        cube_sides.append([np.array([sgn * cube_side_half_length]), y_list, z_list])
        norm_vecs.append(np.array([0, 0, sgn]))
        norm_vecs.append(np.array([0, sgn, 0]))
        norm_vecs.append(np.array([sgn, 0, 0]))

    photons=0
    for i in range(6):
        photons += integrate_side_scalar_product(repr1, radtype1, repr2, radtype2, cube_sides[i], norm_vecs[i], cube_params['region'])

    return photons


def get_quantity_surface_cube(repr, radtype, surface_params):
    cube_side_length = surface_params["cube_side_half_length"]
    num_points = surface_params["num_points"]
    
    x_list = np.linspace(-cube_side_length, cube_side_length, num_points)
    y_list = np.linspace(-cube_side_length, cube_side_length, num_points)
    z_list = np.linspace(-cube_side_length, cube_side_length, num_points)

    cube_sides = []
    norm_vecs = []
    for sgn in [1.0, -1.0]:
        cube_sides.append([x_list, y_list, np.array([sgn * cube_side_length])])
        cube_sides.append([x_list, np.array([sgn * cube_side_length]), z_list])
        cube_sides.append([np.array([sgn * cube_side_length]), y_list, z_list])
        norm_vecs.append(np.array([0, 0, sgn]))
        norm_vecs.append(np.array([0, sgn, 0]))
        norm_vecs.append(np.array([sgn, 0, 0]))

    energy = 0
    helicity = 0
    photons = 0
    for i in range(6):
        e, h, p = integrate_side(repr, radtype, cube_sides[i], norm_vecs[i], surface_params['region'])
        energy += e
        helicity += h
        photons += p

    return {
        "Energy": energy,
        "Helicity": helicity,
        "Photons": photons,
    }

def get_scalar_product_surface_sphere(repr1, radtype1, repr2, radtype2, surface_params):
    # Check domains
    if repr1.vals.shape != repr2.vals.shape:
        raise ValueError(f'Shapes of values of reresentations do not coinside: {repr1.vals.shape} and {repr2.vals.shape}')
    if np.array_equal(repr1.k_list, repr2.k_list) == False:
        raise ValueError(f'Domains of representations do not coinside)')
    
    surface_radius = surface_params["surface_radius"]
    num_theta = surface_params["num_theta"]
    num_phi = surface_params["num_phi"]

    # Dimensional constant
    factor = (
        2  # I use E-field instead of RS vectors
        * (-1j)
        * surface_radius**2
        * 1e-18
        * rs.EPSILON_0_SI  # Different def of RS vecs
        / (rs.H_BAR_SI * rs.C_0_SI)
    )

    theta_list = np.linspace(0, np.pi, num_theta)
    phi_list = np.linspace(0, 2 * np.pi, num_phi)
    k_list = repr1.k_list
    params = {
        "num_r": 1,
        "num_theta": num_theta,
        "num_phi": num_phi,
        "t_list": None,
        "max_r": surface_radius,
        "what_to_plot": "E-density",
        "region":surface_params['region']
    }
    spacetime_domain = rs.visuals.get_spacetime_domain(params)
    spacetime_domain["r_list"] = [surface_radius]

    # Electromagnetic field in (lam, k, \bm r)
    field1 = np.squeeze(rs.get_field_of_k_from_wf_am(repr1, radtype1, spacetime_domain))
    field2 = np.squeeze(rs.get_field_of_k_from_wf_am(repr2, radtype2, spacetime_domain))

    # Parts for einsum
    levi_civita = np.zeros((3, 3, 3), dtype=int)
    for i in range(1, 4):
        for j in range(1, 4):
            for k in range(1, 4):
                levi_civita[i - 1, j - 1, k - 1] = (i - j) * (j - k) * (k - i) / 2

    vec_surf = np.zeros((3, len(theta_list), len(phi_list)), dtype=complex)
    for i_theta, theta in np.ndenumerate(theta_list):
        for i_phi, phi in np.ndenumerate(phi_list):
            vec_surf[:, i_theta, i_phi] = np.array(
                [
                    np.sin(theta) * np.cos(phi),
                    np.sin(theta) * np.sin(phi),
                    np.cos(theta),
                ]
            ).reshape(3, 1)

    k_part = np.diff(k_list, append=k_list[-1]) / k_list
    theta_part = np.diff(theta_list, append=theta_list[-1]) * np.sin(theta_list)
    phi_part = np.diff(phi_list, append=phi_list[-1])

    # Quantities
    scalar_product = (
        np.einsum(
            "a,b,c,d,kcd,ijk,baicd,bajcd->",
            [1, -1],
            k_part,
            theta_part,
            phi_part,
            vec_surf,
            levi_civita,
            np.conj(field1),
            field2,
            optimize=True,
        )
        * factor
    )

    return scalar_product

def get_quantity_surface_sphere(repr, radtype, surface_params):
    surface_radius = surface_params["surface_radius"]
    num_theta = surface_params["num_theta"]
    num_phi = surface_params["num_phi"]
    region = surface_params["region"]

    # Dimensional constant
    factor = (
        2  # I use E-field instead of RS vectors
        * (-1j)
        * surface_radius**2
        * 1e-18
        * rs.EPSILON_0_SI  # Different def of RS vecs
        / (rs.H_BAR_SI * rs.C_0_SI)
    )

    theta_list = np.linspace(0, np.pi, num_theta)
    phi_list = np.linspace(0, 2 * np.pi, num_phi)
    k_list = repr.k_list
    
    params = {
        "num_r": 1,
        "num_theta": num_theta,
        "num_phi": num_phi,
        "t_list": [6, 10, 14],
        "max_r": surface_radius,
        "what_to_plot": "E-density",
        "region": region,
    }
    spacetime_domain = rs.visuals.get_spacetime_domain(params)
    spacetime_domain["r_list"] = [surface_radius]
    spacetime_domain["radial_function"] = 'h+'

    # Electromagnetic field in (lam, k, \bm r)
    field = np.squeeze(rs.get_field_of_k_from_wf_am(repr, radtype, spacetime_domain))

    # Parts for einsum
    levi_civita = np.zeros((3, 3, 3), dtype=int)
    for i in range(1, 4):
        for j in range(1, 4):
            for k in range(1, 4):
                levi_civita[i - 1, j - 1, k - 1] = (i - j) * (j - k) * (k - i) / 2

    vec_surf = np.zeros((3, len(theta_list), len(phi_list)), dtype=complex)
    for i_theta, theta in np.ndenumerate(theta_list):
        for i_phi, phi in np.ndenumerate(phi_list):
            vec_surf[:, i_theta, i_phi] = np.array(
                [
                    np.sin(theta) * np.cos(phi),
                    np.sin(theta) * np.sin(phi),
                    np.cos(theta),
                ]
            ).reshape(3, 1)

    k_part = np.diff(k_list, append=k_list[-1]) / k_list
    theta_part = np.diff(theta_list, append=theta_list[-1]) * np.sin(theta_list)
    phi_part = np.diff(phi_list, append=phi_list[-1])

    # Quantities
    photons = (
        np.einsum(
            "a,b,c,d,kcd,ijk,baicd,bajcd->",
            [1, -1],
            k_part,
            theta_part,
            phi_part,
            vec_surf,
            levi_civita,
            np.conj(field),
            field,
            optimize=True,
        )
        * factor
    )
    helicity = (
        np.einsum(
            "b,c,d,kcd,ijk,baicd,bajcd->",
            k_part,
            theta_part,
            phi_part,
            vec_surf,
            levi_civita,
            np.conj(field),
            field,
            optimize=True,
        )
        * factor
        * rs.H_BAR_SI
    )
    energy = (
        np.einsum(
            "a,b,c,d,kcd,ijk,baicd,bajcd->",
            [1, -1],
            k_part * k_list,
            theta_part,
            phi_part,
            vec_surf,
            levi_civita,
            np.conj(field),
            field,
            optimize=True,
        )
        * factor
        * rs.C_0_SI
        * rs.H_BAR_SI
        * 1e9
    )

    quantities = {
        "Energy": energy,
        "Helicity": helicity,
        "Photons": photons,
    }

    return quantities
