diff --git a/tests/test_field_distortion.py b/tests/test_field_distortion.py new file mode 100644 index 0000000000000000000000000000000000000000..0f77edbe4062a24d36f90bb07963045b700b4bc9 --- /dev/null +++ b/tests/test_field_distortion.py @@ -0,0 +1,517 @@ +import unittest + +import numpy as np +import galsim +import os +import sys +from astropy.table import Table +from scipy import interpolate +import pickle + + +class test_field_distortion(unittest.TestCase): + def __init__(self, methodName="runTest"): + super(test_field_distortion, self).__init__(methodName) + self.dataMainPath = os.path.join( + os.getenv("UNIT_TEST_DATA_ROOT"), "csst_msc_sim/field_distortion" + ) + self.dataInputPath = os.path.join(self.dataMainPath, "input_catalog") + self.fdModelName = "FieldDistModel_v2.0_test.pickle" + + def test_fd_model(self): + cat_dir = self.dataInputPath + model_dir = self.dataMainPath + model_date = "2024-05-08" + model_name = self.fdModelName + field_distortion_model( + cat_dir, + model_dir, + poly_degree=4, + model_date=model_date, + model_name=model_name, + ) + + def test_fd_apply(self): + model_name = self.fdModelName + model_dir = self.dataMainPath + cat_dir = self.dataMainPath + field_distortion_apply( + model_name, model_dir, cat_dir, ra_cen=60.0, dec_cen=-40.0, img_rot=0.0 + ) + + +def ccdParam(): + """ + Basic CCD size and noise parameters. + """ + # CCD size + xt, yt = 59516, 49752 + x0, y0 = 9216, 9232 + xgap, ygap = (534, 1309), 898 + xnchip, ynchip = 6, 5 + ccdSize = xt, yt, x0, y0, xgap, ygap, xnchip, ynchip + + # other parameters + readNoise = 5.0 # e/pix + darkNoise = 0.02 # e/pix/s + pixel_scale = 0.074 # pixel scale + gain = 1.0 + ccdBase = readNoise, darkNoise, pixel_scale, gain + + return ccdSize, ccdBase + + +def chipLim(chip): + ccdSize, ccdBase = ccdParam() + xt, yt, x0, y0, gx, gy, xnchip, ynchip = ccdSize + gx1, gx2 = gx + + rowID = ((chip - 1) % 5) + 1 + colID = 6 - ((chip - 1) // 5) + + xrem = 2 * (colID - 1) - (xnchip - 1) + xcen = (x0 // 2 + gx1 // 2) * xrem + if chip <= 5 or chip == 10: + xcen = (x0 // 2 + gx1 // 2) * xrem + (gx2 - gx1) + if chip >= 26 or chip == 21: + xcen = (x0 // 2 + gx1 // 2) * xrem - (gx2 - gx1) + nx0 = xcen - x0 // 2 + 1 + nx1 = xcen + x0 // 2 + + yrem = (rowID - 1) - ynchip // 2 + ycen = (y0 + gy) * yrem + ny0 = ycen - y0 // 2 + 1 + ny1 = ycen + y0 // 2 + + return nx0, nx1, ny0, ny1 + + +def chip_filter(nchip): + """ + return filter name of a given chip + """ + filtype = ["nuv", "u", "g", "r", "i", "z", "y"] + + # updated configurations + # if nchip>24 or nchip<7: raise ValueError("!!! Chip ID: [7,24]") + if nchip in [6, 15, 16, 25]: + filter_name = "y" + if nchip in [11, 20]: + filter_name = "z" + if nchip in [7, 24]: + filter_name = "i" + if nchip in [14, 17]: + filter_name = "u" + if nchip in [9, 22]: + filter_name = "r" + if nchip in [12, 13, 18, 19]: + filter_name = "nuv" + if nchip in [8, 23]: + filter_name = "g" + + filter_id = filtype.index(filter_name) + + return filter_id, filter_name + + +def skyLim(wcs, x0, x1, y0, y1): + """ + The sky coverage of a single exposure image + """ + r2d = 180.0 / np.pi + # xt, yt, x0, y0, gx, gy, xnchip, ynchip = ccdSize() + s1 = wcs.toWorld(galsim.PositionD(x0, y0)) + s2 = wcs.toWorld(galsim.PositionD(x0, y1)) + + s3 = wcs.toWorld(galsim.PositionD(x1, y0)) + s4 = wcs.toWorld(galsim.PositionD(x1, y1)) + ra = [s1.ra.rad * r2d, s2.ra.rad * r2d, s3.ra.rad * r2d, s4.ra.rad * r2d] + dec = [s1.dec.rad * r2d, s2.dec.rad * r2d, s3.dec.rad * r2d, s4.dec.rad * r2d] + + return min(ra), max(ra), min(dec), max(dec) + + +def wcsMain(imgRotation=0.0, raCenter=0.0, decCenter=0.0): + ccdSize, ccdBase = ccdParam() + xsize, ysize, _, _, _, _, _, _ = ccdSize + _, _, pixelScale, _ = ccdBase + xmcen, ymcen = 0.0, 0.0 + imrot = imgRotation * galsim.degrees + racen = raCenter * galsim.degrees + deccen = decCenter * galsim.degrees + + # define the wcs + dudx = -np.cos(imrot.rad) * pixelScale + dudy = +np.sin(imrot.rad) * pixelScale + dvdx = -np.sin(imrot.rad) * pixelScale + dvdy = -np.cos(imrot.rad) * pixelScale + moscen = galsim.PositionD(x=xmcen, y=ymcen) + skyCenter = galsim.CelestialCoord(ra=racen, dec=deccen) + affine = galsim.AffineTransform(dudx, dudy, dvdx, dvdy, origin=moscen) + wcs = galsim.TanWCS(affine, skyCenter, units=galsim.arcsec) + + return wcs + + +# FD model +def field_distortion_model( + cat_dir, + model_dir, + poly_degree=4, + model_date="2024-05-08", + model_name="FieldDistModel_v2.0_test.pickle", +): + # default parameter setup + nccd, nwave, npsf = 30, 4, 30 * 30 + + # load a CSST-like wcs + wcs = wcsMain() + cd11, cd12 = wcs.cd[0, 0], wcs.cd[0, 1] + cd21, cd22 = wcs.cd[1, 0], wcs.cd[1, 1] + xmcen, ymcen = wcs.crpix + + # obtain the interpolation model + fdFunList = {} + fdFunList["date"] = model_date + for iwave in range(1, nwave + 1): + # if iwave!=1: continue + iwaveKey = "wave%d" % iwave + + # first construct the global interpolation + xwList, ywList = [], [] + xdList, ydList = [], [] + fdFunList[iwaveKey] = {} + for iccd in range(1, nccd + 1): + # if iccd!=9: continue + iccdKey = "ccd" + str("0%d" % (iccd))[-2:] + + # load PSF data + ipsfDatn = os.path.join(cat_dir, "ccd%d_%s.dat" % (iccd, iwaveKey)) + ipsfDat = Table.read(ipsfDatn, format="ascii") + for ipsf in range(1, npsf + 1): + # if ipsf!=2: continue + xField = ipsfDat["field_x"][ipsf - 1] + yField = ipsfDat["field_y"][ipsf - 1] + + # image coordinate with field distortion + xImage = 100.0 * ( + ipsfDat["image_x"][ipsf - 1] + ipsfDat["centroid_x"][ipsf - 1] + ) + yImage = 100.0 * ( + ipsfDat["image_y"][ipsf - 1] + ipsfDat["centroid_y"][ipsf - 1] + ) + + # image coordinate only with wcs projection + xwcs = (cd12 * yField - cd22 * xField) / ( + cd12 * cd21 - cd11 * cd22 + ) + xmcen + ywcs = (cd21 * xField - cd11 * yField) / ( + cd12 * cd21 - cd11 * cd22 + ) + ymcen + + xwList += [xwcs] + ywList += [ywcs] + xdList += [xImage] + ydList += [yImage] + + # global interpolation + xImageFun = interpolate.SmoothBivariateSpline( + xwList, ywList, xdList, kx=poly_degree, ky=poly_degree + ) + yImageFun = interpolate.SmoothBivariateSpline( + xwList, ywList, ydList, kx=poly_degree, ky=poly_degree + ) + fdFunList[iwaveKey] = { + "xImagePos": xImageFun, + "yImagePos": yImageFun, + "interpLimit": [ + np.min(xwList), + np.max(xwList), + np.min(ywList), + np.max(ywList), + ], + } + + # construct the residual interpolation + fdFunList[iwaveKey]["residual"] = {} + for iccd in range(1, nccd + 1): + # if iccd!=1: continue + iccdKey = "ccd" + str("0%d" % (iccd))[-2:] + # open the ditortion data + ipsfDatn = os.path.join(cat_dir, "ccd%d_%s.dat" % (iccd, iwaveKey)) + ipsfDat = Table.read(ipsfDatn, format="ascii") + + ixwList, iywList = [], [] + idxList, idyList = [], [] + for ipsf in range(1, npsf + 1): + # if ipsf!=1: continue + print( + "^_^ loading: iccd-{:} iwave-{:} ipsf-{:}".format(iccd, iwave, ipsf) + ) + xField = ipsfDat["field_x"][ipsf - 1] + yField = ipsfDat["field_y"][ipsf - 1] + xImage = 100.0 * ( + ipsfDat["image_x"][ipsf - 1] + ipsfDat["centroid_x"][ipsf - 1] + ) + yImage = 100.0 * ( + ipsfDat["image_y"][ipsf - 1] + ipsfDat["centroid_y"][ipsf - 1] + ) + + # image coordinate only with wcs projection + xwcs = (cd12 * yField - cd22 * xField) / ( + cd12 * cd21 - cd11 * cd22 + ) + xmcen + ywcs = (cd21 * xField - cd11 * yField) / ( + cd12 * cd21 - cd11 * cd22 + ) + ymcen + + ixPred = xImageFun(xwcs, ywcs)[0][0] + iyPred = yImageFun(xwcs, ywcs)[0][0] + idx = xImage - ixPred + idy = yImage - iyPred + # print(idx, idy) + ixwList += [xwcs] + iywList += [ywcs] + idxList += [idx] + idyList += [idy] + + # interpolation + xResFun = interpolate.SmoothBivariateSpline( + ixwList, iywList, idxList, kx=poly_degree, ky=poly_degree + ) + yResFun = interpolate.SmoothBivariateSpline( + ixwList, iywList, idyList, kx=poly_degree, ky=poly_degree + ) + + fdFunList[iwaveKey]["residual"][iccdKey] = { + "xResidual": xResFun, + "yResidual": yResFun, + "interpLimit": [ + np.min(ixwList), + np.max(ixwList), + np.min(iywList), + np.max(iywList), + ], + } + + # save the interpolation functions + model_name_full = os.path.join(model_dir, model_name) + with open(model_name_full, "wb") as out: + pickle.dump(fdFunList, out, pickle.HIGHEST_PROTOCOL) + + return + + +def field_distortion_apply( + model_name, model_dir, cat_dir, ra_cen=60.0, dec_cen=-40.0, img_rot=0.0 +): + # CCD and observation + ccdSize, ccdBase = ccdParam() + xsize, ysize, xchip, ychip, xgap, ygap, xnchip, ynchip = ccdSize + nchip = xnchip * ynchip + + ################################################# + xmcen, ymcen = 0.0, 0.0 + ################################################# + + badchip = list(range(1, 6)) + list(range(26, 31)) + [10, 21] + + # define the wcs of the image mosaic + print( + "^_^ Construct the wcs of the entire image mosaic using Gnomonic/TAN projection" + ) + wcs = wcsMain(imgRotation=img_rot, raCenter=ra_cen, decCenter=dec_cen) + + ################################################# + # load the field distortion model + model_name_full = os.path.join(model_dir, model_name) + with open(model_name_full, "rb") as f: + fdModel = pickle.load(f) + ################################################# + + raLow, raUp, decLow, decUp = skyLim( + wcs, -xsize // 2 + 1, xsize // 2, -ysize // 2 + 1, ysize // 2 + ) + dra = (raUp - raLow) * np.cos(dec_cen * np.pi / 180.0) + ddec = decUp - decLow + print( + " Image pixel size: %d*%d; center: (Ra, Dec)=(%.3f, %.3f)." + % (xsize, ysize, ra_cen, dec_cen) + ) + print(" Field of Veiw: %.2f * %.2f deg^2." % (dra, ddec)) + + # filters and corresponding bounds in the image mosaic + fbound = {} + print(" Model the filter distributions in the image mosaic ...") + stats = {} + for i in range(nchip): + chip_id = i + 1 + if chip_id in badchip: + continue + cx0, cx1, cy0, cy1 = chipLim(chip_id) + chip_bound = galsim.BoundsD(cx0 - 1, cx1 - 1, cy0 - 1, cy1 - 1) + chip_filter_id, chip_filt = chip_filter(chip_id) + # print "^_^ CHIP %d, Filter %s"%(chip_id,chip_filter) + fbound[chip_id] = [chip_filter_id, chip_filt, chip_bound] + stats[chip_id] = [0, 0, 0] + + # generate object grid + ra_input = np.arange(ra_cen - 1.0, ra_cen + 1.0, 0.00125) + dec_input = np.arange(dec_cen - 1.0, dec_cen + 1.0, 0.00125) + nobj = len(ra_input) * len(dec_input) + crdCat = np.zeros((nobj, 2)) + cid = 0 + for id1 in range(len(ra_input)): + ira = ra_input[id1] + for id2 in range(len(dec_input)): + idec = dec_input[id2] + crdCat[cid, :] = ira, idec + cid += 1 + print("^_^ Total %d objects are generaged" % nobj) + + # main program + for i in range(nchip): + # if i not in [6]: continue + if i + 1 in badchip: + continue + filtidk, filtnmk, boundk = fbound[i + 1] + idStr = str("0%d" % (i + 1))[-2:] + + ################################################################### + # 1) Use global field distortion model: FieldDistModelGlobal_v2.0.pickle + ifdModel = fdModel["wave1"] + irsModel = fdModel["wave1"]["residual"]["ccd" + idStr] + xLowI, xUpI, yLowI, yUpI = ifdModel["interpLimit"] + xlLowI, xlUpI, ylLowI, ylUpI = irsModel["interpLimit"] + + # field distortion model along x/y-axis + ixfdModel = ifdModel["xImagePos"] + iyfdModel = ifdModel["yImagePos"] + ixrsModel = irsModel["xResidual"] + iyrsModel = irsModel["yResidual"] + + # first-order derivatives of the global field distortion model + ifx_dx = ixfdModel.partial_derivative(1, 0) + ifx_dy = ixfdModel.partial_derivative(0, 1) + ify_dx = iyfdModel.partial_derivative(1, 0) + ify_dy = iyfdModel.partial_derivative(0, 1) + # first-order derivatives of the residual field distortion model + irx_dx = ixrsModel.partial_derivative(1, 0) + irx_dy = ixrsModel.partial_derivative(0, 1) + iry_dx = iyrsModel.partial_derivative(1, 0) + iry_dy = iyrsModel.partial_derivative(0, 1) + ################################################################### + + # construct the image mosaic firstly + xorigin, yorigin = xmcen - boundk.xmin, ymcen - boundk.ymin + print(" Construct the chip mosaic ...") + fimage = galsim.ImageF(xchip, ychip) + fimage.setOrigin(boundk.xmin, boundk.ymin) + + fimage.wcs = wcs + raLow, raUp, decLow, decUp = skyLim( + wcs, boundk.xmin, boundk.xmax, boundk.ymin, boundk.ymax + ) + dra = (raUp - raLow) * np.cos(dec_cen * np.pi / 180.0) + ddec = decUp - decLow + print(" Image coverage: %.2f * %.2f arcmin^2." % (dra * 60.0, ddec * 60.0)) + # enlarge the sky coverage in order to catch the galaxies at the chip edge + raLow -= 0.2 / 60.0 + decLow -= 0.2 / 60.0 + raUp += 0.2 / 60.0 + decUp += 0.2 / 60.0 + print( + " Range: RA=[%.4f, %.4f]; DEC=[%.4f, %.4f]" + % (raLow, raUp, decLow, decUp) + ) + + # generate the galaxy and star images + catxxn = os.path.join( + cat_dir, "csst_mainfocus_field_distortion_ccd%s_%s.cat" % (idStr, filtnmk) + ) + hdrxx = "#id_obj id_chip filter ra_true dec_ture x_image_ture y_image_ture x_image y_image g1_fd g2_fd\n" + fmtxx = "%8d %3d %4s %12.6f %12.6f %13.6f %13.6f %13.6f %13.6f %9.5f %9.5f\n" + catxx = open(catxxn, "w") + catxx.write(hdrxx) + oidxx = 0 + for k in range(nobj): + # if k != 0: continue + # input galaxy parameters + rak = crdCat[k, 0] + deck = crdCat[k, 1] + + # reject objects out of the image + if (rak - raLow) * (rak - raUp) > 0.0 or (deck - decLow) * ( + deck - decUp + ) > 0.0: + continue + + world_pos = galsim.CelestialCoord( + ra=rak * galsim.degrees, dec=deck * galsim.degrees + ) + image_pos = fimage.wcs.toImage(world_pos) + xk_true = image_pos.x + yk_true = image_pos.y + + ################################################################# + # field distortion + if (xLowI - xk_true) * (xUpI - xk_true) > 0 or (yLowI - yk_true) * ( + yUpI - yk_true + ) > 0: + continue + xk = ixfdModel(xk_true, yk_true)[0][0] + yk = iyfdModel(xk_true, yk_true)[0][0] + + # global offset correction + if (xlLowI - xk) * (xlUpI - xk) > 0 or (ylLowI - yk) * (ylUpI - yk) > 0: + continue + dxk = ixrsModel(xk, yk)[0][0] + dyk = iyrsModel(xk, yk)[0][0] + xk = xk + dxk + yk = yk + dyk + + # field distortion induced ellipticity + ix_dx = ifx_dx(xk, yk) + irx_dx(xk, yk) + ix_dy = ifx_dy(xk, yk) + irx_dy(xk, yk) + iy_dx = ify_dx(xk, yk) + iry_dx(xk, yk) + iy_dy = ify_dy(xk, yk) + iry_dy(xk, yk) + + g1k_fd = 0.0 + (iy_dy - ix_dx) / (iy_dy + ix_dx) + g2k_fd = 0.0 - (iy_dx + ix_dy) / (iy_dy + ix_dx) + ################################################################# + dxk_true, dyk_true = xk_true - xmcen, yk_true - ymcen + xLock_true, yLock_true = dxk_true + xorigin + 1.0, dyk_true + yorigin + 1.0 + + dxk, dyk = xk - xmcen, yk - ymcen + xLock, yLock = dxk + xorigin + 1.0, dyk + yorigin + 1.0 + + if (xLock_true < 0) or (xLock_true > xchip): + continue + if (yLock_true < 0) or (yLock_true > ychip): + continue + if (xLock < 0) or (xLock > xchip): + continue + if (yLock < 0) or (yLock > ychip): + continue + linexx = fmtxx % ( + k + 1, + i + 1, + filtnmk.lower(), + rak, + deck, + xLock_true, + yLock_true, + xLock, + yLock, + g1k_fd[0][0], + g2k_fd[0][0], + ) + catxx.write(linexx) + + catxx.close() + + return + + +if __name__ == "__main__": + unittest.main()