Commit df15b1ac authored by Zhang Xin's avatar Zhang Xin
Browse files

modify sls convolve psf method

parent 21c0174d
......@@ -23,7 +23,7 @@ class Observation(object):
self.filter_param = FilterParam()
self.Catalog = Catalog
def prepare_chip_for_exposure(self, chip, ra_cen, dec_cen, pointing, wcs_fp=None):
def prepare_chip_for_exposure(self, chip, ra_cen, dec_cen, pointing, wcs_fp=None, slsPSFOptim = False):
# Get WCS for the focal plane
if wcs_fp == None:
wcs_fp = self.focal_plane.getTanWCS(
......@@ -34,6 +34,26 @@ class Observation(object):
chip.img.setOrigin(chip.bound.xmin, chip.bound.ymin)
chip.img.wcs = wcs_fp
chip.slsPSFOptim = slsPSFOptim
if chip.chipID in [1,2,3,4,5,10,21,26,27,28,29,30] and slsPSFOptim:
chip.img_stack = {}
for id1 in np.arange(2):
gn = chip_utils.getChipSLSGratingID(chip.chipID)[id1]
orders = {}
# for id2 in ['-2','-1','0','1','2']:
for id2 in ['0','1']:
o_n = "order"+id2
allbands = {}
for id3 in ['1','2','3','4']:
w_n = "w"+id3
allbands[w_n] = galsim.ImageF(chip.npix_x, chip.npix_y)
allbands[w_n].setOrigin(chip.bound.xmin, chip.bound.ymin)
allbands[w_n].wcs = wcs_fp
orders[o_n] = allbands
chip.img_stack[gn] = orders
else:
chip.img_stack = {}
# Get random generators for this chip
chip.rng_poisson, chip.poisson_noise = chip_utils.get_poisson(
seed=int(self.config["random_seeds"]["seed_poisson"]) + pointing.id*30 + chip.chipID, sky_level=0.)
......@@ -95,9 +115,10 @@ class Observation(object):
else:
ra_cen = pointing.ra
dec_cen = pointing.dec
slsPSFOpt = True
# Prepare necessary chip properties for simulation
chip = self.prepare_chip_for_exposure(chip, ra_cen, dec_cen, pointing)
chip = self.prepare_chip_for_exposure(chip, ra_cen, dec_cen, pointing, slsPSFOptim = slsPSFOpt)
# Initialize SimSteps
sim_steps = SimSteps(overall_config=self.config,
......
......@@ -11,6 +11,8 @@ from observation_sim.mock_objects._util import integrate_sed_bandpass, getNormFa
getABMAG
from observation_sim.mock_objects.SpecDisperser import SpecDisperser
from observation_sim.instruments.chip import chip_utils
class MockObject(object):
def __init__(self, param, logger=None):
......@@ -239,63 +241,101 @@ class MockObject(object):
def addSLStoChipImageWithPSF(self, sdp=None, chip=None, pos_img_local=[1, 1], psf_model=None, bandNo=1, grating_split_pos=3685, local_wcs=None, pos_img=None):
spec_orders = sdp.compute_spec_orders()
for k, v in spec_orders.items():
img_s = v[0]
# print(bandNo,k)
try:
psf, pos_shear = psf_model.get_PSF(
chip, pos_img_local=pos_img_local, bandNo=bandNo, galsimGSObject=True, g_order=k, grating_split_pos=grating_split_pos)
except:
psf, pos_shear = psf_model.get_PSF(chip=chip, pos_img=pos_img)
psf_img = psf.drawImage(nx=100, ny=100, wcs=local_wcs)
if chip.slsPSFOptim:
for k, v in spec_orders.items():
img_s = v[0]
pos_shear = galsim.Shear(e=0., beta=(np.pi/2)*galsim.radians)
nan_ids = np.isnan(img_s)
if img_s[nan_ids].shape[0] > 0:
img_s[nan_ids] = 0
print("DEBUG: specImg nan num is", img_s[nan_ids].shape[0])
#########################################################
# img_s, orig_off = convolveImg(img_s, psf_img_m)
orig_off = [0,0]
origin_order_x = v[1] - orig_off[0]
origin_order_y = v[2] - orig_off[1]
specImg = galsim.ImageF(img_s)
specImg.wcs = local_wcs
specImg.setOrigin(origin_order_x, origin_order_y)
bounds = specImg.bounds & galsim.BoundsI(
0, chip.npix_x - 1, 0, chip.npix_y - 1)
if bounds.area() == 0:
continue
# orders = {'A': 'order1', 'B': 'order0', 'C': 'order2', 'D': 'order-1', 'E': 'order-2'}
orders = {'A': 'order1', 'B': 'order0', 'C': 'order0', 'D': 'order0', 'E': 'order0'}
gratingN = chip_utils.getChipSLSGratingID(chip.chipID)[1]
if pos_img_local[0] < grating_split_pos:
gratingN = chip_utils.getChipSLSGratingID(chip.chipID)[0]
chip.img_stack[gratingN][orders[k]]['w' + str(bandNo)].setOrigin(0, 0)
chip.img_stack[gratingN][orders[k]]['w' + str(bandNo)][bounds] = chip.img_stack[gratingN][orders[k]]['w' + str(bandNo)][bounds] + specImg[bounds]
chip.img_stack[gratingN][orders[k]]['w' + str(bandNo)].setOrigin(chip.bound.xmin, chip.bound.ymin)
psf_img_m = psf_img.array
#########################################################
# DEBUG
#########################################################
# ids_p = psf_img_m < 0
# psf_img_m[ids_p] = 0
# from astropy.io import fits
# fits.writeto(str(bandNo) + '_' + str(k) + '_psf.fits', psf_img_m)
# print("DEBUG: orig_off is", orig_off)
nan_ids = np.isnan(img_s)
if img_s[nan_ids].shape[0] > 0:
img_s[nan_ids] = 0
print("DEBUG: specImg nan num is", img_s[nan_ids].shape[0])
#########################################################
img_s, orig_off = convolveImg(img_s, psf_img_m)
origin_order_x = v[1] - orig_off[0]
origin_order_y = v[2] - orig_off[1]
specImg = galsim.ImageF(img_s)
# photons = galsim.PhotonArray.makeFromImage(specImg)
# photons.x += origin_order_x
# photons.y += origin_order_y
# xlen_imf = int(specImg.xmax - specImg.xmin + 1)
# ylen_imf = int(specImg.ymax - specImg.ymin + 1)
# stamp = galsim.ImageF(xlen_imf, ylen_imf)
# stamp.wcs = local_wcs
# stamp.setOrigin(origin_order_x, origin_order_y)
specImg.wcs = local_wcs
specImg.setOrigin(origin_order_x, origin_order_y)
bounds = specImg.bounds & galsim.BoundsI(
0, chip.npix_x - 1, 0, chip.npix_y - 1)
if bounds.area() == 0:
continue
chip.img.setOrigin(0, 0)
chip.img[bounds] = chip.img[bounds] + specImg[bounds]
# stamp[bounds] = chip.img[bounds]
# # chip.sensor.accumulate(photons, stamp)
# chip.img[bounds] = stamp[bounds]
chip.img.setOrigin(chip.bound.xmin, chip.bound.ymin)
# del stamp
else:
for k, v in spec_orders.items():
img_s = v[0]
# print(bandNo,k)
try:
psf, pos_shear = psf_model.get_PSF(
chip, pos_img_local=pos_img_local, bandNo=bandNo, galsimGSObject=True, g_order=k, grating_split_pos=grating_split_pos)
except:
psf, pos_shear = psf_model.get_PSF(chip=chip, pos_img=pos_img)
psf_img = psf.drawImage(nx=100, ny=100, wcs=local_wcs)
psf_img_m = psf_img.array
#########################################################
# DEBUG
#########################################################
# ids_p = psf_img_m < 0
# psf_img_m[ids_p] = 0
# from astropy.io import fits
# fits.writeto(str(bandNo) + '_' + str(k) + '_psf.fits', psf_img_m)
# print("DEBUG: orig_off is", orig_off)
nan_ids = np.isnan(img_s)
if img_s[nan_ids].shape[0] > 0:
img_s[nan_ids] = 0
print("DEBUG: specImg nan num is", img_s[nan_ids].shape[0])
#########################################################
img_s, orig_off = convolveImg(img_s, psf_img_m)
origin_order_x = v[1] - orig_off[0]
origin_order_y = v[2] - orig_off[1]
specImg = galsim.ImageF(img_s)
# photons = galsim.PhotonArray.makeFromImage(specImg)
# photons.x += origin_order_x
# photons.y += origin_order_y
# xlen_imf = int(specImg.xmax - specImg.xmin + 1)
# ylen_imf = int(specImg.ymax - specImg.ymin + 1)
# stamp = galsim.ImageF(xlen_imf, ylen_imf)
# stamp.wcs = local_wcs
# stamp.setOrigin(origin_order_x, origin_order_y)
specImg.wcs = local_wcs
specImg.setOrigin(origin_order_x, origin_order_y)
bounds = specImg.bounds & galsim.BoundsI(
0, chip.npix_x - 1, 0, chip.npix_y - 1)
if bounds.area() == 0:
continue
chip.img.setOrigin(0, 0)
chip.img[bounds] = chip.img[bounds] + specImg[bounds]
# stamp[bounds] = chip.img[bounds]
# # chip.sensor.accumulate(photons, stamp)
# chip.img[bounds] = stamp[bounds]
chip.img.setOrigin(chip.bound.xmin, chip.bound.ymin)
# del stamp
del spec_orders
return pos_shear
......
......@@ -20,8 +20,10 @@ import os
from astropy.io import fits
from astropy.modeling.models import Gaussian2D
from scipy import signal
from scipy import signal, interpolate
import datetime
import gc
from jax import numpy as jnp
LOG_DEBUG = False # ***#
NPSF = 900 # ***# 30*30
......@@ -479,6 +481,106 @@ class PSFInterpSLS(PSFModel):
return PSF_int_trans, PSF_int
def convolveFullImgWithPCAPSF(self, chip, folding_threshold=5.e-3):
keys_L1= chip_utils.getChipSLSGratingID(chip.chipID)
# keys_L2 = ['order-2','order-1','order0','order1','order2']
keys_L2 = ['order0','order1']
keys_L3 = ['w1','w2','w3','w4']
npca = 10
x_start = chip.x_cen/chip.pix_size - chip.npix_x / 2.
y_start = chip.y_cen/chip.pix_size - chip.npix_y / 2.
for i,gt in enumerate(keys_L1):
psfCo = self.grating1_data
if i > 0:
psfCo = self.grating2_data
for od in keys_L2:
psfCo_L2 = psfCo['order1']
if od in ['order-2','order-1','order0','order2']:
psfCo_L2 = psfCo['order0']
for w in keys_L3:
img = chip.img_stack[gt][od][w]
pcs = psfCo_L2['band'+w[1]]['band_data'][0].data
pos_p = psfCo_L2['band'+w[1]]['band_data'][1].data/chip.pix_size - np.array([y_start, x_start])
pc_coeff = psfCo_L2['band'+w[1]]['band_data'][2].data
# print("DEBUG-----------",np.max(pos_p[:,1]),np.min(pos_p[:,1]), np.max(pos_p[:,0]),np.min(pos_p[:,0]))
sum_img = np.sum(img.array)
# coeff_mat = np.zeros([npca, chip.npix_y, chip.npix_x])
# for m in np.arange(chip.npix_y):
# for n in np.arange(chip.npix_x):
# px = n
# py = m
# dist2 = (pos_p[:, 1] - px)*(pos_p[:, 1] - px) + (pos_p[:, 0] - py)*(pos_p[:, 0] - py)
# temp_sort_dist = np.zeros([dist2.shape[0], 2])
# temp_sort_dist[:, 0] = np.arange(0, dist2.shape[0], 1)
# temp_sort_dist[:, 1] = dist2
# # print(temp_sort_dist)
# dits2_sortlist = sorted(temp_sort_dist, key=lambda x: x[1])
# # print(dits2_sortlist)
# nearest4p = np.zeros([4, 3])
# pc_coeff_4p = np.zeros([npca, 4])
# for i in np.arange(4):
# smaller_ids = int(dits2_sortlist[i][0])
# nearest4p[i, 0] = pos_p[smaller_ids, 1]
# nearest4p[i, 1] = pos_p[smaller_ids, 0]
# # print(pos_p[smaller_ids, 1],pos_p[smaller_ids, 0])
# nearest4p[i, 2] = dits2_sortlist[i][1]
# pc_coeff_4p[:, i] = pc_coeff[npca, smaller_ids]
# # idw_dist = 1/(np.sqrt((px-nearest4p[:, 0]) * (px-nearest4p[:, 0]) + (
# # py-nearest4p[:, 1]) * (py-nearest4p[:, 1])))
# idw_dist = 1/(np.sqrt(nearest4p[:, 2]))
# coeff_int = np.zeros(npca)
# for i in np.arange(4):
# coeff_int = coeff_int + pc_coeff_4p[:, i]*idw_dist[i]
# coeff_mat[:, m, n] = coeff_int
m_size = int(pcs.shape[0]**0.5)
tmp_img = np.zeros_like(img.array,dtype=np.float32)
for j in np.arange(npca):
print(gt, od, w, j)
X_ = jnp.hstack((pos_p[:,1].flatten()[:, None], pos_p[:,0].flatten()[:, None]),dtype=np.float32)
Z_ = (pc_coeff[j].astype(np.float32)).flatten()
# print(pc_coeff[j].shape[0], pos_p[:,1].shape[0], pos_p[:,0].shape[0])
sub_size = 4
cx_len = int(chip.npix_x/sub_size)
cy_len = int(chip.npix_y/sub_size)
n_x = jnp.arange(0, cx_len, 1, dtype = int)
n_y = jnp.arange(0, cy_len, 1, dtype = int)
M, N = jnp.meshgrid(n_x, n_y)
t1=datetime.datetime.now()
# U = interpolate.griddata(X_, Z_, (M[0:cy_len, 0:cx_len],N[0:cy_len, 0:cx_len]),
# method='nearest',fill_value=1.0)
U1 = interpolate.griddata(X_, Z_, (M[0:cy_len, 0:cx_len],N[0:cy_len, 0:cx_len]),
method='nearest',fill_value=1.0)
U = np.zeros_like(chip.img.array, dtype=np.float32)
for mi in np.arange(cx_len):
for mj in np.arange(cx_len):
U[mi*sub_size:(mi+1)*sub_size, mj*sub_size:(mj+1)*sub_size]=U1[mi,mj]
t2=datetime.datetime.now()
print("time interpolate:", t2-t1)
img_tmp = img.array*U
psf = pcs[:, j].reshape(m_size, m_size)
tmp_img = tmp_img + signal.fftconvolve(img_tmp, psf, mode='same', axes=None)
t3=datetime.datetime.now()
print("time convole:", t3-t2)
del U
del U1
chip.img = chip.img + tmp_img*sum_img/np.sum(tmp_img)
del tmp_img
gc.collect()
# pixSize = np.rad2deg(self.pixsize*1e-3/28)*3600 #set psf pixsize
#
# # assert self.iccd == int(chip.getChipLabel(chipID=chip.chipID)), 'ERROR: self.iccd != chip.chipID'
......
......@@ -217,6 +217,28 @@ def add_objects(self, chip, filt, tel, pointing, catalog, obs_param):
obj.unload_SED()
del obj
gc.collect()
if chip.survey_type == "spectroscopic" and not self.overall_config["run_option"]["out_cat_only"] and chip.slsPSFOptim:
# from observation_sim.instruments.chip import chip_utils as chip_utils
# gn = chip_utils.getChipSLSGratingID(chip.chipID)[0]
# img1 = np.zeros([2,chip.img.array.shape[0],chip.img.array.shape[1]])
# for id1 in np.arange(2):
# gn = chip_utils.getChipSLSGratingID(chip.chipID)[id1]
# img_i = 0
# for id2 in ['0','1']:
# o_n = "order"+id2
# for id3 in ['1','2','3','4']:
# w_n = "w"+id3
# img1[img_i] = img1[img_i] + chip.img_stack[gn][o_n][w_n].array
# img_i = img_i + 1
# from astropy.io import fits
# fits.writeto('order0.fits',img1[0],overwrite=True)
# fits.writeto('order1.fits',img1[1],overwrite=True)
psf_model.convolveFullImgWithPCAPSF(chip)
del psf_model
gc.collect()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment