"""
@author: Maxim Vavilin maxim.vavilin@kit.edu
"""
# from myfuncs import *

import os
import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from wigners import wigner_3j, clebsch_gordan

import repscat as rs

from repscat.fields.spherical_pulse import get_spherical_pulse_wavefunction
import surface_formula as sf
import vis


def create_reps(jay_max, region_min, width_time, num_k):
    norm_fac = 2e10
    info_rep1 = {
        "center_wavelength": 800,  # nm
        "jay": 3,
        "m": 3,
        "lam": 1,
    }

    info_rep2 = {
        "center_wavelength": 400,  # nm
        "jay": 2,
        "m": -2,
        "lam": -1,
    }

    k1_list = rs.get_k_list_gaussian(info_rep1["center_wavelength"], width_time, num_k)
    k2_list = rs.get_k_list_gaussian(info_rep2["center_wavelength"], width_time, num_k)
    k_list = rs.get_common_domain(k1_list, k2_list)

    vals = np.zeros((num_k, 2, jay_max, 2 * jay_max + 1), dtype=complex)
    vals[
        :, 0, info_rep1["jay"] - 1, jay_max + info_rep1["m"]
    ] = norm_fac * rs.gaussian_wavenumber(
        k_list, 2 * np.pi / info_rep1["center_wavelength"], width_time
    )
    vals[
        :, 1, info_rep2["jay"] - 1, jay_max + info_rep2["m"]
    ] = norm_fac * rs.gaussian_wavenumber(
        k_list, 2 * np.pi / info_rep2["center_wavelength"], width_time
    )

    outgoing = rs.WaveFunctionAngularMomentum(k_list, vals)

    info_rep_g = {
        "center_wavelength": 600,  # nm
        "jay": 3,
        "m": 3,
        "lam": 1,
    }
    vals_g = np.zeros_like(vals)
    vals_g[
        :, 0, info_rep_g["jay"] - 1, jay_max + info_rep_g["m"]
    ] = norm_fac * rs.gaussian_wavenumber(
        k_list, 2 * np.pi / info_rep_g["center_wavelength"], width_time
    )
    g_wavefunction = rs.WaveFunctionAngularMomentum(k_list, vals_g)

    return outgoing, g_wavefunction


def create_displaced_outgoing(jay_max, region_min, width_time, num_k):
    norm_fac = 2e10
    info_rep1 = {
        "center_wavelength": 800,  # nm
        "jay": 3,
        "m": 3,
        "lam": 1,
    }

    info_rep2 = {
        "center_wavelength": 400,  # nm
        "jay": 2,
        "m": -2,
        "lam": -1,
    }

    k1_list = rs.get_k_list_gaussian(info_rep1["center_wavelength"], width_time, num_k)
    k2_list = rs.get_k_list_gaussian(info_rep2["center_wavelength"], width_time, num_k)
    k_list = rs.get_common_domain(k1_list, k2_list)

    vals = np.zeros((num_k, 2, jay_max, 2 * jay_max + 1), dtype=complex)
    vals[
        :, 0, info_rep1["jay"] - 1, jay_max + info_rep1["m"]
    ] = norm_fac * rs.gaussian_wavenumber(
        k_list, 2 * np.pi / info_rep1["center_wavelength"], width_time
    )
    vals[
        :, 1, info_rep2["jay"] - 1, jay_max + info_rep2["m"]
    ] = norm_fac * rs.gaussian_wavenumber(
        k_list, 2 * np.pi / info_rep2["center_wavelength"], width_time
    )

    outgoing = rs.WaveFunctionAngularMomentum(k_list, vals)


    return outgoing


##############################################################################
### Compute scalar product with new formula over sphere
# max_r = 50000
# # num_theta = 200
# # num_phi = 200
# # sphere_params = {
# # 'surface_radius': max_r/2,
# # 'num_theta': num_theta,
# # 'num_phi': num_phi
# # }
# # # print(sf.get_quantity_surface_sphere(outgoing, sphere_params))
# # print(sf.get_scalar_product_surface_sphere(outgoing, g_field, sphere_params))

# cube_params = {
# 'cube_side_length': max_r/2, # half of length actually
# # 'num_theta': num_theta,
# # 'num_phi': num_phi
# }
# # print(sf.get_quantity_surface_cube(outgoing, cube_params))
# # print('out out ',sf.get_scalar_product_surface_cube(outgoing, outgoing, cube_params))
# # print('out reg ',sf.get_scalar_product_surface_cube(outgoing, g_field_j, cube_params))
# print('out out ',sf.get_scalar_product_surface_cube(outgoing, g_field_hplus, cube_params))
# print('out in ', sf.get_scalar_product_surface_cube(outgoing, g_field_hminus, cube_params))


#### Get representations
# Rep params
max_jay = 3
region_min = 200
num_k = 200


# Plot params
max_r = 5000
num_theta = 400  # 300
num_phi = 3 #200
num_r = 500  # 300

width_time = 2
# t1_list = [60, 100, 140]
t2_list = [6, 9, 12]
# t3_list = [5.8, 6.0, 6.2]

outgoing, g_wavefunc = create_reps(max_jay, region_min, width_time, num_k)

### Plot energy density
# 1) 5000
plot_params = {
    "num_r": num_r,
    "num_theta": num_theta,
    "num_phi": num_phi,
    "t_list": t2_list,
    "max_r": max_r,
    "what_to_plot": "E-density",
    "region_min": region_min,
    "region": [region_min, np.infty]
}

# rs.plot_field_from_wf_am(outgoing, plot_params, "h+", draw_patch=True)

# # ## 2) 50000
# # max_r = 5000
# # plot_params = {
# #     "num_r": num_r,
# #     "num_theta": num_theta,
# #     "num_phi": num_phi,
# #     "t_list": t1_list,
# #     "max_r": max_r,
# #     "what_to_plot": "E-density",
# #     "region_min": region_min
# # }
# # plot_params_list.append(plot_params)

# # rs.plot_field_from_wf_am(outgoing, plot_params, draw_patch=True, log10=log10)


# # # 3) 500
# # max_r = 500 #500
# # plot_params = {
# #     "num_r": num_r,
# #     "num_theta": num_theta,
# #     "num_phi": num_phi,
# #     "t_list": t3_list,
# #     "max_r": max_r,
# #     "what_to_plot": "E-density",
# #     "region_min": region_min
# # }
# # plot_params_list.append(plot_params)

# # # vis.plot_field_from_wf_am_table(outgoing, plot_params_list, draw_patch=True, log10=log10)

######## Compute quantities with coefficients
# print('Reference values:')
# print('Number of photons: ', outgoing.photons())
# print('Helicity: ', outgoing.helicity())
# print('Energy: ',outgoing.energy())

######## Compute quantities with new formula over sphere
# sphere_params = {
# 'surface_radius': max_r/2,
# 'num_theta': num_theta,
# 'num_phi': num_phi,
# "region": [region_min, np.infty],
# }
# print(sf.get_quantity_surface_sphere(outgoing, 'h+', sphere_params))

######## Compute quantities with new formula over cube
# cube_params = {
# 'cube_side_length': max_r/2, ## half of it
# "region": [region_min, np.infty],
# }
# print('Quantities over cube:')
# print(sf.get_quantity_surface_cube(outgoing, 'h+', cube_params))


###### Compute scalar product with new formula over sphere
# max_r = 5000
# num_theta = 400
# num_phi = 200
# sphere_params = {
#     "surface_radius": max_r / 2,
#     "num_theta": num_theta,
#     "num_phi": num_phi,
#     "region": [region_min, np.infty],
# }

# print('Scalar products over sphere:')
# print('out reg ',sf.get_scalar_product_surface_sphere(outgoing, 'h+', g_wavefunc, 'j', sphere_params))
# print('out out ',sf.get_scalar_product_surface_sphere(outgoing, 'h+', g_wavefunc, 'h+', sphere_params))
# print('out in ', sf.get_scalar_product_surface_sphere(outgoing, 'h+', g_wavefunc, 'h-', sphere_params))

# cube_params = {
#     "cube_side_length": max_r / 2,  # half of length actually
#     "region": [region_min, np.infty],
# }

# print('Scalar products over cube:')
# print("out reg ", sf.get_scalar_product_surface_cube(outgoing, 'h+',g_wavefunc, 'j',cube_params))
# print("out out ", sf.get_scalar_product_surface_cube(outgoing,'h+', g_wavefunc, 'h+',cube_params))
# print("out in ", sf.get_scalar_product_surface_cube(outgoing,'h+', g_wavefunc,'h-', cube_params))
