'''
PSF interpolation for CSST-Sim

NOTE: [iccd, iwave, ipsf] are counted from 1 to n, but [tccd, twave, tpsf] are counted from 0 to n-1
'''

import sys
import time
import copy
import numpy as np
import scipy.spatial as spatial
import galsim
import h5py

from observation_sim.psf.PSFModel import PSFModel
from observation_sim.psf._util import psf_extrapolate


NPSF = 900  # ***# 30*30
PixSizeInMicrons = 5.  # ***# in microns


# find neighbors-KDtree #
def findNeighbors(tx, ty, px, py, dr=0.1, dn=1, OnlyDistance=True):
    """
    find nearest neighbors by 2D-KDTree

    Parameters:
        tx, ty (float, float): a given position
        px, py (numpy.array, numpy.array): position data for tree
        dr (float-optional): distance
        dn (int-optional): nearest-N
        OnlyDistance (bool-optional): only use distance to find neighbors. Default: True

    Returns:
        dataq (numpy.array): index
    """
    datax = px
    datay = py
    tree = spatial.KDTree(list(zip(datax.ravel(), datay.ravel())))

    dataq = []
    rr = dr
    if OnlyDistance is True:
        dataq = tree.query_ball_point([tx, ty], rr)
    if OnlyDistance is False:
        while len(dataq) < dn:
            dataq = tree.query_ball_point([tx, ty], rr)
            rr += dr
        dd = np.hypot(datax[dataq]-tx, datay[dataq]-ty)
        ddSortindx = np.argsort(dd)
        dataq = np.array(dataq)[ddSortindx[0:dn]]
    return dataq

# find neighbors-hoclist#


def hocBuild(partx, party, nhocx, nhocy, dhocx, dhocy):
    if np.max(partx) > nhocx*dhocx:
        print('ERROR')
        sys.exit()
    if np.max(party) > nhocy*dhocy:
        print('ERROR')
        sys.exit()

    npart = partx.size
    hoclist = np.zeros(npart, dtype=np.int32)-1
    hoc = np.zeros([nhocy, nhocx], dtype=np.int32)-1
    for ipart in range(npart):
        ix = int(partx[ipart]/dhocx)
        iy = int(party[ipart]/dhocy)
        hoclist[ipart] = hoc[iy, ix]
        hoc[iy, ix] = ipart
    return hoc, hoclist


def hocFind(px, py, dhocx, dhocy, hoc, hoclist):
    ix = int(px/dhocx)
    iy = int(py/dhocy)

    neigh = []
    it = hoc[iy, ix]
    while it != -1:
        neigh.append(it)
        it = hoclist[it]
    return neigh


def findNeighbors_hoclist(px, py, tx=None, ty=None, dn=4, hoc=None, hoclist=None):
    nhocy = nhocx = 20

    pxMin = np.min(px)
    pxMax = np.max(px)
    pyMin = np.min(py)
    pyMax = np.max(py)

    dhocx = (pxMax - pxMin)/(nhocx-1)
    dhocy = (pyMax - pyMin)/(nhocy-1)
    partx = px - pxMin + dhocx/2
    party = py - pyMin + dhocy/2

    if hoc is None:
        hoc, hoclist = hocBuild(partx, party, nhocx, nhocy, dhocx, dhocy)
        return hoc, hoclist

    if hoc is not None:
        tx = tx - pxMin + dhocx/2
        ty = ty - pyMin + dhocy/2
        itx = int(tx/dhocx)
        ity = int(ty/dhocy)

        ps = [-1, 0, 1]
        neigh = []
        for ii in range(3):
            for jj in range(3):
                ix = itx + ps[ii]
                iy = ity + ps[jj]
                if ix < 0:
                    continue
                if iy < 0:
                    continue
                if ix > nhocx-1:
                    continue
                if iy > nhocy-1:
                    continue

                # neightt = myUtil.hocFind(ppx, ppy, dhocx, dhocy, hoc, hoclist)
                it = hoc[iy, ix]
                while it != -1:
                    neigh.append(it)
                    it = hoclist[it]
                # neigh.append(neightt)
        # ll = [i for k in neigh for i in k]
        if dn != -1:
            ptx = np.array(partx[neigh])
            pty = np.array(party[neigh])
            dd = np.hypot(ptx-tx, pty-ty)
            idx = np.argsort(dd)
            neigh = np.array(neigh)[idx[0:dn]]
        return neigh


# PSF-IDW#
def psfMaker_IDW(px, py, PSFMat, cen_col, cen_row, IDWindex=2, OnlyNeighbors=True, hoc=None, hoclist=None, PSFCentroidWgt=False):
    """
    psf interpolation by IDW

    Parameters:
        px, py (float, float): position of the target
        PSFMat (numpy.array): image
        cen_col, cen_row (numpy.array, numpy.array): potions of the psf centers
        IDWindex (int-optional): the power index of IDW
        OnlyNeighbors (bool-optional): only neighbors are used for psf interpolation

    Returns:
        psfMaker (numpy.array)
    """

    minimum_psf_weight = 1e-8
    ref_col = px
    ref_row = py

    ngy, ngx = PSFMat[0, :, :].shape
    npsf = PSFMat[:, :, :].shape[0]
    psfWeight = np.zeros([npsf])

    if OnlyNeighbors is True:
        if hoc is None:
            neigh = findNeighbors(px, py, cen_col, cen_row,
                                  dr=5., dn=4, OnlyDistance=False)
        if hoc is not None:
            neigh = findNeighbors_hoclist(
                cen_col, cen_row, tx=px, ty=py, dn=4, hoc=hoc, hoclist=hoclist)

        neighFlag = np.zeros(npsf)
        neighFlag[neigh] = 1

    for ipsf in range(npsf):
        if OnlyNeighbors is True:
            if neighFlag[ipsf] != 1:
                continue

        dist = np.sqrt((ref_col - cen_col[ipsf])
                       ** 2 + (ref_row - cen_row[ipsf])**2)
        if IDWindex == 1:
            psfWeight[ipsf] = dist
        if IDWindex == 2:
            psfWeight[ipsf] = dist**2
        if IDWindex == 3:
            psfWeight[ipsf] = dist**3
        if IDWindex == 4:
            psfWeight[ipsf] = dist**4
        psfWeight[ipsf] = max(psfWeight[ipsf], minimum_psf_weight)
        psfWeight[ipsf] = 1./psfWeight[ipsf]
    psfWeight /= np.sum(psfWeight)

    psfMaker = np.zeros([ngy, ngx], dtype=np.float32)
    for ipsf in range(npsf):
        if OnlyNeighbors is True:
            if neighFlag[ipsf] != 1:
                continue

        iPSFMat = PSFMat[ipsf, :, :].copy()
        ipsfWeight = psfWeight[ipsf]

        psfMaker += iPSFMat * ipsfWeight
    psfMaker /= np.nansum(psfMaker)

    return psfMaker


# define PSFInterp#
class PSFInterp(PSFModel):
    def __init__(self, chip, npsf=NPSF, PSF_data=None, PSF_data_file=None, PSF_data_prefix="", sigSpin=0, psfRa=0.15, HocBuild=False, LOG_DEBUG=False):
        self.LOG_DEBUG = LOG_DEBUG
        if self.LOG_DEBUG:
            print('===================================================')
            print('DEBUG: psf module for csstSim '
                  + time.strftime("(%Y-%m-%d %H:%M:%S)", time.localtime()), flush=True)
            print('===================================================')

        self.sigSpin = sigSpin
        self.sigGauss = psfRa

        self.iccd = int(chip.getChipLabel(chipID=chip.chipID))
        # self.iccd = chip.chip_name
        if PSF_data_file is None:
            print('Error - PSF_data_file is None')
            sys.exit()

        self.nwave = self._getPSFwave(
            self.iccd, PSF_data_file, PSF_data_prefix)
        self.npsf = npsf
        self.PSF_data = self._loadPSF(
            self.iccd, PSF_data_file, PSF_data_prefix)

        if self.LOG_DEBUG:
            print('nwave-{:} on ccd-{:}::'.format(self.nwave,
                  self.iccd), flush=True)
            print('self.PSF_data ... ok', flush=True)
            print(
                'Preparing self.[psfMat,cen_col,cen_row] for psfMaker ... ', end='', flush=True)

        ngy, ngx = self.PSF_data[0][0]['psfMat'].shape
        self.psfMat = np.zeros(
            [self.nwave, self.npsf, ngy, ngx], dtype=np.float32)
        self.cen_col = np.zeros([self.nwave, self.npsf], dtype=np.float32)
        self.cen_row = np.zeros([self.nwave, self.npsf], dtype=np.float32)
        self.hoc = []
        self.hoclist = []

        for twave in range(self.nwave):
            for tpsf in range(self.npsf):
                self.psfMat[twave, tpsf, :,
                            :] = self.PSF_data[twave][tpsf]['psfMat']
                self.PSF_data[twave][tpsf]['psfMat'] = 0  # free psfMat

                self.pixsize = self.PSF_data[twave][tpsf]['pixsize']*1e-3  # mm
                self.cen_col[twave, tpsf] = self.PSF_data[twave][tpsf]['image_x'] + \
                    self.PSF_data[twave][tpsf]['centroid_x']
                self.cen_row[twave, tpsf] = self.PSF_data[twave][tpsf]['image_y'] + \
                    self.PSF_data[twave][tpsf]['centroid_y']

            if HocBuild:
                # hoclist on twave for neighborsFinding
                hoc, hoclist = findNeighbors_hoclist(
                    self.cen_col[twave], self.cen_row[twave])
                self.hoc.append(hoc)
                self.hoclist.append(hoclist)

        if self.LOG_DEBUG:
            print('ok', flush=True)

    def _getPSFwave(self, iccd, PSF_data_file, PSF_data_prefix):
        # fq = h5py.File(PSF_data_file+'/' +PSF_data_prefix +'psfCube_ccd{:}.h5'.format(iccd), 'r')
        fq = h5py.File(PSF_data_file+'/' + PSF_data_prefix +
                       'psfCube_{:}.h5'.format(iccd), 'r')
        nwave = len(fq.keys())
        fq.close()
        return nwave

    def _loadPSF(self, iccd, PSF_data_file, PSF_data_prefix):
        psfSet = []
        # fq = h5py.File(PSF_data_file+'/' +PSF_data_prefix +'psfCube_ccd{:}.h5'.format(iccd), 'r')
        fq = h5py.File(PSF_data_file+'/' + PSF_data_prefix +
                       'psfCube_{:}.h5'.format(iccd), 'r')
        for ii in range(self.nwave):
            iwave = ii+1
            psfWave = []

            fq_iwave = fq['w_{:}'.format(iwave)]
            for jj in range(self.npsf):
                ipsf = jj+1
                psfInfo = {}
                psfInfo['wavelength'] = fq_iwave['wavelength'][()]

                fq_iwave_ipsf = fq_iwave['psf_{:}'.format(ipsf)]
                psfInfo['pixsize'] = PixSizeInMicrons
                psfInfo['field_x'] = fq_iwave_ipsf['field_x'][()]
                psfInfo['field_y'] = fq_iwave_ipsf['field_y'][()]
                psfInfo['image_x'] = fq_iwave_ipsf['image_x'][()]
                psfInfo['image_y'] = fq_iwave_ipsf['image_y'][()]
                psfInfo['centroid_x'] = fq_iwave_ipsf['cx'][()]
                psfInfo['centroid_y'] = fq_iwave_ipsf['cy'][()]
                psfInfo['psfMat'] = fq_iwave_ipsf['psfMat'][()]

                psfWave.append(psfInfo)
            psfSet.append(psfWave)
        fq.close()

        if self.LOG_DEBUG:
            print('psfSet has been loaded:', flush=True)
            print('psfSet[iwave][ipsf][keys]:',
                  psfSet[0][0].keys(), flush=True)
        return psfSet

    def _findWave(self, bandpass):
        if isinstance(bandpass, int):
            twave = bandpass
            return twave

        for twave in range(self.nwave):
            bandwave = self.PSF_data[twave][0]['wavelength']
            if bandpass.blue_limit < bandwave and bandwave < bandpass.red_limit:
                return twave
        return -1

    def get_PSF(self, chip, pos_img, bandpass, galsimGSObject=True, findNeighMode='treeFind', folding_threshold=5.e-3, pointing_pa=0.0, extrapolate=False, ngg=2048):
        """
        Get the PSF at a given image position

        Parameters:
            chip: A 'Chip' object representing the chip we want to extract PSF from.
            pos_img: A 'galsim.Position' object representing the image position.
            bandpass: A 'galsim.Bandpass' object representing the wavelength range.
            pixSize: The pixels size of psf matrix
            findNeighMode: 'treeFind' or 'hoclistFind'
        Returns:
            PSF: A 'galsim.GSObject'.
        """
        pixSize = np.rad2deg(self.pixsize*1e-3/28)*3600  # set psf pixsize

        # assert self.iccd == int(chip.getChipLabel(chipID=chip.chipID)), 'ERROR: self.iccd != chip.chipID'
        twave = self._findWave(bandpass)
        if twave == -1:
            print("!!!PSF bandpass does not match.")
            exit()
        PSFMat = self.psfMat[twave]
        cen_col = self.cen_col[twave]
        cen_row = self.cen_row[twave]

        px = (pos_img.x - chip.cen_pix_x)*0.01
        py = (pos_img.y - chip.cen_pix_y)*0.01
        if findNeighMode == 'treeFind':
            imPSF = psfMaker_IDW(px, py, PSFMat, cen_col, cen_row,
                                 IDWindex=2, OnlyNeighbors=True, PSFCentroidWgt=True)
        if findNeighMode == 'hoclistFind':
            assert (self.hoc != 0), 'hoclist should be built correctly!'
            imPSF = psfMaker_IDW(px, py, PSFMat, cen_col, cen_row, IDWindex=2, OnlyNeighbors=True,
                                 hoc=self.hoc[twave], hoclist=self.hoclist[twave], PSFCentroidWgt=True)

        if extrapolate is True:
            ccdList = [6,  7,  8,  9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23, 24, 25]
            rr_trim_list = [72, 64, 96, 88, 64, 72, 72, 76, 72, 72, 76, 72, 72, 64, 88, 96, 64, 72]
            imPSF = psf_extrapolate(imPSF, rr_trim=rr_trim_list[ccdList.index(chip.chipID)], ngg=ngg)

        if galsimGSObject:
            if extrapolate is True:
                imPSFt = np.zeros([ngg+1, ngg+1])
                imPSFt[:-1, :-1] = imPSF
            else:
                imPSFt = np.zeros([257, 257])
                imPSFt[0:256, 0:256] = imPSF

            img = galsim.ImageF(imPSFt, scale=pixSize)
            gsp = galsim.GSParams(folding_threshold=folding_threshold)
            # TEST: START
            # Use sheared PSF to test the PSF orientation
            # self.psf = galsim.InterpolatedImage(img, gsparams=gsp).shear(g1=0.8, g2=0.)
            # TEST: END
            self.psf = galsim.InterpolatedImage(img, gsparams=gsp)
            wcs = chip.img.wcs.local(pos_img)
            scale = galsim.PixelScale(0.074)
            self.psf = wcs.toWorld(scale.toImage(
                self.psf), image_pos=(pos_img))

            # return self.PSFspin(x=px/0.01, y=py/0.01)
            return self.psf, galsim.Shear(e=0., beta=(np.pi/2)*galsim.radians)
        return imPSF

    '''
    def PSFspin(self, x, y):
        """
        The PSF profile at a given image position relative to the axis center

        Parameters:
        theta : spin angles in a given exposure in unit of [arcsecond]
        dx, dy: relative position to the axis center in unit of [pixels]

        Return:
        Spinned PSF: g1, g2 and axis ratio 'a/b'
        """
        a2Rad = np.pi/(60.0*60.0*180.0)

        ff = self.sigGauss * 0.107 * (1000.0/10.0) # in unit of [pixels]
        rc = np.sqrt(x*x + y*y)
        cpix = rc*(self.sigSpin*a2Rad)

        beta = (np.arctan2(y,x) + np.pi/2)
        ell = cpix**2/(2.0*ff**2+cpix**2)
        qr = np.sqrt((1.0+ell)/(1.0-ell))
        PSFshear = galsim.Shear(e=ell, beta=beta*galsim.radians)
        return self.psf.shear(PSFshear), PSFshear
    '''