"""
@author: Maxim Vavilin maxim.vavilin@kit.edu
"""

import numpy as np

import repscat as rs
import surface_formula as sf

from compute_quantities_in_field_over_shifted_sphere import compute_quantities_in_field_over_shifted_sphere
from compute_scalar_product_over_shifted_sphere import compute_scalar_product_over_shifted_sphere

def create_reps(jay_max, width_time, num_k):
    norm_fac = 2e10
    info_rep1 = {
        "center_wavelength": 800,  # nm
        "jay": 3,
        "m": 3,
        "lam": 1,
    }

    info_rep2 = {
        "center_wavelength": 400,  # nm
        "jay": 2,
        "m": -2,
        "lam": -1,
    }

    k1_list = rs.get_k_list_gaussian(info_rep1["center_wavelength"], width_time, num_k)
    k2_list = rs.get_k_list_gaussian(info_rep2["center_wavelength"], width_time, num_k)
    k_list = rs.get_common_domain(k1_list, k2_list)

    vals = np.zeros((num_k, 2, jay_max, 2 * jay_max + 1), dtype=complex)
    vals[
        :, 0, info_rep1["jay"] - 1, jay_max + info_rep1["m"]
    ] = norm_fac * rs.gaussian_wavenumber(
        k_list, 2 * np.pi / info_rep1["center_wavelength"], width_time
    )
    vals[
        :, 1, info_rep2["jay"] - 1, jay_max + info_rep2["m"]
    ] = norm_fac * rs.gaussian_wavenumber(
        k_list, 2 * np.pi / info_rep2["center_wavelength"], width_time
    )

    outgoing = rs.WaveFunctionAngularMomentum(k_list, vals)

    info_rep_g = {
        "center_wavelength": 600,  # nm
        "jay": 3,
        "m": 3,
        "lam": 1,
    }
    vals_g = np.zeros_like(vals)
    vals_g[
        :, 0, info_rep_g["jay"] - 1, jay_max + info_rep_g["m"]
    ] = norm_fac * rs.gaussian_wavenumber(
        k_list, 2 * np.pi / info_rep_g["center_wavelength"], width_time
    )
    g_wavefunction = rs.WaveFunctionAngularMomentum(k_list, vals_g)

    return outgoing, g_wavefunction


def plot_energy_density(outgoing, max_r, region_min):
    plot_params = {
        "num_r": 500,
        "num_theta": 400,
        "num_phi": 3,
        "t_list": [6, 9, 12],
        "max_r": max_r,
        "what_to_plot": "E-density",
        "region_min": region_min,
        "region": [region_min, np.infty]
    }
    rs.plot_field_from_wf_am(outgoing, plot_params, "h+", draw_patch=True)

def compute_reference_quantities_in_field(outgoing):
    print('Reference values in |f>:')
    print('Number of photons: ', np.real(outgoing.photons()))
    print('Helicity: ', np.real(outgoing.helicity()))
    print('Energy: ', np.real(outgoing.energy()), '\n')
    
def compute_quantities_in_field_over_sphere(outgoing, sphere_params):
    res = sf.get_quantity_surface_sphere(outgoing, 'h+', sphere_params)
    
    print('Quantities in |f> over sphere:')
    print('Number of photons: ', np.real(res['Photons']))
    print('Helicity: ', np.real(res['Helicity']))
    print('Energy: ', np.real(res['Energy']), '\n')
    
def compute_quantities_in_field_over_cube(outgoing, cube_params):
    res = sf.get_quantity_surface_cube(outgoing, 'h+', cube_params)
    print('Quantities in |f> over cube:')
    print('Number of photons: ', np.real(res['Photons']))
    print('Helicity: ', np.real(res['Helicity']))
    print('Energy: ', np.real(res['Energy']), '\n')

    
def compute_scalar_product_over_sphere(outgoing, g_wavefunc, sphere_params):
    print('Scalar products over sphere:')
    print('out reg ', sf.get_scalar_product_surface_sphere(outgoing, 'h+', g_wavefunc, 'j', sphere_params))
    print('out out ', sf.get_scalar_product_surface_sphere(outgoing, 'h+', g_wavefunc, 'h+', sphere_params))
    print('out in ', sf.get_scalar_product_surface_sphere(outgoing, 'h+', g_wavefunc, 'h-', sphere_params))
    
def compute_scalar_product_over_cube(outgoing, g_wavefunc, cube_params):
    print('Scalar products over cube:')
    print("out reg ", sf.get_scalar_product_surface_cube(outgoing, 'h+', g_wavefunc, 'j', cube_params))
    print("out out ", sf.get_scalar_product_surface_cube(outgoing, 'h+', g_wavefunc, 'h+', cube_params))
    print("out in ", sf.get_scalar_product_surface_cube(outgoing, 'h+', g_wavefunc, 'h-', cube_params))
    
    
def compute_quantities_in_field(outgoing, sphere_params, cube_params, sphere_displacement):
    compute_reference_quantities_in_field(outgoing)
    compute_quantities_in_field_over_sphere(outgoing, sphere_params)
    compute_quantities_in_field_over_shifted_sphere(outgoing, sphere_params, sphere_displacement)
    compute_quantities_in_field_over_cube(outgoing, cube_params)
    
def compute_scalar_product_of_fields(outgoing, g_wavefunc, sphere_params, cube_params, displacement):
    compute_scalar_product_over_sphere(outgoing, g_wavefunc, sphere_params)
    compute_scalar_product_over_shifted_sphere(outgoing, g_wavefunc, sphere_params, displacement)
    compute_scalar_product_over_cube(outgoing, g_wavefunc, cube_params)
    
def get_sphere_params(max_r, min_r):
    return {
    'surface_radius': max_r/2,
    'num_theta': 400,
    'num_phi': 200,
    "region": [min_r, np.infty],
    }
    
def get_cube_params(max_r, min_r):
    return {
    'cube_side_half_length': max_r/2,
    "region": [min_r, np.infty],
    "num_points": 200,
    }
    
if __name__ == "__main__":
    
    # Wave function params
    width_time = 2
    max_jay = 3
    num_k = 200
    
    # Screen size
    min_r = 200
    max_r = 5000
    
    outgoing, g_wavefunc = create_reps(max_jay, width_time, num_k)
    
    plot_energy_density(outgoing, max_r, min_r)
    
    sphere_params = get_sphere_params(max_r, min_r)
    cube_params = get_cube_params(max_r, min_r)
    
    sphere_displacement = 1500
    compute_quantities_in_field(outgoing, sphere_params, cube_params, sphere_displacement)
    compute_scalar_product_of_fields(outgoing, g_wavefunc, sphere_params, cube_params, sphere_displacement)
    

    