From 2f6b3ce6bf8eadf89c318442d6d414bfbdabd157 Mon Sep 17 00:00:00 2001 From: yuedong Date: Tue, 8 Mar 2022 16:00:57 +0800 Subject: [PATCH] add missed PSFInterp.py --- ObservationSim/PSF/PSFInterp.py | 386 ++++++++++++++++++++++++++++++++ 1 file changed, 386 insertions(+) create mode 100644 ObservationSim/PSF/PSFInterp.py diff --git a/ObservationSim/PSF/PSFInterp.py b/ObservationSim/PSF/PSFInterp.py new file mode 100644 index 0000000..058acca --- /dev/null +++ b/ObservationSim/PSF/PSFInterp.py @@ -0,0 +1,386 @@ +''' +PSF interpolation for CSST-Sim + +NOTE: [iccd, iwave, ipsf] are counted from 1 to n, but [tccd, twave, tpsf] are counted from 0 to n-1 +''' + +import sys +import time +import copy +import numpy as np +import scipy.spatial as spatial +import galsim +import h5py + +from ObservationSim.PSF.PSFModel import PSFModel + + +LOG_DEBUG = False #***# +NPSF = 900 #***# 30*30 +PixSizeInMicrons = 5. #***# in microns + + +###find neighbors-KDtree### +def findNeighbors(tx, ty, px, py, dr=0.1, dn=1, OnlyDistance=True): + """ + find nearest neighbors by 2D-KDTree + + Parameters: + tx, ty (float, float): a given position + px, py (numpy.array, numpy.array): position data for tree + dr (float-optional): distance + dn (int-optional): nearest-N + OnlyDistance (bool-optional): only use distance to find neighbors. Default: True + + Returns: + dataq (numpy.array): index + """ + datax = px + datay = py + tree = spatial.KDTree(list(zip(datax.ravel(), datay.ravel()))) + + dataq=[] + rr = dr + if OnlyDistance == True: + dataq = tree.query_ball_point([tx, ty], rr) + if OnlyDistance == False: + while len(dataq) < dn: + dataq = tree.query_ball_point([tx, ty], rr) + rr += dr + dd = np.hypot(datax[dataq]-tx, datay[dataq]-ty) + ddSortindx = np.argsort(dd) + dataq = np.array(dataq)[ddSortindx[0:dn]] + return dataq + +###find neighbors-hoclist### +def hocBuild(partx, party, nhocx, nhocy, dhocx, dhocy): + if np.max(partx) > nhocx*dhocx: + print('ERROR') + sys.exit() + if np.max(party) > nhocy*dhocy: + print('ERROR') + sys.exit() + + npart = partx.size + hoclist= np.zeros(npart, dtype=np.int32)-1 + hoc = np.zeros([nhocy, nhocx], dtype=np.int32)-1 + for ipart in range(npart): + ix = int(partx[ipart]/dhocx) + iy = int(party[ipart]/dhocy) + hoclist[ipart] = hoc[iy, ix] + hoc[iy, ix] = ipart + return hoc, hoclist + +def hocFind(px, py, dhocx, dhocy, hoc, hoclist): + ix = int(px/dhocx) + iy = int(py/dhocy) + + neigh=[] + it = hoc[iy, ix] + while it != -1: + neigh.append(it) + it = hoclist[it] + return neigh + +def findNeighbors_hoclist(px, py, tx=None,ty=None, dn=4, hoc=None, hoclist=None): + nhocy = nhocx = 20 + + pxMin = np.min(px) + pxMax = np.max(px) + pyMin = np.min(py) + pyMax = np.max(py) + + dhocx = (pxMax - pxMin)/(nhocx-1) + dhocy = (pyMax - pyMin)/(nhocy-1) + partx = px - pxMin +dhocx/2 + party = py - pyMin +dhocy/2 + + if hoc is None: + hoc, hoclist = hocBuild(partx, party, nhocx, nhocy, dhocx, dhocy) + return hoc, hoclist + + if hoc is not None: + tx = tx - pxMin +dhocx/2 + ty = ty - pyMin +dhocy/2 + itx = int(tx/dhocx) + ity = int(ty/dhocy) + + ps = [-1, 0, 1] + neigh=[] + for ii in range(3): + for jj in range(3): + ix = itx + ps[ii] + iy = ity + ps[jj] + if ix < 0: + continue + if iy < 0: + continue + if ix > nhocx-1: + continue + if iy > nhocy-1: + continue + + #neightt = myUtil.hocFind(ppx, ppy, dhocx, dhocy, hoc, hoclist) + it = hoc[iy, ix] + while it != -1: + neigh.append(it) + it = hoclist[it] + #neigh.append(neightt) + #ll = [i for k in neigh for i in k] + if dn != -1: + ptx = np.array(partx[neigh]) + pty = np.array(party[neigh]) + dd = np.hypot(ptx-tx, pty-ty) + idx = np.argsort(dd) + neigh= np.array(neigh)[idx[0:dn]] + return neigh + + +###PSF-IDW### +def psfMaker_IDW(px, py, PSFMat, cen_col, cen_row, IDWindex=2, OnlyNeighbors=True, hoc=None, hoclist=None, PSFCentroidWgt=False): + """ + psf interpolation by IDW + + Parameters: + px, py (float, float): position of the target + PSFMat (numpy.array): image + cen_col, cen_row (numpy.array, numpy.array): potions of the psf centers + IDWindex (int-optional): the power index of IDW + OnlyNeighbors (bool-optional): only neighbors are used for psf interpolation + + Returns: + psfMaker (numpy.array) + """ + + minimum_psf_weight = 1e-8 + ref_col = px + ref_row = py + + ngy, ngx = PSFMat[0, :, :].shape + npsf = PSFMat[:, :, :].shape[0] + psfWeight = np.zeros([npsf]) + + if OnlyNeighbors == True: + if hoc is None: + neigh = findNeighbors(px, py, cen_col, cen_row, dr=5., dn=4, OnlyDistance=False) + if hoc is not None: + neigh = findNeighbors_hoclist(cen_col, cen_row, tx=px,ty=py, dn=4, hoc=hoc, hoclist=hoclist) + + neighFlag = np.zeros(npsf) + neighFlag[neigh] = 1 + + for ipsf in range(npsf): + if OnlyNeighbors == True: + if neighFlag[ipsf] != 1: + continue + + dist = np.sqrt((ref_col - cen_col[ipsf])**2 + (ref_row - cen_row[ipsf])**2) + if IDWindex == 1: + psfWeight[ipsf] = dist + if IDWindex == 2: + psfWeight[ipsf] = dist**2 + if IDWindex == 3: + psfWeight[ipsf] = dist**3 + if IDWindex == 4: + psfWeight[ipsf] = dist**4 + psfWeight[ipsf] = max(psfWeight[ipsf], minimum_psf_weight) + psfWeight[ipsf] = 1./psfWeight[ipsf] + psfWeight /= np.sum(psfWeight) + + psfMaker = np.zeros([ngy, ngx], dtype=np.float32) + for ipsf in range(npsf): + if OnlyNeighbors == True: + if neighFlag[ipsf] != 1: + continue + + iPSFMat = PSFMat[ipsf, :, :].copy() + ipsfWeight = psfWeight[ipsf] + + psfMaker += iPSFMat * ipsfWeight + psfMaker /= np.nansum(psfMaker) + + return psfMaker + + + +###define PSFInterp### +class PSFInterp(PSFModel): + def __init__(self, chip, npsf=NPSF, PSF_data=None, PSF_data_file=None, PSF_data_prefix="", sigSpin=0, psfRa=0.15, HocBuild=False): + if LOG_DEBUG: + print('===================================================') + print('DEBUG: psf module for csstSim ' \ + +time.strftime("(%Y-%m-%d %H:%M:%S)", time.localtime()), flush=True) + print('===================================================') + + self.sigSpin = sigSpin + self.sigGauss = psfRa + + self.iccd = int(chip.getChipLabel(chipID=chip.chipID)) + if PSF_data_file == None: + print('Error - PSF_data_file is None') + sys.exit() + + self.nwave= self._getPSFwave(self.iccd, PSF_data_file, PSF_data_prefix) + self.npsf = npsf + self.PSF_data = self._loadPSF(self.iccd, PSF_data_file, PSF_data_prefix) + + if LOG_DEBUG: + print('nwave-{:} on ccd-{:}::'.format(self.nwave, self.iccd), flush=True) + print('self.PSF_data ... ok', flush=True) + print('Preparing self.[psfMat,cen_col,cen_row] for psfMaker ... ', end='', flush=True) + + ngy, ngx = self.PSF_data[0][0]['psfMat'].shape + self.psfMat = np.zeros([self.nwave, self.npsf, ngy, ngx], dtype=np.float32) + self.cen_col= np.zeros([self.nwave, self.npsf], dtype=np.float32) + self.cen_row= np.zeros([self.nwave, self.npsf], dtype=np.float32) + self.hoc =[] + self.hoclist=[] + + for twave in range(self.nwave): + for tpsf in range(self.npsf): + self.psfMat[twave, tpsf, :, :] = self.PSF_data[twave][tpsf]['psfMat'] + self.PSF_data[twave][tpsf]['psfMat'] = 0 ###free psfMat + + self.pixsize = self.PSF_data[twave][tpsf]['pixsize']*1e-3 ##mm + self.cen_col[twave, tpsf] = self.PSF_data[twave][tpsf]['image_x'] + self.PSF_data[twave][tpsf]['centroid_x'] + self.cen_row[twave, tpsf] = self.PSF_data[twave][tpsf]['image_y'] + self.PSF_data[twave][tpsf]['centroid_y'] + + if HocBuild: + #hoclist on twave for neighborsFinding + hoc,hoclist = findNeighbors_hoclist(self.cen_col[twave], self.cen_row[twave]) + self.hoc.append(hoc) + self.hoclist.append(hoclist) + + if LOG_DEBUG: + print('ok', flush=True) + + + def _getPSFwave(self, iccd, PSF_data_file, PSF_data_prefix): + fq = h5py.File(PSF_data_file+'/' +PSF_data_prefix +'psfCube_ccd{:}.h5'.format(iccd), 'r') + nwave = len(fq.keys()) + fq.close() + return nwave + + + def _loadPSF(self, iccd, PSF_data_file, PSF_data_prefix): + psfSet = [] + fq = h5py.File(PSF_data_file+'/' +PSF_data_prefix +'psfCube_ccd{:}.h5'.format(iccd), 'r') + for ii in range(self.nwave): + iwave = ii+1 + psfWave = [] + + fq_iwave = fq['w_{:}'.format(iwave)] + for jj in range(self.npsf): + ipsf = jj+1 + psfInfo = {} + psfInfo['wavelength']= fq_iwave['wavelength'][()] + + fq_iwave_ipsf = fq_iwave['psf_{:}'.format(ipsf)] + psfInfo['pixsize'] = PixSizeInMicrons + psfInfo['field_x'] = fq_iwave_ipsf['field_x'][()] + psfInfo['field_y'] = fq_iwave_ipsf['field_y'][()] + psfInfo['image_x'] = fq_iwave_ipsf['image_x'][()] + psfInfo['image_y'] = fq_iwave_ipsf['image_y'][()] + psfInfo['centroid_x']= fq_iwave_ipsf['cx'][()] + psfInfo['centroid_y']= fq_iwave_ipsf['cy'][()] + psfInfo['psfMat'] = fq_iwave_ipsf['psfMat'][()] + + psfWave.append(psfInfo) + psfSet.append(psfWave) + fq.close() + + if LOG_DEBUG: + print('psfSet has been loaded:', flush=True) + print('psfSet[iwave][ipsf][keys]:', psfSet[0][0].keys(), flush=True) + return psfSet + + + def _findWave(self, bandpass): + if isinstance(bandpass,int): + twave = bandpass + return twave + + for twave in range(self.nwave): + bandwave = self.PSF_data[twave][0]['wavelength'] + if bandpass.blue_limit < bandwave and bandwave < bandpass.red_limit: + return twave + return -1 + + + def get_PSF(self, chip, pos_img, bandpass, galsimGSObject=True, findNeighMode='treeFind', folding_threshold=5.e-3, pointing_pa=0.0): + """ + Get the PSF at a given image position + + Parameters: + chip: A 'Chip' object representing the chip we want to extract PSF from. + pos_img: A 'galsim.Position' object representing the image position. + bandpass: A 'galsim.Bandpass' object representing the wavelength range. + pixSize: The pixels size of psf matrix + findNeighMode: 'treeFind' or 'hoclistFind' + Returns: + PSF: A 'galsim.GSObject'. + """ + pixSize = np.rad2deg(self.pixsize*1e-3/28)*3600 #set psf pixsize + + assert self.iccd == int(chip.getChipLabel(chipID=chip.chipID)), 'ERROR: self.iccd != chip.chipID' + twave = self._findWave(bandpass) + if twave == -1: + print("!!!PSF bandpass does not match.") + exit() + PSFMat = self.psfMat[twave] + cen_col= self.cen_col[twave] + cen_row= self.cen_row[twave] + + px = (pos_img.x - chip.cen_pix_x)*0.01 + py = (pos_img.y - chip.cen_pix_y)*0.01 + if findNeighMode == 'treeFind': + imPSF = psfMaker_IDW(px, py, PSFMat, cen_col, cen_row, IDWindex=2, OnlyNeighbors=True, PSFCentroidWgt=True) + if findNeighMode == 'hoclistFind': + assert(self.hoc != 0), 'hoclist should be built correctly!' + imPSF = psfMaker_IDW(px, py, PSFMat, cen_col, cen_row, IDWindex=2, OnlyNeighbors=True, hoc=self.hoc[twave], hoclist=self.hoclist[twave], PSFCentroidWgt=True) + + ############TEST: START + TestGaussian = False + if TestGaussian: + gsx = galsim.Gaussian(sigma=0.04) + #pointing_pa = -23.433333 + imPSF= gsx.shear(g1=0.8, g2=0.).rotate(0.*galsim.degrees).drawImage(nx = 256, ny=256, scale=pixSize).array + ############TEST: END + + if galsimGSObject: + imPSFt = np.zeros([257,257]) + imPSFt[0:256, 0:256] = imPSF + + img = galsim.ImageF(imPSFt, scale=pixSize) + gsp = galsim.GSParams(folding_threshold=folding_threshold) + self.psf = galsim.InterpolatedImage(img, gsparams=gsp).rotate(pointing_pa*galsim.degrees) + + return self.PSFspin(x=px/0.01, y=py/0.01) + return imPSF + + def PSFspin(self, x, y): + """ + The PSF profile at a given image position relative to the axis center + + Parameters: + theta : spin angles in a given exposure in unit of [arcsecond] + dx, dy: relative position to the axis center in unit of [pixels] + + Return: + Spinned PSF: g1, g2 and axis ratio 'a/b' + """ + a2Rad = np.pi/(60.0*60.0*180.0) + + ff = self.sigGauss * 0.107 * (1000.0/10.0) # in unit of [pixels] + rc = np.sqrt(x*x + y*y) + cpix = rc*(self.sigSpin*a2Rad) + + beta = (np.arctan2(y,x) + np.pi/2) + ell = cpix**2/(2.0*ff**2+cpix**2) + qr = np.sqrt((1.0+ell)/(1.0-ell)) + PSFshear = galsim.Shear(e=ell, beta=beta*galsim.radians) + return self.psf.shear(PSFshear), PSFshear + + +if __name__ == '__main__': + pass -- GitLab