import sys
from itertools import islice

import mpi4py.MPI as MPI

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.use('Agg')

import scipy.io
#import xlrd
from scipy import ndimage

sys.path.append("/public/home/weichengliang/lnData/CSST_new_framwork/csstPSF_20210108")
import PSFConfig as myConfig
import PSFUtil as myUtil

NPSF = 900
##############################
##############################
##############################
def test_psfREE80(psfPath, ThisTask, NTasks):
    nccd = 30
    npsf = NPSF

    npsfPerTasks = int(npsf/NTasks)
    iStart= 0 + npsfPerTasks*ThisTask
    iEnd  = npsfPerTasks + npsfPerTasks*ThisTask
    if ThisTask == NTasks:
        iEnd = npsf

    CENPIXUSED  = True
    wvREE80 = np.zeros([4, nccd])  #psf in different waves-4
    ttREE80 = np.zeros(nccd)       #stacked psf

    for iccd in range(1, nccd+1):
        psf_wvREE80 = np.zeros([4, npsf])
        psf_ttREE80 = np.zeros(npsf)

        #for ipsf in range(1, npsf+1):
        for ipsf in range(iStart+1, iEnd+1):
            psf4iwave = []
            for iwave in range(1, 5):        
                if ThisTask == 0:
                    print('iccd-ipsf-iwave: {:} {:} {:}'.format(iccd, ipsf, iwave), end='\r')
                psfInfo = myConfig.LoadPSF(iccd, iwave, ipsf, psfPath, InputMaxPixelPos=True, PSFCentroidWgt=False)
                
                cenPix = None
                if CENPIXUSED:
                    psfInfoX= myConfig.LoadPSF(iccd, iwave, ipsf, psfPath, InputMaxPixelPos=True, PSFCentroidWgt=True)
                    deltX= psfInfoX['centroid_x'] #in mm
                    deltY= psfInfoX['centroid_y'] #in mm
                    pixsize  = 2.5*1e-3 #mm, will use binningPSF
                    cenPix_X = 512/2 + deltX/pixsize
                    cenPix_Y = 512/2 + deltY/pixsize
                    cenPix = [cenPix_X, cenPix_Y]

                ipsfMat = psfInfo['psfMat']
                cenX, cenY, sz, e1, e2, REE80 = myUtil.psfSizeCalculator(ipsfMat, CalcPSFcenter=True, SigRange=True, TailorScheme=2, cenPix=cenPix)
                psf_wvREE80[iwave-1, ipsf-1] = REE80
                psf4iwave.append(ipsfMat)
            tt = myUtil.psfStack(psf4iwave[0], psf4iwave[1], psf4iwave[2], psf4iwave[3])
            cenX, cenY, sz, e1, e2, REE80 = myUtil.psfSizeCalculator(tt, CalcPSFcenter=True, SigRange=True, TailorScheme=2)
            psf_ttREE80[ipsf-1] = REE80

        if iccd == 1 and iwave ==1:
            print('iccd-{:}:'.format(iccd), flush=True)
            print('psfSet has been loaded.', flush=True)
            #print('Usage: psfSet[i][keys]', flush=True)
            #print('psfSet.keys:', psfSet[0].keys(), flush=True)
        else:
            print('iccd-{:}, iwave-{:}'.format(iccd, iwave), end='\r', flush=True)
    
        comm.barrier()
        psf_ttREE80 = comm.allreduce(psf_ttREE80, op=MPI.SUM)
        psf_wvREE80[0, :] = comm.allreduce(psf_wvREE80[0, :], op=MPI.SUM)
        psf_wvREE80[1, :] = comm.allreduce(psf_wvREE80[1, :], op=MPI.SUM)
        psf_wvREE80[2, :] = comm.allreduce(psf_wvREE80[2, :], op=MPI.SUM)
        psf_wvREE80[3, :] = comm.allreduce(psf_wvREE80[3, :], op=MPI.SUM)

        ttREE80[iccd-1]    = np.mean(psf_ttREE80)
        wvREE80[0, iccd-1] = np.mean(psf_wvREE80[0, :])
        wvREE80[1, iccd-1] = np.mean(psf_wvREE80[1, :])
        wvREE80[2, iccd-1] = np.mean(psf_wvREE80[2, :])
        wvREE80[3, iccd-1] = np.mean(psf_wvREE80[3, :])
        ##############################

    comm.barrier()
    #ttREE80 = comm.allreduce(ttREE80, op=MPI.SUM)
    #wvREE80 = comm.allreduce(wvREE80, op=MPI.SUM)

    #plot-test
    if ThisTask == 0:
        REE80W1 = wvREE80[0, :]
        REE80W2 = wvREE80[1, :]
        REE80W3 = wvREE80[2, :]
        REE80W4 = wvREE80[3, :]

        np.savetxt('REE80_w1.txt',REE80W1)
        np.savetxt('REE80_w2.txt',REE80W2)
        np.savetxt('REE80_w3.txt',REE80W3)
        np.savetxt('REE80_w4.txt',REE80W4)
        np.savetxt('REE80_tt.txt',ttREE80)


        ccdFilterLayout = ['GV', 'GV', 'GU', 'GU', 'GI', 'y', 'i', 'g', 'r', 'GI', 'z', 'NUV', 'NUV', 'u', 'y', 'y','u', 'NUV', 'NUV', 'z', 'GI', 'r', 'g', 'i', 'y', 'GI', 'GU', 'GU','GV', 'GV']

        fig = plt.figure(figsize=(18,10))
        for iccd in range(0,30):
            plt.arrow(iccd+1, REE80W1[iccd], 0, REE80W4[iccd]-REE80W1[iccd], width = 0.05, head_length=0.002, ec='None', color='k')
            plt.plot([iccd+1], [REE80W1[iccd]], 'o',c='k')
            plt.plot([iccd+1.1], [REE80W2[iccd]], 'o',c='b')
            plt.plot([iccd+1.2], [REE80W3[iccd]], 'o',c='g')
            plt.plot([iccd+1.3], [REE80W4[iccd]], 'o',c='r')
            plt.plot([iccd+1, iccd+1.1, iccd+1.2, iccd+1.3], [REE80W1[iccd], REE80W2[iccd], REE80W3[iccd], REE80W4[iccd]], '--',c='k')
            if REE80W1[iccd] < REE80W4[iccd]:
                plt.text(iccd+1-0.2, REE80W1[iccd]-0.005, ccdFilterLayout[iccd], fontsize=15)
            if REE80W1[iccd] > REE80W4[iccd]:
                plt.text(iccd+1-0.2, REE80W1[iccd]+0.003, ccdFilterLayout[iccd], fontsize=15)

        plt.fill_betweenx([0.078, 0.145], [0.5,0.5], [5.5,5.5], color='gray',alpha=0.5)
        plt.fill_betweenx([0.078, 0.145], [25.5,25.5], [30.5,30.5], color='gray',alpha=0.5)

        plt.fill_betweenx([0.078, 0.145], [9.5,9.5], [10.5,10.5], color='gray',alpha=0.5)
        plt.fill_betweenx([0.078, 0.145], [20.5,20.5], [21.5,21.5], color='gray',alpha=0.5)

        plt.plot([5.5, 5.5], [0.078, 0.5], ':')
        plt.plot([10.5, 10.5], [0.078, 0.5], 'k:')
        plt.plot([15.5, 15.5], [0.078, 0.5], 'k:')
        plt.plot([20.5, 20.5], [0.078, 0.5], 'k:')
        plt.plot([25.5, 25.5], [0.078, 0.5], 'k:')


        plt.ylim(0.078, 0.145)
        plt.xlim(0.5, 30.5)
        #plt.plot(np.linspace(1, 30, 30), REE80W1)
        #plt.plot(np.linspace(1, 30, 30), REE80W2)
        #plt.plot(np.linspace(1, 30, 30), REE80W3)
        #plt.plot(np.linspace(1, 30, 30), REE80W4)

        plt.xticks([])
        plt.yticks(fontsize=15)
        plt.text(1.5, 0.074, 'CCD1 - CCD5', fontsize=15)
        plt.text(6.5, 0.074, 'CCD6 - CCD10', fontsize=15)
        plt.text(11.5, 0.074, 'CCD11 - CCD15', fontsize=15)
        plt.text(16.5, 0.074, 'CCD16 - CCD20', fontsize=15)
        plt.text(21.5, 0.074, 'CCD21 - CCD25', fontsize=15)
        plt.text(26.5, 0.074, 'CCD26 - CCD30', fontsize=15)

        plt.plot([27], [0.183], 'ko')
        plt.text(27.5, 0.182, 'wave-1',fontsize=15)
        plt.plot([27], [0.180], 'ro')
        plt.text(27.5, 0.179, 'wave-2',fontsize=15)
        plt.plot([27], [0.177], 'go')
        plt.text(27.5, 0.176, 'wave-3',fontsize=15)
        plt.plot([27], [0.174], 'bo')
        plt.text(27.5, 0.173, 'wave-4',fontsize=15)


        #overplot stackedPSF
        xccd = np.linspace(1, 30, 30)
        plt.plot(xccd,ttREE80, 'm*', ms = 20, markerfacecolor='None', markeredgewidth=2)
        plt.plot([27], [0.168], 'm*', ms = 20, markerfacecolor='None', markeredgewidth=2)
        plt.text(27.5, 0.1665, 'stacked',fontsize=20)
        plt.savefig('figs/psfStackedREE80.pdf')    



##############################
##############################
##############################
if __name__=='__main__':
    comm = MPI.COMM_WORLD
    ThisTask = comm.Get_rank()
    NTasks   = comm.Get_size()

    psfPath = '/data/simudata/CSSOSDataProductsSims/data/csstPSFdata/CSSOS_psf_20210108/CSST_psf_ciomp_2p5um_cycle3'
    test_psfREE80(psfPath, ThisTask, NTasks)