Skip to content
PSFInterp.py 13.8 KiB
Newer Older
Fang Yuedong's avatar
Fang Yuedong committed
'''
PSF interpolation for CSST-Sim

NOTE: [iccd, iwave, ipsf] are counted from 1 to n, but [tccd, twave, tpsf] are counted from 0 to n-1
'''

import galsim
import numpy as np

import os
import time
import copy
import PSF.PSFInterp.PSFConfig as myConfig
import PSF.PSFInterp.PSFUtil as myUtil
from PSF.PSFModel import PSFModel

LOG_DEBUG = False #***#
NPSF      = 900  #***# 30*30
iccdTest  = 1 #***#


class PSFInterp(PSFModel):
    # def __init__(self, PSF_data=None, params=None, PSF_data_file=None):
    def __init__(self, chip, PSF_data=None, PSF_data_file=None, sigSpin=0., psfRa=0.15):
        """
        The PSF data matrix is either given by a object parameter or read in from a file.
        Parameters:
            PSF_data: The PSF data matrix object
            params: Other parameters?
            PSF_data_file: The file for PSF data matrix (optional).
        """
        if LOG_DEBUG:
            print('===================================================')
            print('DEBUG: psf module for csstSim ' \
                  +time.strftime("(%Y-%m-%d %H:%M:%S)", time.localtime()), flush=True)
            print('===================================================')

        self.sigSpin = sigSpin
        self.sigGauss = psfRa # 80% light radius

        self.iccd = int(chip.getChipLabel(chipID=chip.chipID))
        if PSF_data_file == None:
            PSF_data_file = '/data/simudata/CSSOSDataProductsSims/data/csstPSFdata/CSSOS_psf_ciomp_30X30'

        self.nwave= self._getPSFwave(self.iccd, PSF_data_file)
        self.PSF_data = self._loadPSF(self.iccd, PSF_data_file)

        if LOG_DEBUG:
            print('nwave-{:} on ccd-{:}::'.format(self.nwave, self.iccd), flush=True)
            print('self.PSF_data ... ok', flush=True)
            print('Preparing self.[psfMat,cen_col,cen_row] for psfMaker ... ', end='', flush=True)

        ngy, ngx    = self.PSF_data[0][0]['psfMat'].shape
        self.psfMat = np.zeros([self.nwave, NPSF, ngy, ngx])
        self.cen_col= np.zeros([self.nwave, NPSF])
        self.cen_row= np.zeros([self.nwave, NPSF])
        self.hoc    =[]
        self.hoclist=[]
        for twave in range(self.nwave):
            for tpsf in range(NPSF):
                #psfMatX = myUtil.psfTailor(self.PSF_data[twave][tpsf]['psfMat'], apSizeInArcsec=2.5)
                self.psfMat[twave, tpsf, :, :] = self.PSF_data[twave][tpsf]['psfMat']
                self.cen_col[twave, tpsf] = self.PSF_data[twave][tpsf]['image_x'] + self.PSF_data[twave][tpsf]['centroid_x']
                self.cen_row[twave, tpsf] = self.PSF_data[twave][tpsf]['image_y']  +self.PSF_data[twave][tpsf]['centroid_y']
            #hoclist on twave for neighborsFinding
            hoc,hoclist = myUtil.findNeighbors_hoclist(self.cen_col[twave], self.cen_row[twave])
            self.hoc.append(hoc)
            self.hoclist.append(hoclist)

        if LOG_DEBUG:
            print('ok', flush=True)


    def _getPSFwave(self, iccd, PSF_data_file):
        """
        Get # of sampling waves on iccd
        Parameters:
            iccd: The chip of i-th ccd
            PSF_data_file: The file for PSF data matrix
        Returns:
            nwave: The number of the sampling waves
        """
        strs = os.listdir(PSF_data_file + '/ccd{:}'.format(iccd))
        nwave = 0
        for _ in strs:
            if 'wave_' in _:
                nwave += 1
        return nwave



    def _loadPSF(self, iccd, PSF_data_file):
        """
        load psf-matrix on iccd
        Parameters:
            iccd: The chip of i-th ccd
            PSF_data_file: The file for PSF data matrix
        Returns:
            psfSet: The matrix of the csst-psf
        """
        psfSet = []
        for ii in range(self.nwave):
            iwave = ii+1
            if LOG_DEBUG:
                print('self._loadPSF: iwave::', iwave, flush=True)
            psfWave = []
            for jj in range(NPSF):
                ipsf = jj+1
                psfInfo = myConfig.LoadPSF(iccd, iwave, ipsf, PSF_data_file, InputMaxPixelPos=True, PSFCentroidWgt=True)
                psfWave.append(psfInfo)
            psfSet.append(psfWave)
        if LOG_DEBUG:
            print('psfSet has been loaded:', flush=True)
            print('psfSet[iwave][ipsf][keys]:', psfSet[0][0].keys(), flush=True)
        return psfSet


    
    def _preprocessPSF(self):
        """
        Preprocessing psf-matrix
        Parameters:
        
        Returns:
            itpPSF_data: The matrix of the preprocessed csst-psf
        """
        '''
        #old version (discarded)
        itpPSF_data = copy.deepcopy(self.PSF_data)
        for twave in range(self.nwave):
            for tpsf in range(NPSF):
                psfMat = self.PSF_data[twave][tpsf]['psfMat']
                psf_image_x = self.PSF_data[twave][tpsf]['image_x']
                psf_image_y = self.PSF_data[twave][tpsf]['image_y']

                #psfMatX= myUtil.psfCentering(psfMat, CenteringMode=1)
                #itpPSF_data[twave][tpsf]['psfMat'] = psfMatX

                #img, cxt, cyt = myUtil.psfCentering_wgt(psfMat, psf_image_x, psf_image_y)
                #itpPSF_data[twave][tpsf]['psfMat'] = img
                #itpPSF_data[twave][tpsf]['centroid_x'] = cxt
                #itpPSF_data[twave][tpsf]['centroid_y'] = cyt
        return itpPSF_data
        '''
        pass


    def _findWave(self, bandpass):
        for twave in range(self.nwave):
            bandwave = self.PSF_data[twave][0]['wavelength']
            if bandpass.blue_limit < bandwave and bandwave < bandpass.red_limit:
                return twave
        return -1
 


    def get_PSF(self, chip, pos_img, bandpass, pixSize=0.037, galsimGSObject=True, folding_threshold=5.e-3):
        """
        Get the PSF at a given image position

        Parameters:
            chip: A 'Chip' object representing the chip we want to extract PSF from.
            pos_img: A 'galsim.Position' object representing the image position.
            bandpass: A 'galsim.Bandpass' object representing the wavelength range.
            pixSize: The pixels size of psf matrix
        Returns:
            PSF: A 'galsim.GSObject'.
        """
        assert self.iccd == int(chip.getChipLabel(chipID=chip.chipID)), 'ERROR: self.iccd != chip.label'
        # twave = bandpass-1  #***# ??? #self.findWave(bandpass)  ###twave=iwave-1 as that in NOTE
        twave = self._findWave(bandpass)
        if twave == -1:
            print("!!!PSF bandpass does not match.")
            exit()
        PSFMat = self.psfMat[twave]
        cen_col= self.cen_col[twave]
        cen_row= self.cen_row[twave]

        # px = pos_img[0]
        # py = pos_img[1]
        px = (pos_img.x - chip.cen_pix_x)*0.01
        py = (pos_img.y - chip.cen_pix_y)*0.01
        imPSF = myConfig.psfMaker_IDW(px, py, PSFMat, cen_col, cen_row, IDWindex=2, OnlyNeighbors=True, hoc=self.hoc[twave], hoclist=self.hoclist[twave], PSFCentroidWgt=True)
        #imPSF = myConfig.psfMaker_IDW(px, py, PSFMat, cen_col, cen_row, IDWindex=2, OnlyNeighbors=True, PSFCentroidWgt=True)
        if galsimGSObject:
            img = galsim.ImageF(imPSF, scale=pixSize)
            gsp = galsim.GSParams(folding_threshold=folding_threshold)
            self.psf = galsim.InterpolatedImage(img, gsparams=gsp)
            # dx = px - chip.cen_pix_x
            # dy = py - chip.cen_pix_y
            return self.PSFspin(x=px/0.01, y=py/0.01)
        return imPSF

    def PSFspin(self, x, y):
        """
        The PSF profile at a given image position relative to the axis center

        Parameters:
        theta : spin angles in a given exposure in unit of [arcsecond]
        dx, dy: relative position to the axis center in unit of [pixels]

        Return:
        Spinned PSF: g1, g2 and axis ratio 'a/b'
        """
        a2Rad = np.pi/(60.0*60.0*180.0)
        
        ff = self.sigGauss * 0.107 * (1000.0/10.0) # in unit of [pixels]
        rc = np.sqrt(x*x + y*y)
        cpix = rc*(self.sigSpin*a2Rad)

        beta = (np.arctan2(y,x) + np.pi/2)
        ell = cpix**2/(2.0*ff**2+cpix**2)
        qr = np.sqrt((1.0+ell)/(1.0-ell))
        PSFshear = galsim.Shear(e=ell, beta=beta*galsim.radians)
        return self.psf.shear(PSFshear), PSFshear




def testPSFInterp():
    import time
    import matplotlib.pyplot as plt

    iccd = iccdTest  #[1, 30] for test
    iwave= 1   #[1, 4] for test

    npsfB   = 900
    psfPathB = '/data/simudata/CSSOSDataProductsSims/data/csstPSFdata/CSSOS_psf_ciomp_30X30'
    psfCSST = PSFInterp(PSF_data_file = psfPathB) ### define PSF_data from 30*30


    npsfA   = 100
    psfPathA = '/data/simudata/CSSOSDataProductsSims/data/csstPSFdata/CSSOS_psf_ciomp'
    psfSetA, PSFMatA, cen_colA, cen_rowA = myConfig.LoadPSFset(iccd, iwave, npsfA, psfPathA, InputMaxPixelPos=False, PSFCentroidWgt=True)  ###load test_data from 10*10

    psf_sz = np.zeros(npsfA)
    psf_e1 = np.zeros(npsfA)
    psf_e2 = np.zeros(npsfA)
    psfMaker_sz = np.zeros(npsfA)
    psfMaker_e1 = np.zeros(npsfA)
    psfMaker_e2 = np.zeros(npsfA)

    runtimeInterp = 0
    starttime = time.time()
    for ipsf in range(npsfA):
        print('IDW-ipsf: {:4}/100'.format(ipsf), end='\r', flush=True)

        starttimeInterp = time.time()
        px = cen_colA[ipsf]
        py = cen_rowA[ipsf]
        pos_img = [px, py]
        img = psfCSST.get_PSF(iccd, pos_img, iwave, galsimGSObject=False)  ###interpolate PSF at[px,py]
        endtimeInterp = time.time()
        runtimeInterp = runtimeInterp + (endtimeInterp - starttimeInterp)


        imx = psfSetA[ipsf]['psfMat']
        imy = img

        cenX, cenY, sz, e1, e2, REE80 = myUtil.psfSizeCalculator(imx, CalcPSFcenter=True, SigRange=True, TailorScheme=2)
        psf_sz[ipsf] = sz
        psf_e1[ipsf] = e1
        psf_e2[ipsf] = e2
        cenX, cenY, sz, e1, e2, REE80 = myUtil.psfSizeCalculator(imy, CalcPSFcenter=True, SigRange=True, TailorScheme=2)
        psfMaker_sz[ipsf] = sz
        psfMaker_e1[ipsf] = e1
        psfMaker_e2[ipsf] = e2
    endtime = time.time()
    print('run time::', endtime - starttime, runtimeInterp)

    if True:
        ell_iccd = np.zeros(npsfA)
        ell_iccd_psfMaker = np.zeros(npsfA)
        fig = plt.figure(figsize=(18, 5))
        plt.subplots_adjust(wspace=0.1, hspace=0.1)
        ax = plt.subplot(1, 3, 1)
        for ipsf in range(npsfA):
            imx = cen_colA[ipsf]
            imy = cen_rowA[ipsf]
            plt.plot(imx, imy, 'b.')

            ang = np.arctan2(psf_e2[ipsf], psf_e1[ipsf])/2
            ell = np.sqrt(psf_e1[ipsf]**2 + psf_e2[ipsf]**2)
            ell_iccd[ipsf] = ell
            ell *= 50
            lcos = ell*np.cos(ang)
            lsin = ell*np.sin(ang)
            plt.plot([imx-lcos, imx+lcos],[imy-lsin, imy+lsin],'b', lw=2)
            ###########
            ang = np.arctan2(psfMaker_e2[ipsf], psfMaker_e1[ipsf])/2
            ell = np.sqrt(psfMaker_e1[ipsf]**2 + psfMaker_e2[ipsf]**2)
            ell_iccd_psfMaker[ipsf] = ell
            ell *= 50
            lcos = ell*np.cos(ang)
            lsin = ell*np.sin(ang)
            plt.plot([imx-lcos, imx+lcos],[imy-lsin, imy+lsin],'r', lw=1)
            ###########
        plt.gca().set_aspect(1)

        ax = plt.subplot(1, 3, 2)
        tmp = plt.hist(ell_iccd, bins = 8, color='b', alpha=0.5)
        tmp = plt.hist(ell_iccd_psfMaker, bins = 8, color='r', alpha=0.5)
        plt.annotate('iccd-{:} iwave-{:}'.format(iccd, iwave), (0.55, 0.85), xycoords='axes fraction',fontsize=15)
        plt.xlabel('ell')
        plt.ylabel('PDF')

        ax = plt.subplot(1, 3, 3)
        dsz = (psfMaker_sz - psf_sz)/psf_sz
        dsz_hist= plt.hist(dsz)
        plt.xlabel('dsz')
        plt.savefig('test/figs/testPSFInterp_30t10_iccd{:}_iwave{:}.pdf'.format(iccd, iwave))


    


if __name__ == '__main__':
    if False:
        psfPath = '/data/simudata/CSSOSDataProductsSims/data/csstPSFdata/CSSOS_psf_ciomp_30X30'
        #psfPath = '/data/simudata/CSSOSDataProductsSims/data/csstPSFdata/CSSOS_psf_ciomp'
        psfCSST = PSFInterp(PSF_data_file = psfPath)
        iwave= 1
        ipsf = 665
        pos_img = [psfCSST.cen_col[iwave, ipsf], psfCSST.cen_row[iwave, ipsf]]
        img = psfCSST.get_PSF(1, pos_img, iwave, galsimGSObject=False)
        print('haha')

    if True:
        testPSFInterp()
    
    if False:
        #old version (discarded)
        #plot check-1
        import matplotlib.pyplot as plt
        fig = plt.figure(figsize=(18,5))
        ax = plt.subplot(1,3,1)
        plt.imshow(img)
        plt.colorbar()
        ax = plt.subplot(1,3,2)
        imgx = psfCSST.itpPSF_data[iwave][ipsf]['psfMat']
        imgx/= np.sum(imgx)
        plt.imshow(imgx)
        plt.colorbar()
        ax = plt.subplot(1,3,3)
        plt.imshow(img - imgx)
        plt.colorbar()
        plt.savefig('test/figs/test1.jpg')

    if False:
        #old version (discarded)
        #plot check-2: 注意图像坐标和全局坐标
        fig = plt.figure(figsize=(8,8), dpi = 200)
        img = psfCSST.PSF_data[iwave][ipsf]['psfMat']
        npix = img.shape[0]
        dng  = 105
        imgg = img[dng:-dng, dng:-dng]
        plt.imshow(imgg)
        imgX = psfCSST.PSF_data[iwave][ipsf]['image_x']  #in mm
        imgY = psfCSST.PSF_data[iwave][ipsf]['image_y']  #in mm
        deltX= psfCSST.PSF_data[iwave][ipsf]['centroid_x'] #in mm
        deltY= psfCSST.PSF_data[iwave][ipsf]['centroid_y'] #in mm
        maxX = psfCSST.PSF_data[iwave][ipsf]['max_x']
        maxY = psfCSST.PSF_data[iwave][ipsf]['max_y']
        cenPix_X = npix/2 + deltX/0.005
        cenPix_Y = npix/2 - deltY/0.005
        maxPix_X = npix/2 + maxX/0.005-1
        maxPix_Y = npix/2 - maxY/0.005-1
        plt.plot([cenPix_X-dng],[cenPix_Y-dng], 'rx', ms = 20)
        plt.plot([maxPix_X-dng],[maxPix_Y-dng], 'b+', ms=20)

        from scipy import ndimage
        y, x = ndimage.center_of_mass(img)
        plt.plot([x-dng],[y-dng], 'rx', ms = 10, mew=4)
        x, y = myUtil.findMaxPix(img)
        plt.plot([x-dng],[y-dng], 'b+', ms = 10, mew=4)
        plt.savefig('test/figs/test2.jpg')