''' PSF interpolation for CSST-Sim NOTE: [iccd, iwave, ipsf] are counted from 1 to n, but [tccd, twave, tpsf] are counted from 0 to n-1 ''' import sys import time import copy import numpy as np import scipy.spatial as spatial import galsim import h5py from ObservationSim.PSF.PSFModel import PSFModel from ObservationSim.Instrument.Chip import ChipUtils as chip_utils import os from astropy.io import fits from astropy.modeling.models import Gaussian2D from scipy import signal LOG_DEBUG = False #***# NPSF = 900 #***# 30*30 PixSizeInMicrons = 5. #***# in microns ###find neighbors-KDtree### # def findNeighbors(tx, ty, px, py, dr=0.1, dn=1, OnlyDistance=True): # """ # find nearest neighbors by 2D-KDTree # # Parameters: # tx, ty (float, float): a given position # px, py (numpy.array, numpy.array): position data for tree # dr (float-optional): distance # dn (int-optional): nearest-N # OnlyDistance (bool-optional): only use distance to find neighbors. Default: True # # Returns: # dataq (numpy.array): index # """ # datax = px # datay = py # tree = spatial.KDTree(list(zip(datax.ravel(), datay.ravel()))) # # dataq=[] # rr = dr # if OnlyDistance == True: # dataq = tree.query_ball_point([tx, ty], rr) # if OnlyDistance == False: # while len(dataq) < dn: # dataq = tree.query_ball_point([tx, ty], rr) # rr += dr # dd = np.hypot(datax[dataq]-tx, datay[dataq]-ty) # ddSortindx = np.argsort(dd) # dataq = np.array(dataq)[ddSortindx[0:dn]] # return dataq # # ###find neighbors-hoclist### # def hocBuild(partx, party, nhocx, nhocy, dhocx, dhocy): # if np.max(partx) > nhocx*dhocx: # print('ERROR') # sys.exit() # if np.max(party) > nhocy*dhocy: # print('ERROR') # sys.exit() # # npart = partx.size # hoclist= np.zeros(npart, dtype=np.int32)-1 # hoc = np.zeros([nhocy, nhocx], dtype=np.int32)-1 # for ipart in range(npart): # ix = int(partx[ipart]/dhocx) # iy = int(party[ipart]/dhocy) # hoclist[ipart] = hoc[iy, ix] # hoc[iy, ix] = ipart # return hoc, hoclist # # def hocFind(px, py, dhocx, dhocy, hoc, hoclist): # ix = int(px/dhocx) # iy = int(py/dhocy) # # neigh=[] # it = hoc[iy, ix] # while it != -1: # neigh.append(it) # it = hoclist[it] # return neigh # # def findNeighbors_hoclist(px, py, tx=None,ty=None, dn=4, hoc=None, hoclist=None): # nhocy = nhocx = 20 # # pxMin = np.min(px) # pxMax = np.max(px) # pyMin = np.min(py) # pyMax = np.max(py) # # dhocx = (pxMax - pxMin)/(nhocx-1) # dhocy = (pyMax - pyMin)/(nhocy-1) # partx = px - pxMin +dhocx/2 # party = py - pyMin +dhocy/2 # # if hoc is None: # hoc, hoclist = hocBuild(partx, party, nhocx, nhocy, dhocx, dhocy) # return hoc, hoclist # # if hoc is not None: # tx = tx - pxMin +dhocx/2 # ty = ty - pyMin +dhocy/2 # itx = int(tx/dhocx) # ity = int(ty/dhocy) # # ps = [-1, 0, 1] # neigh=[] # for ii in range(3): # for jj in range(3): # ix = itx + ps[ii] # iy = ity + ps[jj] # if ix < 0: # continue # if iy < 0: # continue # if ix > nhocx-1: # continue # if iy > nhocy-1: # continue # # #neightt = myUtil.hocFind(ppx, ppy, dhocx, dhocy, hoc, hoclist) # it = hoc[iy, ix] # while it != -1: # neigh.append(it) # it = hoclist[it] # #neigh.append(neightt) # #ll = [i for k in neigh for i in k] # if dn != -1: # ptx = np.array(partx[neigh]) # pty = np.array(party[neigh]) # dd = np.hypot(ptx-tx, pty-ty) # idx = np.argsort(dd) # neigh= np.array(neigh)[idx[0:dn]] # return neigh # # # ###PSF-IDW### # def psfMaker_IDW(px, py, PSFMat, cen_col, cen_row, IDWindex=2, OnlyNeighbors=True, hoc=None, hoclist=None, PSFCentroidWgt=False): # """ # psf interpolation by IDW # # Parameters: # px, py (float, float): position of the target # PSFMat (numpy.array): image # cen_col, cen_row (numpy.array, numpy.array): potions of the psf centers # IDWindex (int-optional): the power index of IDW # OnlyNeighbors (bool-optional): only neighbors are used for psf interpolation # # Returns: # psfMaker (numpy.array) # """ # # minimum_psf_weight = 1e-8 # ref_col = px # ref_row = py # # ngy, ngx = PSFMat[0, :, :].shape # npsf = PSFMat[:, :, :].shape[0] # psfWeight = np.zeros([npsf]) # # if OnlyNeighbors == True: # if hoc is None: # neigh = findNeighbors(px, py, cen_col, cen_row, dr=5., dn=4, OnlyDistance=False) # if hoc is not None: # neigh = findNeighbors_hoclist(cen_col, cen_row, tx=px,ty=py, dn=4, hoc=hoc, hoclist=hoclist) # # neighFlag = np.zeros(npsf) # neighFlag[neigh] = 1 # # for ipsf in range(npsf): # if OnlyNeighbors == True: # if neighFlag[ipsf] != 1: # continue # # dist = np.sqrt((ref_col - cen_col[ipsf])**2 + (ref_row - cen_row[ipsf])**2) # if IDWindex == 1: # psfWeight[ipsf] = dist # if IDWindex == 2: # psfWeight[ipsf] = dist**2 # if IDWindex == 3: # psfWeight[ipsf] = dist**3 # if IDWindex == 4: # psfWeight[ipsf] = dist**4 # psfWeight[ipsf] = max(psfWeight[ipsf], minimum_psf_weight) # psfWeight[ipsf] = 1./psfWeight[ipsf] # psfWeight /= np.sum(psfWeight) # # psfMaker = np.zeros([ngy, ngx], dtype=np.float32) # for ipsf in range(npsf): # if OnlyNeighbors == True: # if neighFlag[ipsf] != 1: # continue # # iPSFMat = PSFMat[ipsf, :, :].copy() # ipsfWeight = psfWeight[ipsf] # # psfMaker += iPSFMat * ipsfWeight # psfMaker /= np.nansum(psfMaker) # # return psfMaker ###define PSFInterp### class PSFInterpSLS(PSFModel): def __init__(self, chip, filt,PSF_data_prefix="", sigSpin=0, psfRa=0.15, pix_size = 0.005): if LOG_DEBUG: print('===================================================') print('DEBUG: psf module for csstSim ' \ +time.strftime("(%Y-%m-%d %H:%M:%S)", time.localtime()), flush=True) print('===================================================') self.sigSpin = sigSpin self.sigGauss = psfRa self.grating_ids = chip_utils.getChipSLSGratingID(chip.chipID) _,self.grating_type = chip.getChipFilter(chipID=chip.chipID) self.data_folder = PSF_data_prefix self.getPSFDataFromFile(filt) self.pixsize = pix_size # um def getPSFDataFromFile(self, filt): gratingInwavelist = {'GU':0,'GV':1,'GI':2} grating_orders = ['0','1'] waveListFn = self.data_folder + '/wavelist.dat' wavelists = np.loadtxt(waveListFn) self.waveList = wavelists[:,gratingInwavelist[self.grating_type]] bandranges = np.zeros([4,2]) midBand = (self.waveList[0:3] + self.waveList[1:4])/2.*10000. bandranges[0,0] = filt.blue_limit bandranges[1:4,0] = midBand bandranges[0:3, 1] = midBand bandranges[3,1] = filt.red_limit self.bandranges = bandranges self.grating1_data = {} g_folder = self.data_folder + '/' + self.grating_ids[0] + '/' for g_order in grating_orders: g_folder_order = g_folder + 'PSF_Order_' + g_order + '/' grating_order_data = {} for bandi in [1,2,3,4]: subBand_data = {} subBand_data['bandrange'] = bandranges[bandi-1] final_folder = g_folder_order + str(bandi) + '/' print(final_folder) pca_fs = os.listdir(final_folder) for fname in pca_fs: if ('_PCs.fits' in fname) and (fname[0] != '.'): fname_ = final_folder + fname hdu = fits.open(fname_) subBand_data['band_data'] = hdu grating_order_data['band'+str(bandi)] = subBand_data self.grating1_data['order'+g_order] = grating_order_data self.grating2_data = {} g_folder = self.data_folder + '/' + self.grating_ids[1] + '/' for g_order in grating_orders: g_folder_order = g_folder + 'PSF_Order_' + g_order + '/' grating_order_data = {} for bandi in [1, 2, 3, 4]: subBand_data = {} subBand_data['bandrange'] = bandranges[bandi - 1] final_folder = g_folder_order + str(bandi) + '/' print(final_folder) pca_fs = os.listdir(final_folder) for fname in pca_fs: if ('_PCs.fits' in fname) and (fname[0] != '.'): fname_ = final_folder + fname hdu = fits.open(fname_) subBand_data['band_data'] = hdu grating_order_data['band' + str(bandi)] = subBand_data self.grating2_data['order' + g_order] = grating_order_data # # # # def _getPSFwave(self, iccd, PSF_data_file, PSF_data_prefix): # # fq = h5py.File(PSF_data_file+'/' +PSF_data_prefix +'psfCube_ccd{:}.h5'.format(iccd), 'r') # fq = h5py.File(PSF_data_file+'/' +PSF_data_prefix +'psfCube_{:}.h5'.format(iccd), 'r') # nwave = len(fq.keys()) # fq.close() # return nwave # # # def _loadPSF(self, iccd, PSF_data_file, PSF_data_prefix): # psfSet = [] # # fq = h5py.File(PSF_data_file+'/' +PSF_data_prefix +'psfCube_ccd{:}.h5'.format(iccd), 'r') # fq = h5py.File(PSF_data_file+'/' +PSF_data_prefix +'psfCube_{:}.h5'.format(iccd), 'r') # for ii in range(self.nwave): # iwave = ii+1 # psfWave = [] # # fq_iwave = fq['w_{:}'.format(iwave)] # for jj in range(self.npsf): # ipsf = jj+1 # psfInfo = {} # psfInfo['wavelength']= fq_iwave['wavelength'][()] # # fq_iwave_ipsf = fq_iwave['psf_{:}'.format(ipsf)] # psfInfo['pixsize'] = PixSizeInMicrons # psfInfo['field_x'] = fq_iwave_ipsf['field_x'][()] # psfInfo['field_y'] = fq_iwave_ipsf['field_y'][()] # psfInfo['image_x'] = fq_iwave_ipsf['image_x'][()] # psfInfo['image_y'] = fq_iwave_ipsf['image_y'][()] # psfInfo['centroid_x']= fq_iwave_ipsf['cx'][()] # psfInfo['centroid_y']= fq_iwave_ipsf['cy'][()] # psfInfo['psfMat'] = fq_iwave_ipsf['psfMat'][()] # # psfWave.append(psfInfo) # psfSet.append(psfWave) # fq.close() # # if LOG_DEBUG: # print('psfSet has been loaded:', flush=True) # print('psfSet[iwave][ipsf][keys]:', psfSet[0][0].keys(), flush=True) # return psfSet # # # def _findWave(self, bandpass): # if isinstance(bandpass,int): # twave = bandpass # return twave # # for twave in range(self.nwave): # bandwave = self.PSF_data[twave][0]['wavelength'] # if bandpass.blue_limit < bandwave and bandwave < bandpass.red_limit: # return twave # return -1 # # def convolveWithGauss(self, img=None, sigma=1): offset = int(np.ceil(sigma * 3)) g_size = 2 * offset + 1 m_cen = int(g_size / 2) print('-----',g_size) g_PSF_ = Gaussian2D(1, m_cen, m_cen, sigma, sigma) yp, xp = np.mgrid[0:g_size, 0:g_size] g_PSF = g_PSF_(xp, yp) psf = g_PSF / g_PSF.sum() convImg = signal.fftconvolve(img, psf, mode='full', axes=None) convImg = convImg/np.sum(convImg) return convImg def get_PSF(self, chip, pos_img_local = [1000,1000], bandNo = 1, galsimGSObject=True, folding_threshold=5.e-3, g_order = 'A', grating_split_pos=3685): """ Get the PSF at a given image position Parameters: chip: A 'Chip' object representing the chip we want to extract PSF from. pos_img: A 'galsim.Position' object representing the image position. bandpass: A 'galsim.Bandpass' object representing the wavelength range. pixSize: The pixels size of psf matrix findNeighMode: 'treeFind' or 'hoclistFind' Returns: PSF: A 'galsim.GSObject'. """ order_IDs = {'A': '1', 'B': '0' ,'C': '0', 'D': '0', 'E': '0'} contam_order_sigma = {'C':0.28032344707964174,'D':0.39900182912061344,'E':1.1988309797685412} #arcsec x_start = chip.x_cen/chip.pix_size - chip.npix_x / 2. y_start = chip.y_cen/chip.pix_size - chip.npix_y / 2. # print(pos_img.x - x_start) pos_img_x = pos_img_local[0] + x_start pos_img_y = pos_img_local[1] + y_start pos_img = galsim.PositionD(pos_img_x, pos_img_y) if pos_img_local[0] < grating_split_pos: psf_data = self.grating1_data else: psf_data = self.grating2_data grating_order = order_IDs[g_order] # if grating_order in ['-2','-1','2']: # grating_order = '1' # if grating_order in ['0', '1']: psf_order = psf_data['order'+grating_order] psf_order_b = psf_order['band'+str(bandNo)] psf_b_dat = psf_order_b['band_data'] pos_p = psf_b_dat[1].data pc_coeff = psf_b_dat[2].data pcs = psf_b_dat[0].data # print(max(pos_p[:,0]), min(pos_p[:,0]),max(pos_p[:,1]), min(pos_p[:,1])) # print(chip.x_cen, chip.y_cen) # print(pos_p) px = pos_img.x*chip.pix_size py = pos_img.y*chip.pix_size dist2=(pos_p[:,1] - px)*(pos_p[:,1] - px) + (pos_p[:,0] - py)*(pos_p[:,0] - py) temp_sort_dist = np.zeros([dist2.shape[0],2]) temp_sort_dist[:, 0] = np.arange(0, dist2.shape[0],1) temp_sort_dist[:, 1] = dist2 # print(temp_sort_dist) dits2_sortlist = sorted(temp_sort_dist, key=lambda x:x[1]) # print(dits2_sortlist) nearest4p = np.zeros([4,2]) pc_coeff_4p = np.zeros([pc_coeff.data.shape[0],4]) for i in np.arange(4): smaller_ids = int(dits2_sortlist[i][0]) nearest4p[i, 0] = pos_p[smaller_ids, 1] nearest4p[i, 1] = pos_p[smaller_ids, 0] pc_coeff_4p[:,i] = pc_coeff[:,smaller_ids] idw_dist = 1/(np.sqrt((px-nearest4p[:,0]) * (px-nearest4p[:,0]) + (py-nearest4p[:,1]) * (py-nearest4p[:,1]))) coeff_int = np.zeros(pc_coeff.data.shape[0]) for i in np.arange(4): coeff_int = coeff_int + pc_coeff_4p[:,i]*idw_dist[i] coeff_int = coeff_int / np.sum(coeff_int) npc = 10 m_size = int(pcs.shape[0]**0.5) PSF_int = np.dot(pcs[:,0:npc],coeff_int[0:npc]).reshape(m_size,m_size) # PSF_int = PSF_int/np.sum(PSF_int) PSF_int_trans = np.flipud(np.fliplr(PSF_int)) PSF_int_trans = np.fliplr(PSF_int_trans.T) # PSF_int_trans = np.abs(PSF_int_trans) # ids_szero = PSF_int_trans<0 # PSF_int_trans[ids_szero] = 0 # print(PSF_int_trans[ids_szero].shape[0],PSF_int_trans.shape) PSF_int_trans = PSF_int_trans/np.sum(PSF_int_trans) # from astropy.io import fits # fits.writeto(str(bandNo) + '_' + g_order+ '_psf_o.fits', PSF_int_trans) # if g_order in ['C','D','E']: # g_simgma = contam_order_sigma[g_order]/pixel_size_arc # PSF_int_trans = self.convolveWithGauss(PSF_int_trans,g_simgma) # n_m_size = int(m_size/2) # # n_PSF_int = np.zeros([n_m_size, n_m_size]) # # for i in np.arange(n_m_size): # for j in np.arange(n_m_size): # n_PSF_int[i,j] = np.sum(PSF_int[2*i:2*i+2, 2*j:2*j+2]) # # n_PSF_int = n_PSF_int/np.sum(n_PSF_int) # chip.img = galsim.ImageF(chip.npix_x, chip.npix_y) # chip.img.wcs = galsim.wcs.AffineTransform if galsimGSObject: # imPSFt = np.zeros([257,257]) # imPSFt[0:256, 0:256] = imPSF # # imPSFt[120:130, 0:256] = 1. pixel_size_arc = np.rad2deg(self.pixsize * 1e-3 / 28) * 3600 img = galsim.ImageF(PSF_int_trans, scale=pixel_size_arc) gsp = galsim.GSParams(folding_threshold=folding_threshold) ############TEST: START # Use sheared PSF to test the PSF orientation # self.psf = galsim.InterpolatedImage(img, gsparams=gsp).shear(g1=0.8, g2=0.) ############TEST: END self.psf = galsim.InterpolatedImage(img, gsparams=gsp) # if g_order in ['C','D','E']: # add_psf = galsim.Gaussian(sigma=contam_order_sigma[g_order], flux=1.0) # self.psf = galsim.Convolve(self.psf, add_psf) wcs = chip.img.wcs.local(pos_img) scale = galsim.PixelScale(0.074) self.psf = wcs.toWorld(scale.toImage(self.psf), image_pos=(pos_img)) # return self.PSFspin(x=px/0.01, y=py/0.01) return self.psf, galsim.Shear(e=0., beta=(np.pi/2)*galsim.radians) return PSF_int_trans, PSF_int # pixSize = np.rad2deg(self.pixsize*1e-3/28)*3600 #set psf pixsize # # # assert self.iccd == int(chip.getChipLabel(chipID=chip.chipID)), 'ERROR: self.iccd != chip.chipID' # twave = self._findWave(bandpass) # if twave == -1: # print("!!!PSF bandpass does not match.") # exit() # PSFMat = self.psfMat[twave] # cen_col= self.cen_col[twave] # cen_row= self.cen_row[twave] # # px = (pos_img.x - chip.cen_pix_x)*0.01 # py = (pos_img.y - chip.cen_pix_y)*0.01 # if findNeighMode == 'treeFind': # imPSF = psfMaker_IDW(px, py, PSFMat, cen_col, cen_row, IDWindex=2, OnlyNeighbors=True, PSFCentroidWgt=True) # if findNeighMode == 'hoclistFind': # assert(self.hoc != 0), 'hoclist should be built correctly!' # imPSF = psfMaker_IDW(px, py, PSFMat, cen_col, cen_row, IDWindex=2, OnlyNeighbors=True, hoc=self.hoc[twave], hoclist=self.hoclist[twave], PSFCentroidWgt=True) # # ############TEST: START # TestGaussian = False # if TestGaussian: # gsx = galsim.Gaussian(sigma=0.04) # #pointing_pa = -23.433333 # imPSF= gsx.shear(g1=0.8, g2=0.).rotate(0.*galsim.degrees).drawImage(nx = 256, ny=256, scale=pixSize).array # ############TEST: END # # if galsimGSObject: # imPSFt = np.zeros([257,257]) # imPSFt[0:256, 0:256] = imPSF # # imPSFt[120:130, 0:256] = 1. # # img = galsim.ImageF(imPSFt, scale=pixSize) # gsp = galsim.GSParams(folding_threshold=folding_threshold) # ############TEST: START # # Use sheared PSF to test the PSF orientation # # self.psf = galsim.InterpolatedImage(img, gsparams=gsp).shear(g1=0.8, g2=0.) # ############TEST: END # self.psf = galsim.InterpolatedImage(img, gsparams=gsp) # wcs = chip.img.wcs.local(pos_img) # scale = galsim.PixelScale(0.074) # self.psf = wcs.toWorld(scale.toImage(self.psf), image_pos=(pos_img)) # # # return self.PSFspin(x=px/0.01, y=py/0.01) # return self.psf, galsim.Shear(e=0., beta=(np.pi/2)*galsim.radians) # return imPSF # # def PSFspin(self, x, y): # """ # The PSF profile at a given image position relative to the axis center # # Parameters: # theta : spin angles in a given exposure in unit of [arcsecond] # dx, dy: relative position to the axis center in unit of [pixels] # # Return: # Spinned PSF: g1, g2 and axis ratio 'a/b' # """ # a2Rad = np.pi/(60.0*60.0*180.0) # # ff = self.sigGauss * 0.107 * (1000.0/10.0) # in unit of [pixels] # rc = np.sqrt(x*x + y*y) # cpix = rc*(self.sigSpin*a2Rad) # # beta = (np.arctan2(y,x) + np.pi/2) # ell = cpix**2/(2.0*ff**2+cpix**2) # qr = np.sqrt((1.0+ell)/(1.0-ell)) # PSFshear = galsim.Shear(e=ell, beta=beta*galsim.radians) # return self.psf.shear(PSFshear), PSFshear from ObservationSim.Instrument import Filter, FilterParam, Chip import yaml if __name__ == '__main__': configfn = '/Users/zhangxin/Work/SlitlessSim/CSST_SIM/CSST_new_sim/csst-simulation/config/config_C6_dev.yaml' with open(configfn, "r") as stream: try: config = yaml.safe_load(stream) for key, value in config.items(): print (key + " : " + str(value)) except yaml.YAMLError as exc: print(exc) chip = Chip(chipID=1,config=config) filter_id, filter_type = chip.getChipFilter() filt = Filter(filter_id=filter_id, filter_type=filter_type, filter_param=FilterParam()) psf_i = PSFInterpSLS(chip, filt,PSF_data_prefix="/Volumes/EAGET/CSST_PSF_data/SLS_PSF_PCA_fp/") pos_img = galsim.PositionD(x=25155, y=-22060) psf_im = psf_i.get_PSF(chip, pos_img = pos_img, g_order = '1')