''' PSF interpolation for CSST-Sim NOTE: [iccd, iwave, ipsf] are counted from 1 to n, but [tccd, twave, tpsf] are counted from 0 to n-1 ''' import sys import time import copy import numpy as np import scipy.spatial as spatial import galsim import h5py from ObservationSim.PSF.PSFModel import PSFModel LOG_DEBUG = False #***# NPSF = 900 #***# 30*30 PixSizeInMicrons = 5. #***# in microns ###find neighbors-KDtree### def findNeighbors(tx, ty, px, py, dr=0.1, dn=1, OnlyDistance=True): """ find nearest neighbors by 2D-KDTree Parameters: tx, ty (float, float): a given position px, py (numpy.array, numpy.array): position data for tree dr (float-optional): distance dn (int-optional): nearest-N OnlyDistance (bool-optional): only use distance to find neighbors. Default: True Returns: dataq (numpy.array): index """ datax = px datay = py tree = spatial.KDTree(list(zip(datax.ravel(), datay.ravel()))) dataq=[] rr = dr if OnlyDistance == True: dataq = tree.query_ball_point([tx, ty], rr) if OnlyDistance == False: while len(dataq) < dn: dataq = tree.query_ball_point([tx, ty], rr) rr += dr dd = np.hypot(datax[dataq]-tx, datay[dataq]-ty) ddSortindx = np.argsort(dd) dataq = np.array(dataq)[ddSortindx[0:dn]] return dataq ###find neighbors-hoclist### def hocBuild(partx, party, nhocx, nhocy, dhocx, dhocy): if np.max(partx) > nhocx*dhocx: print('ERROR') sys.exit() if np.max(party) > nhocy*dhocy: print('ERROR') sys.exit() npart = partx.size hoclist= np.zeros(npart, dtype=np.int32)-1 hoc = np.zeros([nhocy, nhocx], dtype=np.int32)-1 for ipart in range(npart): ix = int(partx[ipart]/dhocx) iy = int(party[ipart]/dhocy) hoclist[ipart] = hoc[iy, ix] hoc[iy, ix] = ipart return hoc, hoclist def hocFind(px, py, dhocx, dhocy, hoc, hoclist): ix = int(px/dhocx) iy = int(py/dhocy) neigh=[] it = hoc[iy, ix] while it != -1: neigh.append(it) it = hoclist[it] return neigh def findNeighbors_hoclist(px, py, tx=None,ty=None, dn=4, hoc=None, hoclist=None): nhocy = nhocx = 20 pxMin = np.min(px) pxMax = np.max(px) pyMin = np.min(py) pyMax = np.max(py) dhocx = (pxMax - pxMin)/(nhocx-1) dhocy = (pyMax - pyMin)/(nhocy-1) partx = px - pxMin +dhocx/2 party = py - pyMin +dhocy/2 if hoc is None: hoc, hoclist = hocBuild(partx, party, nhocx, nhocy, dhocx, dhocy) return hoc, hoclist if hoc is not None: tx = tx - pxMin +dhocx/2 ty = ty - pyMin +dhocy/2 itx = int(tx/dhocx) ity = int(ty/dhocy) ps = [-1, 0, 1] neigh=[] for ii in range(3): for jj in range(3): ix = itx + ps[ii] iy = ity + ps[jj] if ix < 0: continue if iy < 0: continue if ix > nhocx-1: continue if iy > nhocy-1: continue #neightt = myUtil.hocFind(ppx, ppy, dhocx, dhocy, hoc, hoclist) it = hoc[iy, ix] while it != -1: neigh.append(it) it = hoclist[it] #neigh.append(neightt) #ll = [i for k in neigh for i in k] if dn != -1: ptx = np.array(partx[neigh]) pty = np.array(party[neigh]) dd = np.hypot(ptx-tx, pty-ty) idx = np.argsort(dd) neigh= np.array(neigh)[idx[0:dn]] return neigh ###PSF-IDW### def psfMaker_IDW(px, py, PSFMat, cen_col, cen_row, IDWindex=2, OnlyNeighbors=True, hoc=None, hoclist=None, PSFCentroidWgt=False): """ psf interpolation by IDW Parameters: px, py (float, float): position of the target PSFMat (numpy.array): image cen_col, cen_row (numpy.array, numpy.array): potions of the psf centers IDWindex (int-optional): the power index of IDW OnlyNeighbors (bool-optional): only neighbors are used for psf interpolation Returns: psfMaker (numpy.array) """ minimum_psf_weight = 1e-8 ref_col = px ref_row = py ngy, ngx = PSFMat[0, :, :].shape npsf = PSFMat[:, :, :].shape[0] psfWeight = np.zeros([npsf]) if OnlyNeighbors == True: if hoc is None: neigh = findNeighbors(px, py, cen_col, cen_row, dr=5., dn=4, OnlyDistance=False) if hoc is not None: neigh = findNeighbors_hoclist(cen_col, cen_row, tx=px,ty=py, dn=4, hoc=hoc, hoclist=hoclist) neighFlag = np.zeros(npsf) neighFlag[neigh] = 1 for ipsf in range(npsf): if OnlyNeighbors == True: if neighFlag[ipsf] != 1: continue dist = np.sqrt((ref_col - cen_col[ipsf])**2 + (ref_row - cen_row[ipsf])**2) if IDWindex == 1: psfWeight[ipsf] = dist if IDWindex == 2: psfWeight[ipsf] = dist**2 if IDWindex == 3: psfWeight[ipsf] = dist**3 if IDWindex == 4: psfWeight[ipsf] = dist**4 psfWeight[ipsf] = max(psfWeight[ipsf], minimum_psf_weight) psfWeight[ipsf] = 1./psfWeight[ipsf] psfWeight /= np.sum(psfWeight) psfMaker = np.zeros([ngy, ngx], dtype=np.float32) for ipsf in range(npsf): if OnlyNeighbors == True: if neighFlag[ipsf] != 1: continue iPSFMat = PSFMat[ipsf, :, :].copy() ipsfWeight = psfWeight[ipsf] psfMaker += iPSFMat * ipsfWeight psfMaker /= np.nansum(psfMaker) return psfMaker ###define PSFInterp### class PSFInterp(PSFModel): def __init__(self, chip, npsf=NPSF, PSF_data=None, PSF_data_file=None, PSF_data_prefix="", sigSpin=0, psfRa=0.15, HocBuild=False): if LOG_DEBUG: print('===================================================') print('DEBUG: psf module for csstSim ' \ +time.strftime("(%Y-%m-%d %H:%M:%S)", time.localtime()), flush=True) print('===================================================') self.sigSpin = sigSpin self.sigGauss = psfRa self.iccd = int(chip.getChipLabel(chipID=chip.chipID)) # self.iccd = chip.chip_name if PSF_data_file == None: print('Error - PSF_data_file is None') sys.exit() self.nwave= self._getPSFwave(self.iccd, PSF_data_file, PSF_data_prefix) self.npsf = npsf self.PSF_data = self._loadPSF(self.iccd, PSF_data_file, PSF_data_prefix) if LOG_DEBUG: print('nwave-{:} on ccd-{:}::'.format(self.nwave, self.iccd), flush=True) print('self.PSF_data ... ok', flush=True) print('Preparing self.[psfMat,cen_col,cen_row] for psfMaker ... ', end='', flush=True) ngy, ngx = self.PSF_data[0][0]['psfMat'].shape self.psfMat = np.zeros([self.nwave, self.npsf, ngy, ngx], dtype=np.float32) self.cen_col= np.zeros([self.nwave, self.npsf], dtype=np.float32) self.cen_row= np.zeros([self.nwave, self.npsf], dtype=np.float32) self.hoc =[] self.hoclist=[] for twave in range(self.nwave): for tpsf in range(self.npsf): self.psfMat[twave, tpsf, :, :] = self.PSF_data[twave][tpsf]['psfMat'] self.PSF_data[twave][tpsf]['psfMat'] = 0 ###free psfMat self.pixsize = self.PSF_data[twave][tpsf]['pixsize']*1e-3 ##mm self.cen_col[twave, tpsf] = self.PSF_data[twave][tpsf]['image_x'] + self.PSF_data[twave][tpsf]['centroid_x'] self.cen_row[twave, tpsf] = self.PSF_data[twave][tpsf]['image_y'] + self.PSF_data[twave][tpsf]['centroid_y'] if HocBuild: #hoclist on twave for neighborsFinding hoc,hoclist = findNeighbors_hoclist(self.cen_col[twave], self.cen_row[twave]) self.hoc.append(hoc) self.hoclist.append(hoclist) if LOG_DEBUG: print('ok', flush=True) def _getPSFwave(self, iccd, PSF_data_file, PSF_data_prefix): # fq = h5py.File(PSF_data_file+'/' +PSF_data_prefix +'psfCube_ccd{:}.h5'.format(iccd), 'r') fq = h5py.File(PSF_data_file+'/' +PSF_data_prefix +'psfCube_{:}.h5'.format(iccd), 'r') nwave = len(fq.keys()) fq.close() return nwave def _loadPSF(self, iccd, PSF_data_file, PSF_data_prefix): psfSet = [] # fq = h5py.File(PSF_data_file+'/' +PSF_data_prefix +'psfCube_ccd{:}.h5'.format(iccd), 'r') fq = h5py.File(PSF_data_file+'/' +PSF_data_prefix +'psfCube_{:}.h5'.format(iccd), 'r') for ii in range(self.nwave): iwave = ii+1 psfWave = [] fq_iwave = fq['w_{:}'.format(iwave)] for jj in range(self.npsf): ipsf = jj+1 psfInfo = {} psfInfo['wavelength']= fq_iwave['wavelength'][()] fq_iwave_ipsf = fq_iwave['psf_{:}'.format(ipsf)] psfInfo['pixsize'] = PixSizeInMicrons psfInfo['field_x'] = fq_iwave_ipsf['field_x'][()] psfInfo['field_y'] = fq_iwave_ipsf['field_y'][()] psfInfo['image_x'] = fq_iwave_ipsf['image_x'][()] psfInfo['image_y'] = fq_iwave_ipsf['image_y'][()] psfInfo['centroid_x']= fq_iwave_ipsf['cx'][()] psfInfo['centroid_y']= fq_iwave_ipsf['cy'][()] psfInfo['psfMat'] = fq_iwave_ipsf['psfMat'][()] psfWave.append(psfInfo) psfSet.append(psfWave) fq.close() if LOG_DEBUG: print('psfSet has been loaded:', flush=True) print('psfSet[iwave][ipsf][keys]:', psfSet[0][0].keys(), flush=True) return psfSet def _findWave(self, bandpass): if isinstance(bandpass,int): twave = bandpass return twave for twave in range(self.nwave): bandwave = self.PSF_data[twave][0]['wavelength'] if bandpass.blue_limit < bandwave and bandwave < bandpass.red_limit: return twave return -1 def get_PSF(self, chip, pos_img, bandpass, galsimGSObject=True, findNeighMode='treeFind', folding_threshold=5.e-3, pointing_pa=0.0): """ Get the PSF at a given image position Parameters: chip: A 'Chip' object representing the chip we want to extract PSF from. pos_img: A 'galsim.Position' object representing the image position. bandpass: A 'galsim.Bandpass' object representing the wavelength range. pixSize: The pixels size of psf matrix findNeighMode: 'treeFind' or 'hoclistFind' Returns: PSF: A 'galsim.GSObject'. """ pixSize = np.rad2deg(self.pixsize*1e-3/28)*3600 #set psf pixsize # assert self.iccd == int(chip.getChipLabel(chipID=chip.chipID)), 'ERROR: self.iccd != chip.chipID' twave = self._findWave(bandpass) if twave == -1: print("!!!PSF bandpass does not match.") exit() PSFMat = self.psfMat[twave] cen_col= self.cen_col[twave] cen_row= self.cen_row[twave] px = (pos_img.x - chip.cen_pix_x)*0.01 py = (pos_img.y - chip.cen_pix_y)*0.01 if findNeighMode == 'treeFind': imPSF = psfMaker_IDW(px, py, PSFMat, cen_col, cen_row, IDWindex=2, OnlyNeighbors=True, PSFCentroidWgt=True) if findNeighMode == 'hoclistFind': assert(self.hoc != 0), 'hoclist should be built correctly!' imPSF = psfMaker_IDW(px, py, PSFMat, cen_col, cen_row, IDWindex=2, OnlyNeighbors=True, hoc=self.hoc[twave], hoclist=self.hoclist[twave], PSFCentroidWgt=True) ############TEST: START TestGaussian = False if TestGaussian: gsx = galsim.Gaussian(sigma=0.04) #pointing_pa = -23.433333 imPSF= gsx.shear(g1=0.8, g2=0.).rotate(0.*galsim.degrees).drawImage(nx = 256, ny=256, scale=pixSize).array ############TEST: END if galsimGSObject: imPSFt = np.zeros([257,257]) imPSFt[0:256, 0:256] = imPSF # imPSFt[120:130, 0:256] = 1. img = galsim.ImageF(imPSFt, scale=pixSize) gsp = galsim.GSParams(folding_threshold=folding_threshold) ############TEST: START # Use sheared PSF to test the PSF orientation # self.psf = galsim.InterpolatedImage(img, gsparams=gsp).shear(g1=0.8, g2=0.) ############TEST: END self.psf = galsim.InterpolatedImage(img, gsparams=gsp) wcs = chip.img.wcs.local(pos_img) scale = galsim.PixelScale(0.074) self.psf = wcs.toWorld(scale.toImage(self.psf), image_pos=(pos_img)) # return self.PSFspin(x=px/0.01, y=py/0.01) return self.psf, galsim.Shear(e=0., beta=(np.pi/2)*galsim.radians) return imPSF def PSFspin(self, x, y): """ The PSF profile at a given image position relative to the axis center Parameters: theta : spin angles in a given exposure in unit of [arcsecond] dx, dy: relative position to the axis center in unit of [pixels] Return: Spinned PSF: g1, g2 and axis ratio 'a/b' """ a2Rad = np.pi/(60.0*60.0*180.0) ff = self.sigGauss * 0.107 * (1000.0/10.0) # in unit of [pixels] rc = np.sqrt(x*x + y*y) cpix = rc*(self.sigSpin*a2Rad) beta = (np.arctan2(y,x) + np.pi/2) ell = cpix**2/(2.0*ff**2+cpix**2) qr = np.sqrt((1.0+ell)/(1.0-ell)) PSFshear = galsim.Shear(e=ell, beta=beta*galsim.radians) return self.psf.shear(PSFshear), PSFshear if __name__ == '__main__': pass