Commit df15b1ac authored by Zhang Xin's avatar Zhang Xin
Browse files

modify sls convolve psf method

parent 21c0174d
...@@ -23,7 +23,7 @@ class Observation(object): ...@@ -23,7 +23,7 @@ class Observation(object):
self.filter_param = FilterParam() self.filter_param = FilterParam()
self.Catalog = Catalog self.Catalog = Catalog
def prepare_chip_for_exposure(self, chip, ra_cen, dec_cen, pointing, wcs_fp=None): def prepare_chip_for_exposure(self, chip, ra_cen, dec_cen, pointing, wcs_fp=None, slsPSFOptim = False):
# Get WCS for the focal plane # Get WCS for the focal plane
if wcs_fp == None: if wcs_fp == None:
wcs_fp = self.focal_plane.getTanWCS( wcs_fp = self.focal_plane.getTanWCS(
...@@ -34,6 +34,26 @@ class Observation(object): ...@@ -34,6 +34,26 @@ class Observation(object):
chip.img.setOrigin(chip.bound.xmin, chip.bound.ymin) chip.img.setOrigin(chip.bound.xmin, chip.bound.ymin)
chip.img.wcs = wcs_fp chip.img.wcs = wcs_fp
chip.slsPSFOptim = slsPSFOptim
if chip.chipID in [1,2,3,4,5,10,21,26,27,28,29,30] and slsPSFOptim:
chip.img_stack = {}
for id1 in np.arange(2):
gn = chip_utils.getChipSLSGratingID(chip.chipID)[id1]
orders = {}
# for id2 in ['-2','-1','0','1','2']:
for id2 in ['0','1']:
o_n = "order"+id2
allbands = {}
for id3 in ['1','2','3','4']:
w_n = "w"+id3
allbands[w_n] = galsim.ImageF(chip.npix_x, chip.npix_y)
allbands[w_n].setOrigin(chip.bound.xmin, chip.bound.ymin)
allbands[w_n].wcs = wcs_fp
orders[o_n] = allbands
chip.img_stack[gn] = orders
else:
chip.img_stack = {}
# Get random generators for this chip # Get random generators for this chip
chip.rng_poisson, chip.poisson_noise = chip_utils.get_poisson( chip.rng_poisson, chip.poisson_noise = chip_utils.get_poisson(
seed=int(self.config["random_seeds"]["seed_poisson"]) + pointing.id*30 + chip.chipID, sky_level=0.) seed=int(self.config["random_seeds"]["seed_poisson"]) + pointing.id*30 + chip.chipID, sky_level=0.)
...@@ -96,8 +116,9 @@ class Observation(object): ...@@ -96,8 +116,9 @@ class Observation(object):
ra_cen = pointing.ra ra_cen = pointing.ra
dec_cen = pointing.dec dec_cen = pointing.dec
slsPSFOpt = True
# Prepare necessary chip properties for simulation # Prepare necessary chip properties for simulation
chip = self.prepare_chip_for_exposure(chip, ra_cen, dec_cen, pointing) chip = self.prepare_chip_for_exposure(chip, ra_cen, dec_cen, pointing, slsPSFOptim = slsPSFOpt)
# Initialize SimSteps # Initialize SimSteps
sim_steps = SimSteps(overall_config=self.config, sim_steps = SimSteps(overall_config=self.config,
......
...@@ -11,6 +11,8 @@ from observation_sim.mock_objects._util import integrate_sed_bandpass, getNormFa ...@@ -11,6 +11,8 @@ from observation_sim.mock_objects._util import integrate_sed_bandpass, getNormFa
getABMAG getABMAG
from observation_sim.mock_objects.SpecDisperser import SpecDisperser from observation_sim.mock_objects.SpecDisperser import SpecDisperser
from observation_sim.instruments.chip import chip_utils
class MockObject(object): class MockObject(object):
def __init__(self, param, logger=None): def __init__(self, param, logger=None):
...@@ -239,6 +241,44 @@ class MockObject(object): ...@@ -239,6 +241,44 @@ class MockObject(object):
def addSLStoChipImageWithPSF(self, sdp=None, chip=None, pos_img_local=[1, 1], psf_model=None, bandNo=1, grating_split_pos=3685, local_wcs=None, pos_img=None): def addSLStoChipImageWithPSF(self, sdp=None, chip=None, pos_img_local=[1, 1], psf_model=None, bandNo=1, grating_split_pos=3685, local_wcs=None, pos_img=None):
spec_orders = sdp.compute_spec_orders() spec_orders = sdp.compute_spec_orders()
if chip.slsPSFOptim:
for k, v in spec_orders.items():
img_s = v[0]
pos_shear = galsim.Shear(e=0., beta=(np.pi/2)*galsim.radians)
nan_ids = np.isnan(img_s)
if img_s[nan_ids].shape[0] > 0:
img_s[nan_ids] = 0
print("DEBUG: specImg nan num is", img_s[nan_ids].shape[0])
#########################################################
# img_s, orig_off = convolveImg(img_s, psf_img_m)
orig_off = [0,0]
origin_order_x = v[1] - orig_off[0]
origin_order_y = v[2] - orig_off[1]
specImg = galsim.ImageF(img_s)
specImg.wcs = local_wcs
specImg.setOrigin(origin_order_x, origin_order_y)
bounds = specImg.bounds & galsim.BoundsI(
0, chip.npix_x - 1, 0, chip.npix_y - 1)
if bounds.area() == 0:
continue
# orders = {'A': 'order1', 'B': 'order0', 'C': 'order2', 'D': 'order-1', 'E': 'order-2'}
orders = {'A': 'order1', 'B': 'order0', 'C': 'order0', 'D': 'order0', 'E': 'order0'}
gratingN = chip_utils.getChipSLSGratingID(chip.chipID)[1]
if pos_img_local[0] < grating_split_pos:
gratingN = chip_utils.getChipSLSGratingID(chip.chipID)[0]
chip.img_stack[gratingN][orders[k]]['w' + str(bandNo)].setOrigin(0, 0)
chip.img_stack[gratingN][orders[k]]['w' + str(bandNo)][bounds] = chip.img_stack[gratingN][orders[k]]['w' + str(bandNo)][bounds] + specImg[bounds]
chip.img_stack[gratingN][orders[k]]['w' + str(bandNo)].setOrigin(chip.bound.xmin, chip.bound.ymin)
else:
for k, v in spec_orders.items(): for k, v in spec_orders.items():
img_s = v[0] img_s = v[0]
# print(bandNo,k) # print(bandNo,k)
......
...@@ -20,8 +20,10 @@ import os ...@@ -20,8 +20,10 @@ import os
from astropy.io import fits from astropy.io import fits
from astropy.modeling.models import Gaussian2D from astropy.modeling.models import Gaussian2D
from scipy import signal from scipy import signal, interpolate
import datetime
import gc
from jax import numpy as jnp
LOG_DEBUG = False # ***# LOG_DEBUG = False # ***#
NPSF = 900 # ***# 30*30 NPSF = 900 # ***# 30*30
...@@ -479,6 +481,106 @@ class PSFInterpSLS(PSFModel): ...@@ -479,6 +481,106 @@ class PSFInterpSLS(PSFModel):
return PSF_int_trans, PSF_int return PSF_int_trans, PSF_int
def convolveFullImgWithPCAPSF(self, chip, folding_threshold=5.e-3):
keys_L1= chip_utils.getChipSLSGratingID(chip.chipID)
# keys_L2 = ['order-2','order-1','order0','order1','order2']
keys_L2 = ['order0','order1']
keys_L3 = ['w1','w2','w3','w4']
npca = 10
x_start = chip.x_cen/chip.pix_size - chip.npix_x / 2.
y_start = chip.y_cen/chip.pix_size - chip.npix_y / 2.
for i,gt in enumerate(keys_L1):
psfCo = self.grating1_data
if i > 0:
psfCo = self.grating2_data
for od in keys_L2:
psfCo_L2 = psfCo['order1']
if od in ['order-2','order-1','order0','order2']:
psfCo_L2 = psfCo['order0']
for w in keys_L3:
img = chip.img_stack[gt][od][w]
pcs = psfCo_L2['band'+w[1]]['band_data'][0].data
pos_p = psfCo_L2['band'+w[1]]['band_data'][1].data/chip.pix_size - np.array([y_start, x_start])
pc_coeff = psfCo_L2['band'+w[1]]['band_data'][2].data
# print("DEBUG-----------",np.max(pos_p[:,1]),np.min(pos_p[:,1]), np.max(pos_p[:,0]),np.min(pos_p[:,0]))
sum_img = np.sum(img.array)
# coeff_mat = np.zeros([npca, chip.npix_y, chip.npix_x])
# for m in np.arange(chip.npix_y):
# for n in np.arange(chip.npix_x):
# px = n
# py = m
# dist2 = (pos_p[:, 1] - px)*(pos_p[:, 1] - px) + (pos_p[:, 0] - py)*(pos_p[:, 0] - py)
# temp_sort_dist = np.zeros([dist2.shape[0], 2])
# temp_sort_dist[:, 0] = np.arange(0, dist2.shape[0], 1)
# temp_sort_dist[:, 1] = dist2
# # print(temp_sort_dist)
# dits2_sortlist = sorted(temp_sort_dist, key=lambda x: x[1])
# # print(dits2_sortlist)
# nearest4p = np.zeros([4, 3])
# pc_coeff_4p = np.zeros([npca, 4])
# for i in np.arange(4):
# smaller_ids = int(dits2_sortlist[i][0])
# nearest4p[i, 0] = pos_p[smaller_ids, 1]
# nearest4p[i, 1] = pos_p[smaller_ids, 0]
# # print(pos_p[smaller_ids, 1],pos_p[smaller_ids, 0])
# nearest4p[i, 2] = dits2_sortlist[i][1]
# pc_coeff_4p[:, i] = pc_coeff[npca, smaller_ids]
# # idw_dist = 1/(np.sqrt((px-nearest4p[:, 0]) * (px-nearest4p[:, 0]) + (
# # py-nearest4p[:, 1]) * (py-nearest4p[:, 1])))
# idw_dist = 1/(np.sqrt(nearest4p[:, 2]))
# coeff_int = np.zeros(npca)
# for i in np.arange(4):
# coeff_int = coeff_int + pc_coeff_4p[:, i]*idw_dist[i]
# coeff_mat[:, m, n] = coeff_int
m_size = int(pcs.shape[0]**0.5)
tmp_img = np.zeros_like(img.array,dtype=np.float32)
for j in np.arange(npca):
print(gt, od, w, j)
X_ = jnp.hstack((pos_p[:,1].flatten()[:, None], pos_p[:,0].flatten()[:, None]),dtype=np.float32)
Z_ = (pc_coeff[j].astype(np.float32)).flatten()
# print(pc_coeff[j].shape[0], pos_p[:,1].shape[0], pos_p[:,0].shape[0])
sub_size = 4
cx_len = int(chip.npix_x/sub_size)
cy_len = int(chip.npix_y/sub_size)
n_x = jnp.arange(0, cx_len, 1, dtype = int)
n_y = jnp.arange(0, cy_len, 1, dtype = int)
M, N = jnp.meshgrid(n_x, n_y)
t1=datetime.datetime.now()
# U = interpolate.griddata(X_, Z_, (M[0:cy_len, 0:cx_len],N[0:cy_len, 0:cx_len]),
# method='nearest',fill_value=1.0)
U1 = interpolate.griddata(X_, Z_, (M[0:cy_len, 0:cx_len],N[0:cy_len, 0:cx_len]),
method='nearest',fill_value=1.0)
U = np.zeros_like(chip.img.array, dtype=np.float32)
for mi in np.arange(cx_len):
for mj in np.arange(cx_len):
U[mi*sub_size:(mi+1)*sub_size, mj*sub_size:(mj+1)*sub_size]=U1[mi,mj]
t2=datetime.datetime.now()
print("time interpolate:", t2-t1)
img_tmp = img.array*U
psf = pcs[:, j].reshape(m_size, m_size)
tmp_img = tmp_img + signal.fftconvolve(img_tmp, psf, mode='same', axes=None)
t3=datetime.datetime.now()
print("time convole:", t3-t2)
del U
del U1
chip.img = chip.img + tmp_img*sum_img/np.sum(tmp_img)
del tmp_img
gc.collect()
# pixSize = np.rad2deg(self.pixsize*1e-3/28)*3600 #set psf pixsize # pixSize = np.rad2deg(self.pixsize*1e-3/28)*3600 #set psf pixsize
# #
# # assert self.iccd == int(chip.getChipLabel(chipID=chip.chipID)), 'ERROR: self.iccd != chip.chipID' # # assert self.iccd == int(chip.getChipLabel(chipID=chip.chipID)), 'ERROR: self.iccd != chip.chipID'
......
...@@ -217,6 +217,28 @@ def add_objects(self, chip, filt, tel, pointing, catalog, obs_param): ...@@ -217,6 +217,28 @@ def add_objects(self, chip, filt, tel, pointing, catalog, obs_param):
obj.unload_SED() obj.unload_SED()
del obj del obj
gc.collect() gc.collect()
if chip.survey_type == "spectroscopic" and not self.overall_config["run_option"]["out_cat_only"] and chip.slsPSFOptim:
# from observation_sim.instruments.chip import chip_utils as chip_utils
# gn = chip_utils.getChipSLSGratingID(chip.chipID)[0]
# img1 = np.zeros([2,chip.img.array.shape[0],chip.img.array.shape[1]])
# for id1 in np.arange(2):
# gn = chip_utils.getChipSLSGratingID(chip.chipID)[id1]
# img_i = 0
# for id2 in ['0','1']:
# o_n = "order"+id2
# for id3 in ['1','2','3','4']:
# w_n = "w"+id3
# img1[img_i] = img1[img_i] + chip.img_stack[gn][o_n][w_n].array
# img_i = img_i + 1
# from astropy.io import fits
# fits.writeto('order0.fits',img1[0],overwrite=True)
# fits.writeto('order1.fits',img1[1],overwrite=True)
psf_model.convolveFullImgWithPCAPSF(chip)
del psf_model del psf_model
gc.collect() gc.collect()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment