Skip to content
PSFInterpSLS.py 32.5 KiB
Newer Older
Fang Yuedong's avatar
Fang Yuedong committed
'''
PSF interpolation for CSST-Sim

NOTE: [iccd, iwave, ipsf] are counted from 1 to n, but [tccd, twave, tpsf] are counted from 0 to n-1
'''

import yaml
import sys
import time
import copy
import numpy as np
import scipy.spatial as spatial
import galsim
import h5py

from observation_sim.instruments import Filter, FilterParam, Chip
Fang Yuedong's avatar
Fang Yuedong committed
from observation_sim.psf.PSFModel import PSFModel
Fang Yuedong's avatar
Fang Yuedong committed
from observation_sim.instruments.chip import chip_utils
import os
from astropy.io import fits

from astropy.modeling.models import Gaussian2D
from scipy import signal, interpolate
import datetime
import gc
from astropy.io import fits

from observation_sim.psf._util import psf_extrapolate, psf_extrapolate1
Zhang Xin's avatar
Zhang Xin committed
# from jax import numpy as jnp
Fang Yuedong's avatar
Fang Yuedong committed

LOG_DEBUG = False  # ***#
NPSF = 900  # ***# 30*30
PIX_SIZE_MICRON = 5.  # ***# in microns


Zhang Xin's avatar
Zhang Xin committed
# find neighbors-KDtree
Fang Yuedong's avatar
Fang Yuedong committed
# def findNeighbors(tx, ty, px, py, dr=0.1, dn=1, OnlyDistance=True):
#     """
#     find nearest neighbors by 2D-KDTree
#
#     Parameters:
#         tx, ty (float, float): a given position
#         px, py (numpy.array, numpy.array): position data for tree
#         dr (float-optional): distance
#         dn (int-optional): nearest-N
#         OnlyDistance (bool-optional): only use distance to find neighbors. Default: True
#
#     Returns:
#         dataq (numpy.array): index
#     """
#     datax = px
#     datay = py
#     tree = spatial.KDTree(list(zip(datax.ravel(), datay.ravel())))
#
#     dataq=[]
#     rr = dr
#     if OnlyDistance == True:
#         dataq = tree.query_ball_point([tx, ty], rr)
#     if OnlyDistance == False:
#         while len(dataq) < dn:
#             dataq = tree.query_ball_point([tx, ty], rr)
#             rr += dr
#         dd = np.hypot(datax[dataq]-tx, datay[dataq]-ty)
#         ddSortindx = np.argsort(dd)
#         dataq = np.array(dataq)[ddSortindx[0:dn]]
#     return dataq
#
# ###find neighbors-hoclist###
# def hocBuild(partx, party, nhocx, nhocy, dhocx, dhocy):
#     if np.max(partx) > nhocx*dhocx:
#         print('ERROR')
#         sys.exit()
#     if np.max(party) > nhocy*dhocy:
#         print('ERROR')
#         sys.exit()
#
#     npart  = partx.size
#     hoclist= np.zeros(npart, dtype=np.int32)-1
#     hoc = np.zeros([nhocy, nhocx], dtype=np.int32)-1
#     for ipart in range(npart):
#         ix = int(partx[ipart]/dhocx)
#         iy = int(party[ipart]/dhocy)
#         hoclist[ipart] = hoc[iy, ix]
#         hoc[iy, ix] = ipart
#     return hoc, hoclist
#
# def hocFind(px, py, dhocx, dhocy, hoc, hoclist):
#     ix = int(px/dhocx)
#     iy = int(py/dhocy)
#
#     neigh=[]
#     it = hoc[iy, ix]
#     while it != -1:
#         neigh.append(it)
#         it = hoclist[it]
#     return neigh
#
# def findNeighbors_hoclist(px, py, tx=None,ty=None, dn=4, hoc=None, hoclist=None):
#     nhocy = nhocx = 20
#
#     pxMin = np.min(px)
#     pxMax = np.max(px)
#     pyMin = np.min(py)
#     pyMax = np.max(py)
#
#     dhocx = (pxMax - pxMin)/(nhocx-1)
#     dhocy = (pyMax - pyMin)/(nhocy-1)
#     partx = px - pxMin +dhocx/2
#     party = py - pyMin +dhocy/2
#
#     if hoc is None:
#         hoc, hoclist = hocBuild(partx, party, nhocx, nhocy, dhocx, dhocy)
#         return hoc, hoclist
#
#     if hoc is not None:
#         tx = tx - pxMin +dhocx/2
#         ty = ty - pyMin +dhocy/2
#         itx = int(tx/dhocx)
#         ity = int(ty/dhocy)
#
#         ps = [-1, 0, 1]
#         neigh=[]
#         for ii in range(3):
#             for jj in range(3):
#                 ix = itx + ps[ii]
#                 iy = ity + ps[jj]
#                 if ix < 0:
#                     continue
#                 if iy < 0:
#                     continue
#                 if ix > nhocx-1:
#                     continue
#                 if iy > nhocy-1:
#                     continue
#
#                 #neightt = myUtil.hocFind(ppx, ppy, dhocx, dhocy, hoc, hoclist)
#                 it = hoc[iy, ix]
#                 while it != -1:
#                     neigh.append(it)
#                     it = hoclist[it]
#                 #neigh.append(neightt)
#         #ll = [i for k in neigh for i in k]
#         if dn != -1:
#             ptx = np.array(partx[neigh])
#             pty = np.array(party[neigh])
#             dd  = np.hypot(ptx-tx, pty-ty)
#             idx = np.argsort(dd)
#             neigh= np.array(neigh)[idx[0:dn]]
#         return neigh
#
#
# ###PSF-IDW###
# def psfMaker_IDW(px, py, PSFMat, cen_col, cen_row, IDWindex=2, OnlyNeighbors=True, hoc=None, hoclist=None, PSFCentroidWgt=False):
#     """
#     psf interpolation by IDW
#
#     Parameters:
#         px, py (float, float): position of the target
#         PSFMat (numpy.array): image
#         cen_col, cen_row (numpy.array, numpy.array): potions of the psf centers
#         IDWindex (int-optional): the power index of IDW
#         OnlyNeighbors (bool-optional): only neighbors are used for psf interpolation
#
#     Returns:
#         psfMaker (numpy.array)
#     """
#
#     minimum_psf_weight = 1e-8
#     ref_col = px
#     ref_row = py
#
#     ngy, ngx = PSFMat[0, :, :].shape
#     npsf = PSFMat[:, :, :].shape[0]
#     psfWeight = np.zeros([npsf])
#
#     if OnlyNeighbors == True:
#         if hoc is None:
#             neigh = findNeighbors(px, py, cen_col, cen_row, dr=5., dn=4, OnlyDistance=False)
#         if hoc is not None:
#             neigh = findNeighbors_hoclist(cen_col, cen_row, tx=px,ty=py, dn=4, hoc=hoc, hoclist=hoclist)
#
#         neighFlag = np.zeros(npsf)
#         neighFlag[neigh] = 1
#
#     for ipsf in range(npsf):
#         if OnlyNeighbors == True:
#             if neighFlag[ipsf] != 1:
#                 continue
#
#         dist = np.sqrt((ref_col - cen_col[ipsf])**2 + (ref_row - cen_row[ipsf])**2)
#         if IDWindex == 1:
#             psfWeight[ipsf] = dist
#         if IDWindex == 2:
#             psfWeight[ipsf] = dist**2
#         if IDWindex == 3:
#             psfWeight[ipsf] = dist**3
#         if IDWindex == 4:
#             psfWeight[ipsf] = dist**4
#         psfWeight[ipsf] = max(psfWeight[ipsf], minimum_psf_weight)
#         psfWeight[ipsf] = 1./psfWeight[ipsf]
#     psfWeight /= np.sum(psfWeight)
#
#     psfMaker  = np.zeros([ngy, ngx], dtype=np.float32)
#     for ipsf in range(npsf):
#         if OnlyNeighbors == True:
#             if neighFlag[ipsf] != 1:
#                 continue
#
#         iPSFMat = PSFMat[ipsf, :, :].copy()
#         ipsfWeight = psfWeight[ipsf]
#
#         psfMaker += iPSFMat * ipsfWeight
#     psfMaker /= np.nansum(psfMaker)
#
#     return psfMaker


Zhang Xin's avatar
Zhang Xin committed
# define PSFInterp
Fang Yuedong's avatar
Fang Yuedong committed
class PSFInterpSLS(PSFModel):
    def __init__(self, chip, filt, PSF_data_prefix="", sigSpin=0, psfRa=0.15, pix_size=0.005):
        if LOG_DEBUG:
            print('===================================================')
            print('DEBUG: psf module for csstSim '
                  + time.strftime("(%Y-%m-%d %H:%M:%S)", time.localtime()), flush=True)
            print('===================================================')

        self.sigSpin = sigSpin
        self.sigGauss = psfRa
        self.grating_ids = chip_utils.getChipSLSGratingID(chip.chipID)
        _, self.grating_type = chip.getChipFilter(chipID=chip.chipID)
        self.data_folder = PSF_data_prefix
        self.getPSFDataFromFile(filt)
        self.pixsize = pix_size  # um

    def getPSFDataFromFile(self, filt):
        gratingInwavelist = {'GU': 0, 'GV': 1, 'GI': 2}
        grating_orders = ['0', '1']
        waveListFn = self.data_folder + '/wavelist.dat'
        wavelists = np.loadtxt(waveListFn)
        self.waveList = wavelists[:, gratingInwavelist[self.grating_type]]
        bandranges = np.zeros([4, 2])
        midBand = (self.waveList[0:3] + self.waveList[1:4])/2.*10000.
        bandranges[0, 0] = filt.blue_limit
        bandranges[1:4, 0] = midBand
        bandranges[0:3, 1] = midBand
        bandranges[3, 1] = filt.red_limit

        self.bandranges = bandranges

        self.grating1_data = {}
        g_folder = self.data_folder + '/' + self.grating_ids[0] + '/'
        for g_order in grating_orders:
            g_folder_order = g_folder + 'PSF_Order_' + g_order + '/'
            grating_order_data = {}
            for bandi in [1, 2, 3, 4]:
                subBand_data = {}
                subBand_data['bandrange'] = bandranges[bandi-1]
                final_folder = g_folder_order + str(bandi) + '/'
                print(final_folder)
                pca_fs = os.listdir(final_folder)
                for fname in pca_fs:
                    if ('_PCs.fits' in fname) and (fname[0] != '.'):
                        fname_ = final_folder + fname
                        hdu = fits.open(fname_)
                        subBand_data['band_data'] = hdu
                grating_order_data['band'+str(bandi)] = subBand_data
            self.grating1_data['order'+g_order] = grating_order_data

        self.grating2_data = {}
        g_folder = self.data_folder + '/' + self.grating_ids[1] + '/'
        for g_order in grating_orders:
            g_folder_order = g_folder + 'PSF_Order_' + g_order + '/'
            grating_order_data = {}
            for bandi in [1, 2, 3, 4]:
                subBand_data = {}
                subBand_data['bandrange'] = bandranges[bandi - 1]
                final_folder = g_folder_order + str(bandi) + '/'
                print(final_folder)
                pca_fs = os.listdir(final_folder)
                for fname in pca_fs:
                    if ('_PCs.fits' in fname) and (fname[0] != '.'):
                        fname_ = final_folder + fname
                        hdu = fits.open(fname_)
                        subBand_data['band_data'] = hdu
                grating_order_data['band' + str(bandi)] = subBand_data
            self.grating2_data['order' + g_order] = grating_order_data

    #
    #
    #
    # def _getPSFwave(self, iccd, PSF_data_file, PSF_data_prefix):
    #     # fq = h5py.File(PSF_data_file+'/' +PSF_data_prefix +'psfCube_ccd{:}.h5'.format(iccd), 'r')
    #     fq = h5py.File(PSF_data_file+'/' +PSF_data_prefix +'psfCube_{:}.h5'.format(iccd), 'r')
    #     nwave = len(fq.keys())
    #     fq.close()
    #     return nwave
    #
    #
    # def _loadPSF(self, iccd, PSF_data_file, PSF_data_prefix):
    #     psfSet = []
    #     # fq = h5py.File(PSF_data_file+'/' +PSF_data_prefix +'psfCube_ccd{:}.h5'.format(iccd), 'r')
    #     fq = h5py.File(PSF_data_file+'/' +PSF_data_prefix +'psfCube_{:}.h5'.format(iccd), 'r')
    #     for ii in range(self.nwave):
    #         iwave = ii+1
    #         psfWave = []
    #
    #         fq_iwave = fq['w_{:}'.format(iwave)]
    #         for jj in range(self.npsf):
    #             ipsf = jj+1
    #             psfInfo = {}
    #             psfInfo['wavelength']= fq_iwave['wavelength'][()]
    #
    #             fq_iwave_ipsf = fq_iwave['psf_{:}'.format(ipsf)]
    #             psfInfo['pixsize']   = PIX_SIZE_MICRON
    #             psfInfo['field_x']   = fq_iwave_ipsf['field_x'][()]
    #             psfInfo['field_y']   = fq_iwave_ipsf['field_y'][()]
    #             psfInfo['image_x']   = fq_iwave_ipsf['image_x'][()]
    #             psfInfo['image_y']   = fq_iwave_ipsf['image_y'][()]
    #             psfInfo['centroid_x']= fq_iwave_ipsf['cx'][()]
    #             psfInfo['centroid_y']= fq_iwave_ipsf['cy'][()]
    #             psfInfo['psfMat']    = fq_iwave_ipsf['psfMat'][()]
    #
    #             psfWave.append(psfInfo)
    #         psfSet.append(psfWave)
    #     fq.close()
    #
    #     if LOG_DEBUG:
    #         print('psfSet has been loaded:', flush=True)
    #         print('psfSet[iwave][ipsf][keys]:', psfSet[0][0].keys(), flush=True)
    #     return psfSet
    #
    #
    # def _findWave(self, bandpass):
    #     if isinstance(bandpass,int):
    #         twave = bandpass
    #         return twave
    #
    #     for twave in range(self.nwave):
    #         bandwave = self.PSF_data[twave][0]['wavelength']
    #         if bandpass.blue_limit < bandwave and bandwave < bandpass.red_limit:
    #             return twave
    #     return -1
    #
    #

    def convolveWithGauss(self, img=None, sigma=1):

        offset = int(np.ceil(sigma * 3))
        g_size = 2 * offset + 1
        m_cen = int(g_size / 2)
        print('-----', g_size)
        g_PSF_ = Gaussian2D(1, m_cen, m_cen, sigma, sigma)
        yp, xp = np.mgrid[0:g_size, 0:g_size]
        g_PSF = g_PSF_(xp, yp)
        psf = g_PSF / g_PSF.sum()
        convImg = signal.fftconvolve(img, psf, mode='full', axes=None)
        convImg = convImg/np.sum(convImg)
        return convImg

    def get_PSF(self, chip, pos_img_local=[1000, 1000], bandNo=1, galsimGSObject=True, folding_threshold=5.e-3, g_order='A', grating_split_pos=3685, extrapolate=False, ngg=2048):
Fang Yuedong's avatar
Fang Yuedong committed
        """
        Get the PSF at a given image position

        Parameters:
            chip: A 'Chip' object representing the chip we want to extract PSF from.
            pos_img: A 'galsim.Position' object representing the image position.
            bandpass: A 'galsim.Bandpass' object representing the wavelength range.
            pixSize: The pixels size of psf matrix
            findNeighMode: 'treeFind' or 'hoclistFind'
        Returns:
            PSF: A 'galsim.GSObject'.
        """
        order_IDs = {'A': '1', 'B': '0', 'C': '0', 'D': '0', 'E': '0'}
        contam_order_sigma = {'C': 0.28032344707964174,
                              'D': 0.39900182912061344, 'E': 1.1988309797685412}  # arcsec
        x_start = chip.x_cen/chip.pix_size - chip.npix_x / 2.
        y_start = chip.y_cen/chip.pix_size - chip.npix_y / 2.
        # print(pos_img.x - x_start)
        pos_img_x = pos_img_local[0] + x_start
        pos_img_y = pos_img_local[1] + y_start
        pos_img = galsim.PositionD(pos_img_x, pos_img_y)
        if pos_img_local[0] < grating_split_pos:
            psf_data = self.grating1_data
        else:
            psf_data = self.grating2_data

        grating_order = order_IDs[g_order]
        # if grating_order in ['-2','-1','2']:
        #     grating_order = '1'

        # if grating_order in ['0', '1']:
        psf_order = psf_data['order'+grating_order]
        psf_order_b = psf_order['band'+str(bandNo)]
        psf_b_dat = psf_order_b['band_data']
        pos_p = psf_b_dat[1].data
        pc_coeff = psf_b_dat[2].data
        pcs = psf_b_dat[0].data
        # print(max(pos_p[:,0]), min(pos_p[:,0]),max(pos_p[:,1]), min(pos_p[:,1]))
        # print(chip.x_cen, chip.y_cen)
        # print(pos_p)
        px = pos_img.x*chip.pix_size
        py = pos_img.y*chip.pix_size

        dist2 = (pos_p[:, 1] - px)*(pos_p[:, 1] - px) + \
            (pos_p[:, 0] - py)*(pos_p[:, 0] - py)
        temp_sort_dist = np.zeros([dist2.shape[0], 2])
        temp_sort_dist[:, 0] = np.arange(0, dist2.shape[0], 1)
        temp_sort_dist[:, 1] = dist2
        # print(temp_sort_dist)
        dits2_sortlist = sorted(temp_sort_dist, key=lambda x: x[1])
        # print(dits2_sortlist)
        nearest4p = np.zeros([4, 2])
        pc_coeff_4p = np.zeros([pc_coeff.data.shape[0], 4])

        for i in np.arange(4):
            smaller_ids = int(dits2_sortlist[i][0])
            nearest4p[i, 0] = pos_p[smaller_ids, 1]
            nearest4p[i, 1] = pos_p[smaller_ids, 0]
            pc_coeff_4p[:, i] = pc_coeff[:, smaller_ids]
        idw_dist = 1/(np.sqrt((px-nearest4p[:, 0]) * (px-nearest4p[:, 0]) + (
            py-nearest4p[:, 1]) * (py-nearest4p[:, 1])))

        coeff_int = np.zeros(pc_coeff.data.shape[0])
        for i in np.arange(4):
            coeff_int = coeff_int + pc_coeff_4p[:, i]*idw_dist[i]
        coeff_int = coeff_int / np.sum(coeff_int)

        npc = 10
        m_size = int(pcs.shape[0]**0.5)
        PSF_int = np.dot(pcs[:, 0:npc], coeff_int[0:npc]
                         ).reshape(m_size, m_size)

        # PSF_int = PSF_int/np.sum(PSF_int)
        PSF_int_trans = np.flipud(np.fliplr(PSF_int))
        PSF_int_trans = np.fliplr(PSF_int_trans.T)
        # PSF_int_trans = np.abs(PSF_int_trans)
        # ids_szero = PSF_int_trans<0
        # PSF_int_trans[ids_szero] = 0
        # print(PSF_int_trans[ids_szero].shape[0],PSF_int_trans.shape)
        PSF_int_trans = PSF_int_trans/np.sum(PSF_int_trans)
        PSF_int_trans = PSF_int_trans-np.min(PSF_int_trans)
        PSF_int_trans = PSF_int_trans/np.sum(PSF_int_trans)
        # fits.writeto('/home/zhangxin/CSST_SIM/CSST_sim_develop/psf_test/psf.fits',PSF_int_trans)
Zhang Xin's avatar
Zhang Xin committed
        # DEBGU
        ids_szero = PSF_int_trans < 0
Fang Yuedong's avatar
Fang Yuedong committed

        n1 = np.sum(np.isinf(PSF_int_trans))
        n2 = np.sum(np.isnan(PSF_int_trans))
Zhang Xin's avatar
Zhang Xin committed
        if n1 > 0 or n2 > 0:
            print("DEBUG: PSFInterpSLS, inf:%d, nan:%d, 0 num:%d" %
                  (n1, n2, n01))
        if extrapolate is True:
            # for rep_i in np.arange(0, 2, 1):
Wei Chengliang's avatar
Wei Chengliang committed
            #     PSF_int_trans[rep_i,:] = 1e9*pow(10,rep_i)
            #     PSF_int_trans[-1-rep_i,:]  = 1e9*pow(10,rep_i)
            #     PSF_int_trans[:,rep_i] = 1e9*pow(10,rep_i)
            #     PSF_int_trans[:,-1-rep_i] = 1e9*pow(10,rep_i)
            PSF_int_trans = psf_extrapolate1(PSF_int_trans, ngg=ngg)
Wei Chengliang's avatar
Wei Chengliang committed
            # fits.writeto('/home/zhangxin/CSST_SIM/CSST_sim_develop/psf_test/psf_large.fits',PSF_int_trans)

Fang Yuedong's avatar
Fang Yuedong committed
        # from astropy.io import fits
        # fits.writeto(str(bandNo) + '_' + g_order+ '_psf_o.fits', PSF_int_trans)

        # if g_order in ['C','D','E']:
        #     g_simgma = contam_order_sigma[g_order]/pixel_size_arc
        #     PSF_int_trans = self.convolveWithGauss(PSF_int_trans,g_simgma)
        # n_m_size = int(m_size/2)
        #
        # n_PSF_int = np.zeros([n_m_size, n_m_size])
        #
        # for i in np.arange(n_m_size):
        #     for j in np.arange(n_m_size):
        #         n_PSF_int[i,j] = np.sum(PSF_int[2*i:2*i+2, 2*j:2*j+2])
        #
        # n_PSF_int = n_PSF_int/np.sum(n_PSF_int)

        # chip.img = galsim.ImageF(chip.npix_x, chip.npix_y)
        # chip.img.wcs = galsim.wcs.AffineTransform
        if galsimGSObject:

            # imPSFt = np.zeros([257,257])
            # imPSFt[0:256, 0:256] = imPSF
            # # imPSFt[120:130, 0:256] = 1.

            pixel_size_arc = np.rad2deg(self.pixsize * 1e-3 / 28) * 3600
            img = galsim.ImageF(PSF_int_trans, scale=pixel_size_arc)
            gsp = galsim.GSParams(folding_threshold=folding_threshold)
            # TEST: START
            # Use sheared PSF to test the PSF orientation
            # self.psf = galsim.InterpolatedImage(img, gsparams=gsp).shear(g1=0.8, g2=0.)
            # TEST: END
            self.psf = galsim.InterpolatedImage(img, gsparams=gsp)
            # if g_order in ['C','D','E']:
            #     add_psf  = galsim.Gaussian(sigma=contam_order_sigma[g_order], flux=1.0)
            #     self.psf = galsim.Convolve(self.psf, add_psf)
            wcs = chip.img.wcs.local(pos_img)
            scale = galsim.PixelScale(0.074)
            self.psf = wcs.toWorld(scale.toImage(
                self.psf), image_pos=(pos_img))

            # return self.PSFspin(x=px/0.01, y=py/0.01)
            return self.psf, galsim.Shear(e=0., beta=(np.pi/2)*galsim.radians)

        return PSF_int_trans, PSF_int

Zhang Xin's avatar
Zhang Xin committed
    def get_PSF_AND_convolve_withsubImg(self, chip, cutImg=None, pos_img_local=[1000, 1000], bandNo=1, g_order='A', grating_split_pos=3685):
        """
        Get the PSF at a given image position

        Parameters:
            chip: A 'Chip' object representing the chip we want to extract PSF from.
            pos_img: A 'galsim.Position' object representing the image position.
            bandpass: A 'galsim.Bandpass' object representing the wavelength range.
            pixSize: The pixels size of psf matrix
            findNeighMode: 'treeFind' or 'hoclistFind'
        Returns:
            PSF: A 'galsim.GSObject'.
        """
        order_IDs = {'A': '1', 'B': '0', 'C': '0', 'D': '0', 'E': '0'}
        contam_order_sigma = {'C': 0.28032344707964174,
                              'D': 0.39900182912061344, 'E': 1.1988309797685412}  # arcsec
        x_start = chip.x_cen/chip.pix_size - chip.npix_x / 2.
        y_start = chip.y_cen/chip.pix_size - chip.npix_y / 2.
        # print(pos_img.x - x_start)
        # pos_img_x = pos_img_local[0] + x_start
        # pos_img_y = pos_img_local[1] + y_start
        # pos_img = galsim.PositionD(pos_img_x, pos_img_y)
Zhang Xin's avatar
Zhang Xin committed
        # centerPos_local = cutImg.ncol/2.
        if pos_img_local[0] < grating_split_pos:
            psf_data = self.grating1_data
        else:
            psf_data = self.grating2_data

        grating_order = order_IDs[g_order]
        # if grating_order in ['-2','-1','2']:
        #     grating_order = '1'

        # if grating_order in ['0', '1']:
        psf_order = psf_data['order'+grating_order]
        psf_order_b = psf_order['band'+str(bandNo)]
        psf_b_dat = psf_order_b['band_data']
        # pos_p = psf_b_dat[1].data
        pos_p = psf_b_dat[1].data/chip.pix_size - np.array([y_start, x_start])

        pc_coeff = psf_b_dat[2].data
        pcs = psf_b_dat[0].data

        npc = 10
        m_size = int(pcs.shape[0]**0.5)
Zhang Xin's avatar
Zhang Xin committed
        sumImg = np.sum(cutImg.array)
        tmp_img = cutImg*0
        for j in np.arange(npc):
Zhang Xin's avatar
Zhang Xin committed
            X_ = np.hstack((pos_p[:, 1].flatten()[:, None], pos_p[:, 0].flatten()[
                           :, None]), dtype=np.float32)
            Z_ = (pc_coeff[j].astype(np.float32)).flatten()
            # print(pc_coeff[j].shape[0], pos_p[:,1].shape[0], pos_p[:,0].shape[0])
            cx_len = int(chip.npix_x)
            cy_len = int(chip.npix_y)
Zhang Xin's avatar
Zhang Xin committed
            n_x = np.arange(0, cx_len, 1, dtype=int)
            n_y = np.arange(0, cy_len, 1, dtype=int)
Zhang Xin's avatar
Zhang Xin committed
            M, N = np.meshgrid(n_x, n_y)
Zhang Xin's avatar
Zhang Xin committed
            # t1=datetime.datetime.now()
    #         U = interpolate.griddata(X_, Z_, (M[0:cy_len, 0:cx_len],N[0:cy_len, 0:cx_len]),
    # method='nearest',fill_value=1.0)
Zhang Xin's avatar
Zhang Xin committed
            b_img = galsim.Image(cx_len, cy_len)
Zhang Xin's avatar
Zhang Xin committed
            b_img.setOrigin(0, 0)
Zhang Xin's avatar
Zhang Xin committed
            bounds = cutImg.bounds & b_img.bounds
            if bounds.area() == 0:
Zhang Xin's avatar
Zhang Xin committed

            # ys = cutImg.ymin
Zhang Xin's avatar
Zhang Xin committed
            # if ys < 0:
Zhang Xin's avatar
Zhang Xin committed
            #     ys = 0
            # ye = cutImg.ymin+cutImg.nrow
Zhang Xin's avatar
Zhang Xin committed
            # if ye >= cy_len-1:
Zhang Xin's avatar
Zhang Xin committed
            #     ye = cy_len-1
            # if ye - ys <=0:
            #     continue
            # xs = cutImg.xmin
Zhang Xin's avatar
Zhang Xin committed
            # if xs < 0:
Zhang Xin's avatar
Zhang Xin committed
            #     xs = 0
            # xe = cutImg.xmin+cutImg.ncol
Zhang Xin's avatar
Zhang Xin committed
            # if xe >= cx_len-1:
Zhang Xin's avatar
Zhang Xin committed
            #     xe = cx_len-1
            # if xe - xs <=0:
            #     continue
            ys = bounds.ymin
            ye = bounds.ymax+1
            xs = bounds.xmin
            xe = bounds.xmax+1
Zhang Xin's avatar
Zhang Xin committed
            U = interpolate.griddata(X_, Z_, (M[ys:ye, xs:xe], N[ys:ye, xs:xe]),
                                     method='nearest', fill_value=1.0)
Zhang Xin's avatar
Zhang Xin committed
            # t2=datetime.datetime.now()
Zhang Xin's avatar
Zhang Xin committed

Zhang Xin's avatar
Zhang Xin committed
            # print("time interpolate:", t2-t1)
Zhang Xin's avatar
Zhang Xin committed
            # if U.shape != cutImg.array.shape:
            #     print('DEBUG:SHAPE',cutImg.ncol,cutImg.nrow,cutImg.xmin, cutImg.ymin)
            #     continue
            img_tmp = cutImg
            img_tmp[bounds] = img_tmp[bounds]*U
            psf = pcs[:, j].reshape(m_size, m_size)
Zhang Xin's avatar
Zhang Xin committed
            tmp_img = tmp_img + \
                signal.fftconvolve(img_tmp.array, psf, mode='same', axes=None)
Zhang Xin's avatar
Zhang Xin committed
            # t3=datetime.datetime.now()
            # print("time convole:", t3-t2)
Zhang Xin's avatar
Zhang Xin committed
            del img_tmp
Zhang Xin's avatar
Zhang Xin committed
        if np.sum(tmp_img.array) == 0:
Zhang Xin's avatar
Zhang Xin committed
            tmp_img = cutImg
        else:
            tmp_img = tmp_img/np.sum(tmp_img.array)*sumImg
    def convolveFullImgWithPCAPSF(self, chip, folding_threshold=5.e-3):
Zhang Xin's avatar
Zhang Xin committed
        keys_L1 = chip_utils.getChipSLSGratingID(chip.chipID)
        # keys_L2 = ['order-2','order-1','order0','order1','order2']
Zhang Xin's avatar
Zhang Xin committed
        keys_L2 = ['order0', 'order1']
        keys_L3 = ['w1', 'w2', 'w3', 'w4']

        npca = 10

        x_start = chip.x_cen/chip.pix_size - chip.npix_x / 2.
        y_start = chip.y_cen/chip.pix_size - chip.npix_y / 2.

Zhang Xin's avatar
Zhang Xin committed
        for i, gt in enumerate(keys_L1):
            psfCo = self.grating1_data
            if i > 0:
                psfCo = self.grating2_data
            for od in keys_L2:
                psfCo_L2 = psfCo['order1']
Zhang Xin's avatar
Zhang Xin committed
                if od in ['order-2', 'order-1', 'order0', 'order2']:
                    psfCo_L2 = psfCo['order0']
                for w in keys_L3:
                    img = chip.img_stack[gt][od][w]
                    pcs = psfCo_L2['band'+w[1]]['band_data'][0].data
Zhang Xin's avatar
Zhang Xin committed
                    pos_p = psfCo_L2['band'+w[1]]['band_data'][1].data / \
                        chip.pix_size - np.array([y_start, x_start])
                    pc_coeff = psfCo_L2['band'+w[1]]['band_data'][2].data
                    # print("DEBUG-----------",np.max(pos_p[:,1]),np.min(pos_p[:,1]), np.max(pos_p[:,0]),np.min(pos_p[:,0]))
                    sum_img = np.sum(img.array)
Zhang Xin's avatar
Zhang Xin committed

                    # coeff_mat = np.zeros([npca, chip.npix_y, chip.npix_x])
                    # for m in np.arange(chip.npix_y):
                    #     for n in np.arange(chip.npix_x):
                    #         px = n
                    #         py = m
Zhang Xin's avatar
Zhang Xin committed

                    #         dist2 = (pos_p[:, 1] - px)*(pos_p[:, 1] - px) + (pos_p[:, 0] - py)*(pos_p[:, 0] - py)
                    #         temp_sort_dist = np.zeros([dist2.shape[0], 2])
                    #         temp_sort_dist[:, 0] = np.arange(0, dist2.shape[0], 1)
                    #         temp_sort_dist[:, 1] = dist2
                    #         # print(temp_sort_dist)
                    #         dits2_sortlist = sorted(temp_sort_dist, key=lambda x: x[1])
                    #         # print(dits2_sortlist)
                    #         nearest4p = np.zeros([4, 3])
                    #         pc_coeff_4p = np.zeros([npca, 4])

                    #         for i in np.arange(4):
                    #             smaller_ids = int(dits2_sortlist[i][0])
                    #             nearest4p[i, 0] = pos_p[smaller_ids, 1]
                    #             nearest4p[i, 1] = pos_p[smaller_ids, 0]
                    #             # print(pos_p[smaller_ids, 1],pos_p[smaller_ids, 0])
                    #             nearest4p[i, 2] = dits2_sortlist[i][1]
                    #             pc_coeff_4p[:, i] = pc_coeff[npca, smaller_ids]
                    #         # idw_dist = 1/(np.sqrt((px-nearest4p[:, 0]) * (px-nearest4p[:, 0]) + (
                    #         #     py-nearest4p[:, 1]) * (py-nearest4p[:, 1])))
                    #         idw_dist = 1/(np.sqrt(nearest4p[:, 2]))

                    #         coeff_int = np.zeros(npca)
                    #         for i in np.arange(4):
                    #             coeff_int = coeff_int + pc_coeff_4p[:, i]*idw_dist[i]
                    #         coeff_mat[:, m, n] = coeff_int

                    m_size = int(pcs.shape[0]**0.5)
Zhang Xin's avatar
Zhang Xin committed
                    tmp_img = np.zeros_like(img.array, dtype=np.float32)
                    for j in np.arange(npca):
                        print(gt, od, w, j)
Zhang Xin's avatar
Zhang Xin committed
                        X_ = np.hstack((pos_p[:, 1].flatten()[:, None], pos_p[:, 0].flatten()[
                                       :, None]), dtype=np.float32)
                        Z_ = (pc_coeff[j].astype(np.float32)).flatten()
                        # print(pc_coeff[j].shape[0], pos_p[:,1].shape[0], pos_p[:,0].shape[0])
                        sub_size = 4
                        cx_len = int(chip.npix_x/sub_size)
                        cy_len = int(chip.npix_y/sub_size)
Zhang Xin's avatar
Zhang Xin committed
                        n_x = np.arange(0, chip.npix_x, sub_size, dtype=int)
                        n_y = np.arange(0, chip.npix_y, sub_size, dtype=int)
Zhang Xin's avatar
Zhang Xin committed
                        M, N = np.meshgrid(n_x, n_y)
Zhang Xin's avatar
Zhang Xin committed
                        t1 = datetime.datetime.now()
                #         U = interpolate.griddata(X_, Z_, (M[0:cy_len, 0:cx_len],N[0:cy_len, 0:cx_len]),
                # method='nearest',fill_value=1.0)
Zhang Xin's avatar
Zhang Xin committed
                        U1 = interpolate.griddata(X_, Z_, (M, N),
Zhang Xin's avatar
Zhang Xin committed
                                                  method='nearest', fill_value=1.0)
                        U = np.zeros_like(chip.img.array, dtype=np.float32)
Zhang Xin's avatar
Zhang Xin committed
                        for mi in np.arange(cy_len):
                            for mj in np.arange(cx_len):
Zhang Xin's avatar
Zhang Xin committed
                                U[mi*sub_size:(mi+1)*sub_size, mj *
                                  sub_size:(mj+1)*sub_size] = U1[mi, mj]
                        t2 = datetime.datetime.now()

                        print("time interpolate:", t2-t1)

                        img_tmp = img.array*U
                        psf = pcs[:, j].reshape(m_size, m_size)
Zhang Xin's avatar
Zhang Xin committed
                        tmp_img = tmp_img + \
                            signal.fftconvolve(
                                img_tmp, psf, mode='same', axes=None)
Zhang Xin's avatar
Zhang Xin committed
                        t3 = datetime.datetime.now()
                        print("time convole:", t3-t2)
                        del U
                        del U1
Zhang Xin's avatar
Zhang Xin committed

                    chip.img = chip.img + tmp_img*sum_img/np.sum(tmp_img)
                    del tmp_img
                    gc.collect()

Fang Yuedong's avatar
Fang Yuedong committed
        # pixSize = np.rad2deg(self.pixsize*1e-3/28)*3600  #set psf pixsize
        #
        # # assert self.iccd == int(chip.getChipLabel(chipID=chip.chipID)), 'ERROR: self.iccd != chip.chipID'
        # twave = self._findWave(bandpass)
        # if twave == -1:
        #     print("!!!PSF bandpass does not match.")
        #     exit()
        # PSFMat = self.psfMat[twave]
        # cen_col= self.cen_col[twave]
        # cen_row= self.cen_row[twave]
        #
        # px = (pos_img.x - chip.cen_pix_x)*0.01
        # py = (pos_img.y - chip.cen_pix_y)*0.01
        # if findNeighMode == 'treeFind':
        #     imPSF = psfMaker_IDW(px, py, PSFMat, cen_col, cen_row, IDWindex=2, OnlyNeighbors=True, PSFCentroidWgt=True)
        # if findNeighMode == 'hoclistFind':
        #     assert(self.hoc != 0), 'hoclist should be built correctly!'
        #     imPSF = psfMaker_IDW(px, py, PSFMat, cen_col, cen_row, IDWindex=2, OnlyNeighbors=True, hoc=self.hoc[twave], hoclist=self.hoclist[twave], PSFCentroidWgt=True)
        #
        # ############TEST: START
        # TestGaussian = False
        # if TestGaussian:
        #     gsx  = galsim.Gaussian(sigma=0.04)
        #     #pointing_pa = -23.433333
        #     imPSF= gsx.shear(g1=0.8, g2=0.).rotate(0.*galsim.degrees).drawImage(nx = 256, ny=256, scale=pixSize).array
        # ############TEST: END
        #
        # if galsimGSObject:
        #     imPSFt = np.zeros([257,257])
        #     imPSFt[0:256, 0:256] = imPSF
        #     # imPSFt[120:130, 0:256] = 1.
        #
        #     img = galsim.ImageF(imPSFt, scale=pixSize)
        #     gsp = galsim.GSParams(folding_threshold=folding_threshold)
        #     ############TEST: START
        #     # Use sheared PSF to test the PSF orientation
        #     # self.psf = galsim.InterpolatedImage(img, gsparams=gsp).shear(g1=0.8, g2=0.)
        #     ############TEST: END
        #     self.psf = galsim.InterpolatedImage(img, gsparams=gsp)
        #     wcs = chip.img.wcs.local(pos_img)
        #     scale = galsim.PixelScale(0.074)
        #     self.psf = wcs.toWorld(scale.toImage(self.psf), image_pos=(pos_img))
        #
        #     # return self.PSFspin(x=px/0.01, y=py/0.01)
        #     return self.psf, galsim.Shear(e=0., beta=(np.pi/2)*galsim.radians)
        # return imPSF
    #
    # def PSFspin(self, x, y):
    #     """
    #     The PSF profile at a given image position relative to the axis center
    #
    #     Parameters:
    #     theta : spin angles in a given exposure in unit of [arcsecond]
    #     dx, dy: relative position to the axis center in unit of [pixels]
    #
    #     Return:
    #     Spinned PSF: g1, g2 and axis ratio 'a/b'
    #     """
    #     a2Rad = np.pi/(60.0*60.0*180.0)
    #
    #     ff = self.sigGauss * 0.107 * (1000.0/10.0) # in unit of [pixels]
    #     rc = np.sqrt(x*x + y*y)
    #     cpix = rc*(self.sigSpin*a2Rad)
    #
    #     beta = (np.arctan2(y,x) + np.pi/2)
    #     ell = cpix**2/(2.0*ff**2+cpix**2)
    #     qr = np.sqrt((1.0+ell)/(1.0-ell))
    #     PSFshear = galsim.Shear(e=ell, beta=beta*galsim.radians)
    #     return self.psf.shear(PSFshear), PSFshear


if __name__ == '__main__':
    configfn = '/Users/zhangxin/Work/SlitlessSim/CSST_SIM/CSST_new_sim/csst-simulation/config/config_C6_dev.yaml'
    with open(configfn, "r") as stream:
        try:
            config = yaml.safe_load(stream)
            for key, value in config.items():
                print(key + " : " + str(value))
        except yaml.YAMLError as exc:
            print(exc)
    chip = Chip(chipID=1, config=config)
    filter_id, filter_type = chip.getChipFilter()
    filt = Filter(filter_id=filter_id,
                  filter_type=filter_type,
                  filter_param=FilterParam())

    psf_i = PSFInterpSLS(
        chip, filt, PSF_data_prefix="/Volumes/EAGET/CSST_PSF_data/SLS_PSF_PCA_fp/")
    pos_img = galsim.PositionD(x=25155, y=-22060)
    psf_im = psf_i.get_PSF(chip, pos_img=pos_img, g_order='1')