"""
@author: Maxim Vavilin maxim.vavilin@kit.edu
"""

import os
import numpy as np
import matplotlib.pyplot as plt
import copy
from wigners import wigner_3j

# from repscat.boost_tmat import boost_tmat_vals
import matplotlib as mpl

import repscat as rs

from boost_tmat_precise_pw import interpolate_tmat


class TmatrixPolyPrecise:
    """T-matrix in form (k k lam lam j j m m)"""

    def __init__(self, k1_lists, k2_list, vals, radius=None):
        """Domain should be either scalar k, or k_list or [k1_list, k2_list]"""
        # self._check_domain_vs_vals(k1_list.shape, k2_lists.shape, vals.shape)

        # self._k1_list = k1_list
        # self._k2_lists = k2_lists

        self._k1_lists = k1_lists
        self._k2_list = k2_list

        self._vals = vals

        self._radius = radius
        self._jay_max = vals.shape[4]

    def _check_domain_vs_vals(self, sh1, sh2, vshape):
        if len(vshape) != 8:
            raise ValueError("Bad shape of values")
        if not (sh1[0] == vshape[0] and sh2[1] == vshape[1]):
            raise ValueError("Domain shape does not correspond to values")

    @property
    def k1_lists(self):
        return self._k1_lists

    @property
    def k2_list(self):
        return self._k2_list

    @property
    def vals(self):
        return self._vals

    @property
    def radius(self):
        return self._radius

    @property
    def jay_max(self):
        return self._jay_max

    def _get_meshes(self):
        k1_mesh = self._k1_lists
        k2_mesh = np.zeros_like(k1_mesh)
        for i_k2, k2 in np.ndenumerate(self._k2_list):
            k1_list = np.squeeze(self._k1_lists[i_k2])
            for i_k1, k1 in np.ndenumerate(k1_list):
                k2_mesh[i_k1, i_k2] = k2
        return k1_mesh, k2_mesh

    def plot(self, *indices):
        i_lam1, i_lam2, i_jay1, i_jay2, i_m1, i_m2 = self._get_indices(*indices)

        k1_mesh, k2_mesh = self._get_meshes()
        plot_vals_real = np.squeeze(
            np.real(self._vals[:, :, i_lam1, i_lam2, i_jay1, i_jay2, i_m1, i_m2])
        )
        plot_vals_imag = np.squeeze(
            np.imag(self._vals[:, :, i_lam1, i_lam2, i_jay1, i_jay2, i_m1, i_m2])
        )

        maxval = max(np.max(np.abs(plot_vals_real)), np.max(np.abs(plot_vals_imag)))
        minval = -maxval
        fig, axs = plt.subplots(nrows=2, ncols=1)
        cmap = mpl.cm.seismic  # pylint: disable=no-member
        fig.subplots_adjust(
            bottom=0.1, top=0.9, left=0.1, right=0.8, wspace=0.4, hspace=0.1
        )
        cb_ax = fig.add_axes([0.83, 0.1, 0.02, 0.8])
        norm = mpl.colors.Normalize(vmin=minval, vmax=maxval)
        fig.colorbar(
            mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
            cax=cb_ax,
            orientation="vertical",
        )
        axs[0].pcolormesh(k1_mesh, k2_mesh, plot_vals_real, cmap=cmap, norm=norm)
        axs[0].set_ylabel("Real part")
        axs[1].pcolormesh(k1_mesh, k2_mesh, plot_vals_imag, cmap=cmap, norm=norm)
        axs[1].set_ylabel("Imaginary part")

    def _get_indices(self, *indices):
        i_lam1 = 0 if indices[0] == 1 else 0
        i_lam2 = 0 if indices[1] == 1 else 0
        i_jay1 = indices[2] - 1
        i_jay2 = indices[3] - 1
        i_m1 = indices[4] + self._jay_max
        i_m2 = indices[5] + self._jay_max
        return i_lam1, i_lam2, i_jay1, i_jay2, i_m1, i_m2

    def check_scatter(self, incident):
        if self._vals.shape[1::2] != incident.vals.shape:
            raise ValueError("Tmatrix shape does not correspond to incident values")
        if not np.allclose(self._k2_list, incident.k_list):
            raise ValueError("Tmatrix domain does not correspond to incident k_list")

    def scatter(self, incident, num_k1):
        k1_big_list = np.linspace(self._k1_lists[0, 0], self._k1_lists[-1, -1], num_k1)
        tmat_interpolated = np.zeros(
            (
                num_k1,
                len(self._k2_list),
                2,
                2,
                self.jay_max,
                self.jay_max,
                2 * self.jay_max + 1,
                2 * self.jay_max + 1,
            ),
            dtype=complex,
        )
        for i_k2, k2 in np.ndenumerate(self._k2_list):
            k1_list = np.squeeze(self._k1_lists[i_k2])
            tmat_interpolated[:, i_k2] = interpolate_tmat(
                k1_list, self._vals[:, i_k2], k1_big_list
            )

        measure = np.diff(self._k2_list, append=self._k2_list[-1]) * self._k2_list
        scattered_vals = np.einsum(
            "q,pqabijmn,qbjn->paim",
            measure,
            tmat_interpolated,
            incident.vals,
            optimize=True,
        )

        return rs.WaveFunctionAngularMomentum(k1_big_list, scattered_vals)

    def retrieve_diagonal(self):
        diag_vals = np.zeros(
            (
                len(self._k2_list),
                2,
                2,
                self._jay_max,
                self._jay_max,
                2 * self._jay_max + 1,
                2 * self._jay_max + 1,
            ),
            dtype=complex,
        )
        measure_arr = np.zeros((self._k1_lists.shape[0],self._k1_lists.shape[1]))
        for i_k2, k2 in np.ndenumerate(self._k2_list):
            k1_list = np.squeeze(self._k1_lists[i_k2])
            measure_arr[:,i_k2] = np.reshape(np.diff(k1_list, append=k1_list[-1]) * k1_list, (self._k1_lists.shape[0],1))

        # print(self._vals[:, 10, 0, 0, 0, 0, self._jay_max+1, self._jay_max+1])

        diag_vals = np.einsum(
            "pq,pqabijmn->qabijmn",
            measure_arr,
            self._vals,
            optimize=True,
        )
        return diag_vals

    # def compute_transfer(self, incident):
    #     scattered = self.scatter(incident)
    #     incident_interp = incident.interpolate(scattered.k_list)
    #     outgoing = incident_interp + scattered
    #     # outgoing.plot("--")
    #     # plt.show()

    #     return {
    #         "energy": incident.energy() - outgoing.energy(),
    #         "momentum_z": incident.momentum_z() - outgoing.momentum_z(),
    #     }
