Skip to content
add_objects.py 10.4 KiB
Newer Older
import os
import gc
import psutil
Fang Yuedong's avatar
Fang Yuedong committed
import traceback
import numpy as np
import galsim
from ObservationSim._util import get_shear_field
from ObservationSim.PSF import PSFGauss, FieldDistortion, PSFInterp, PSFInterpSLS

from astropy.time import Time
from datetime import datetime, timezone

def add_objects(self, chip, filt, tel, pointing, catalog, obs_param):

    # Get exposure time
    if (obs_param) and ("exptime" in obs_param) and (obs_param["exptime"] is not None):
        exptime = obs_param["exptime"]
    else:
        exptime = pointing.exp_time

    # Load catalogues
    if catalog is None:
        self.chip_output.Log_error("Catalog interface class must be specified for SCIE-OBS")
        raise ValueError("Catalog interface class must be specified for SCIE-OBS")
    cat = catalog(config=self.overall_config, chip=chip, pointing=pointing, chip_output=self.chip_output, filt=filt)

    # Prepare output file(s) for this chip
    # [NOTE] Headers of output .cat file may be updated by Catalog instance
    # this should be called after the creation of Catalog instance
    self.chip_output.create_output_file()

    # Prepare the PSF model
    if self.overall_config["psf_setting"]["psf_model"] == "Gauss":
        psf_model = PSFGauss(chip=chip, psfRa=self.overall_config["psf_setting"]["psf_rcont"])
    elif self.overall_config["psf_setting"]["psf_model"] == "Interp":
        if chip.survey_type == "spectroscopic":
            psf_model = PSFInterpSLS(chip, filt, PSF_data_prefix=self.overall_config["psf_setting"]["psf_sls_dir"])
        else:
            psf_model = PSFInterp(chip=chip, npsf=chip.n_psf_samples, PSF_data_file=self.overall_config["psf_setting"]["psf_pho_dir"])
    else:
        self.chip_output.Log_error("unrecognized PSF model type!!", flush=True)
    
    # Apply field distortion model
    if obs_param["field_dist"] == True:
        fd_model = FieldDistortion(chip=chip, img_rot=pointing.img_pa.deg)
    else:
        fd_model = None

    # Update limiting magnitudes for all filters based on the exposure time
    # Get the filter which will be used for magnitude cut
    for ifilt in range(len(self.all_filters)):
        temp_filter = self.all_filters[ifilt]
        temp_filter.update_limit_saturation_mags(exptime=pointing.get_full_depth_exptime(temp_filter.filter_type), chip=chip)
        if temp_filter.filter_type.lower() == self.overall_config["obs_setting"]["cut_in_band"].lower():
            cut_filter = temp_filter

    # Read in shear values from configuration file if the constant shear type is used
    if self.overall_config["shear_setting"]["shear_type"] == "constant":
        g1_field, g2_field, _ = get_shear_field(config=self.overall_config)
        self.chip_output.Log_info("Use constant shear: g1=%.5f, g2=%.5f"%(g1_field, g2_field))

    # Get chip WCS
    if not hasattr(self, 'h_ext'):
        _, _ = self.prepare_headers(chip=chip, pointing=pointing)
    chip_wcs = galsim.FitsWCS(header = self.h_ext)
    
    # Loop over objects
    nobj = len(cat.objs)
    missed_obj = 0
    bright_obj = 0
    dim_obj = 0
    for j in range(nobj):
        # # [DEBUG] [TODO]
        # if j >= 10:
        #     break
        obj = cat.objs[j]

        # load and convert SED; also caculate object's magnitude in all CSST bands
        try:
            sed_data = cat.load_sed(obj)
            norm_filt = cat.load_norm_filt(obj)
            obj.sed, obj.param["mag_%s"%filt.filter_type.lower()], obj.param["flux_%s"%filt.filter_type.lower()] = cat.convert_sed(
                mag=obj.param["mag_use_normal"],
                sed=sed_data,
                target_filt=filt, 
                norm_filt=norm_filt,
            )
            _, obj.param["mag_%s"%cut_filter.filter_type.lower()], obj.param["flux_%s"%cut_filter.filter_type.lower()] = cat.convert_sed(
                mag=obj.param["mag_use_normal"],
                sed=sed_data,
                target_filt=cut_filter, 
                norm_filt=norm_filt,
            )
        except Exception as e:
            traceback.print_exc()
            self.chip_output.Log_error(e)
            continue

        # [TODO] Testing
        # self.chip_output.Log_info("mag_%s = %.3f"%(filt.filter_type.lower(), obj.param["mag_%s"%filt.filter_type.lower()]))

        # Exclude very bright/dim objects (for now)
        if cut_filter.is_too_bright(
            mag=obj.param["mag_%s"%self.overall_config["obs_setting"]["cut_in_band"].lower()],
            margin=self.overall_config["obs_setting"]["mag_sat_margin"]):
            self.chip_output.Log_info("obj %s too birght!! mag_%s = %.3f"%(obj.id, cut_filter.filter_type, obj.param["mag_%s"%self.overall_config["obs_setting"]["cut_in_band"].lower()]))
            bright_obj += 1
            obj.unload_SED()
            continue
        if filt.is_too_dim(
            mag=obj.getMagFilter(filt),
            margin=self.overall_config["obs_setting"]["mag_lim_margin"]):
            self.chip_output.Log_info("obj %s too dim!! mag_%s = %.3f"%(obj.id, filt.filter_type, obj.getMagFilter(filt)))
            dim_obj += 1
            obj.unload_SED()
            continue

        # Get corresponding shear values
        if self.overall_config["shear_setting"]["shear_type"] == "constant":
            if obj.type == 'star':
                obj.g1, obj.g2 = 0., 0.
            else:
                # Figure out shear fields from overall configuration shear setting
                obj.g1, obj.g2 = g1_field, g2_field
        elif self.overall_config["shear_setting"]["shear_type"] == "catalog":
            pass
        else:
            self.chip_output.Log_error("Unknown shear input")
            raise ValueError("Unknown shear input")

        # Get position of object on the focal plane
        pos_img, _, _, _, fd_shear = obj.getPosImg_Offset_WCS(img=chip.img, fdmodel=fd_model, chip=chip, verbose=False, chip_wcs=chip_wcs, img_header=self.h_ext)

        # [TODO] For now, only consider objects which their centers (after field distortion) are projected within the focal plane
        # Otherwise they will be considered missed objects
        # if pos_img.x == -1 or pos_img.y == -1 or (not chip.isContainObj(x_image=pos_img.x, y_image=pos_img.y, margin=0.)):
        if pos_img.x == -1 or pos_img.y == -1:
            self.chip_output.Log_info('obj_ra = %.6f, obj_dec = %.6f, obj_ra_orig = %.6f, obj_dec_orig = %.6f'%(obj.ra, obj.dec, obj.ra_orig, obj.dec_orig))
            self.chip_output.Log_error("Objected missed: %s"%(obj.id))
            missed_obj += 1
            obj.unload_SED()
            continue

        # Draw object & update output catalog
        try:
            if self.overall_config["run_option"]["out_cat_only"]:
                isUpdated = True
                obj.real_pos = obj.getRealPos(chip.img, global_x=obj.posImg.x, global_y=obj.posImg.y, img_real_wcs=obj.chip_wcs)
                pos_shear = 0.
            elif chip.survey_type == "photometric" and not self.overall_config["run_option"]["out_cat_only"]:
                isUpdated, pos_shear = obj.drawObj_multiband(
                    tel=tel,
                    pos_img=pos_img, 
                    psf_model=psf_model, 
                    bandpass_list=filt.bandpass_sub_list, 
                    filt=filt, 
                    chip=chip, 
                    g1=obj.g1, 
                    g2=obj.g2, 
                    fd_shear=fd_shear)

            elif chip.survey_type == "spectroscopic" and not self.overall_config["run_option"]["out_cat_only"]:
                isUpdated, pos_shear = obj.drawObj_slitless(
                    tel=tel, 
                    pos_img=pos_img, 
                    psf_model=psf_model, 
                    bandpass_list=filt.bandpass_sub_list, 
                    filt=filt, 
                    chip=chip, 
                    g1=obj.g1, 
                    g2=obj.g2, 
                    normFilter=norm_filt,
                    fd_shear=fd_shear)

            if isUpdated == 1:
                # TODO: add up stats
                self.chip_output.cat_add_obj(obj, pos_img, pos_shear)
                pass
            elif isUpdated == 0:
                missed_obj += 1
                self.chip_output.Log_error("Objected missed: %s"%(obj.id))
            else:
                self.chip_output.Log_error("Draw error, object omitted: %s"%(obj.id))
                continue
        except Exception as e:
            traceback.print_exc()
            self.chip_output.Log_error(e)

        # Unload SED:
        obj.unload_SED()
        del obj
        gc.collect()
    del psf_model
    gc.collect()

    self.chip_output.Log_info("Running checkpoint #1 (Object rendering finished): pointing-%d chip-%d pid-%d memory-%6.2fGB"%(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ))

    self.chip_output.Log_info("# objects that are too bright %d out of %d"%(bright_obj, nobj))
    self.chip_output.Log_info("# objects that are too dim %d out of %d"%(dim_obj, nobj))
    self.chip_output.Log_info("# objects that are missed %d out of %d"%(missed_obj, nobj))

    # Apply flat fielding (with shutter effects)
    flat_normal = np.ones_like(chip.img.array)
    if obs_param["flat_fielding"] == True:
        flat_normal = flat_normal * chip.flat_img.array / np.mean(chip.flat_img.array)
    if obs_param["shutter_effect"] == True:
        flat_normal = flat_normal * chip.shutter_img
        flat_normal = np.array(flat_normal, dtype='float32')
        self.updateHeaderInfo(header_flag='ext', keys = ['SHTSTAT'], values = [True])
    else:
        self.updateHeaderInfo(header_flag='ext', keys = ['SHTSTAT','SHTOPEN0','SHTOPEN1','SHTCLOS0','SHTCLOS1'], values = [True,'','','',''])
    chip.img *= flat_normal
    del flat_normal


    # renew header info
    datetime_obs = datetime.utcfromtimestamp(pointing.timestamp)
    datetime_obs = datetime_obs.replace(tzinfo=timezone.utc)
    t_obs = Time(datetime_obs)
    
    ##ccd刷新2s,等待0.s,开始曝光
    t_obs_renew = Time(t_obs.mjd - (2.+0.) / 86400., format="mjd")

    t_obs_utc = datetime.utcfromtimestamp(np.round(datetime.utcfromtimestamp(t_obs_renew.unix).replace(tzinfo=timezone.utc).timestamp(), 1))
    self.updateHeaderInfo(header_flag='prim', keys = ['DATE-OBS'], values = [t_obs_utc.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-5]])

    #dark time : 曝光时间+刷新后等带时间0.s+关快门后读出前等待0.s
    self.updateHeaderInfo(header_flag='ext', keys = ['DARKTIME'], values = [0.+0.+pointing.exp_time])
    return chip, filt, tel, pointing