Skip to content
add_objects.py 13.6 KiB
Newer Older
Fang Yuedong's avatar
Fang Yuedong committed
import os
import gc
import psutil
import traceback
import numpy as np
import galsim
from observation_sim._util import get_shear_field
Fang Yuedong's avatar
Fang Yuedong committed
from observation_sim.psf import PSFGauss, FieldDistortion, PSFInterp, PSFInterpSLS
Fang Yuedong's avatar
Fang Yuedong committed

from astropy.time import Time
from datetime import datetime, timezone

def _is_obj_valid(self, obj):
    if obj.param['star'] == 4:
        # Currently there's no parameter checks for 'calib' type
        return True
    pos_keys = ['ra', 'dec']
    shape_keys = ['hlr_bulge', 'hlr_disk', 'e1_disk', 'e2_disk', 'e1_bulge', 'e2_bulge']
    if any(obj.param[key] == -999. for key in pos_keys):
        msg = 'One or more positional information (ra, dec) is missing'
        self.chip_output.Log_error(msg)
        return False
    if obj.param['star'] == 0 and any(obj.param[key] == -999. for key in shape_keys):
        msg = 'One or more shape information (hlr_bulge, hlr_disk, e1_disk, e2_disk, e1_bulge, e2_bulge) is missing'
        self.chip_output.Log_error(msg)
        return False
    return True

Fang Yuedong's avatar
Fang Yuedong committed

def add_objects(self, chip, filt, tel, pointing, catalog, obs_param):

    # Get exposure time
    if (obs_param) and ("exptime" in obs_param) and (obs_param["exptime"] is not None):
        exptime = obs_param["exptime"]
    else:
        exptime = pointing.exp_time

    # Load catalogues
    if catalog is None:
        self.chip_output.Log_error(
            "Catalog interface class must be specified for SCIE-OBS")
        raise ValueError(
            "Catalog interface class must be specified for SCIE-OBS")
    cat = catalog(config=self.overall_config, chip=chip,
                  pointing=pointing, chip_output=self.chip_output, filt=filt)

    # Prepare output file(s) for this chip
    # [NOTE] Headers of output .cat file may be updated by Catalog instance
    # this should be called after the creation of Catalog instance
    self.chip_output.create_output_file()

    # Prepare the PSF model
    if self.overall_config["psf_setting"]["psf_model"] == "Gauss":
        psf_model = PSFGauss(
            chip=chip, psfRa=self.overall_config["psf_setting"]["psf_rcont"])
    elif self.overall_config["psf_setting"]["psf_model"] == "Interp":
        if chip.survey_type == "spectroscopic":
            psf_model = PSFInterpSLS(
                chip, filt, PSF_data_prefix=self.overall_config["psf_setting"]["psf_sls_dir"])
        else:
            psf_model = PSFInterp(chip=chip, npsf=chip.n_psf_samples,
                                  PSF_data_file=self.overall_config["psf_setting"]["psf_pho_dir"])
    else:
        self.chip_output.Log_error("unrecognized PSF model type!!", flush=True)

    # Apply field distortion model
Wei Chengliang's avatar
Wei Chengliang committed
    if obs_param["field_dist"] is True:
Fang Yuedong's avatar
Fang Yuedong committed
        fd_model = FieldDistortion(chip=chip, img_rot=pointing.img_pa.deg)
    else:
        fd_model = None

    # Update limiting magnitudes for all filters based on the exposure time
    # Get the filter which will be used for magnitude cut
    for ifilt in range(len(self.all_filters)):
        temp_filter = self.all_filters[ifilt]
        temp_filter.update_limit_saturation_mags(
Fang Yuedong's avatar
Fang Yuedong committed
            exptime=pointing.exp_time, 
            full_depth_exptime=pointing.get_full_depth_exptime(temp_filter.filter_type), 
            chip=chip)
Fang Yuedong's avatar
Fang Yuedong committed
        if temp_filter.filter_type.lower() == self.overall_config["obs_setting"]["cut_in_band"].lower():
            cut_filter = temp_filter

    # Read in shear values from configuration file if the constant shear type is used
    if self.overall_config["shear_setting"]["shear_type"] == "constant":
        g1_field, g2_field, _ = get_shear_field(config=self.overall_config)
        self.chip_output.Log_info(
            "Use constant shear: g1=%.5f, g2=%.5f" % (g1_field, g2_field))

    # Get chip WCS
    if not hasattr(self, 'h_ext'):
        _, _ = self.prepare_headers(chip=chip, pointing=pointing)

    chip_wcs = galsim.FitsWCS(header=self.h_ext)

    # Loop over objects
    nobj = len(cat.objs)
    missed_obj = 0
    bright_obj = 0
    dim_obj = 0
    for j in range(nobj):
        # # [DEBUG] [TODO]
        # if j >= 10:
        #     break
        obj = cat.objs[j]

        if not self._is_obj_valid(obj):
            continue

Fang Yuedong's avatar
Fang Yuedong committed
        # load and convert SED; also caculate object's magnitude in all CSST bands
        try:
            sed_data = cat.load_sed(obj)
            norm_filt = cat.load_norm_filt(obj)
Fang Yuedong's avatar
Fang Yuedong committed
            
Fang Yuedong's avatar
Fang Yuedong committed
            obj.sed, obj.param["mag_%s" % filt.filter_type.lower()], obj.param["flux_%s" % filt.filter_type.lower()] = cat.convert_sed(
                mag=obj.param["mag_use_normal"],
                sed=sed_data,
                target_filt=filt,
                norm_filt=norm_filt,
                mu=obj.mu
            )
Fang Yuedong's avatar
Fang Yuedong committed

Fang Yuedong's avatar
Fang Yuedong committed
            _, obj.param["mag_%s" % cut_filter.filter_type.lower()], obj.param["flux_%s" % cut_filter.filter_type.lower()] = cat.convert_sed(
                mag=obj.param["mag_use_normal"],
                sed=sed_data,
                target_filt=cut_filter,
Fang Yuedong's avatar
Fang Yuedong committed
                norm_filt=(norm_filt if norm_filt else filt),
Fang Yuedong's avatar
Fang Yuedong committed
                mu=obj.mu
            )
        except Exception as e:
            traceback.print_exc()
            self.chip_output.Log_error(e)
            continue

        # [TODO] Testing
Fang Yuedong's avatar
Fang Yuedong committed
        # print(obj.param["mag_%s" % filt.filter_type.lower()], obj.param["mag_%s" % cut_filter.filter_type.lower()])
Fang Yuedong's avatar
Fang Yuedong committed
        # self.chip_output.Log_info("mag_%s = %.3f"%(filt.filter_type.lower(), obj.param["mag_%s"%filt.filter_type.lower()]))

        # Exclude very bright/dim objects (for now)
        if cut_filter.is_too_bright(
                mag=obj.param["mag_%s" %
                              self.overall_config["obs_setting"]["cut_in_band"].lower()],
                margin=self.overall_config["obs_setting"]["mag_sat_margin"]):
Fang Yuedong's avatar
Fang Yuedong committed
            self.chip_output.Log_info("obj %s too bright!! mag_%s = %.3f" % (
Fang Yuedong's avatar
Fang Yuedong committed
                obj.id, cut_filter.filter_type, obj.param["mag_%s" % self.overall_config["obs_setting"]["cut_in_band"].lower()]))
            bright_obj += 1
            obj.unload_SED()
            continue
        if filt.is_too_dim(
                mag=obj.getMagFilter(filt),
                margin=self.overall_config["obs_setting"]["mag_lim_margin"]):
            self.chip_output.Log_info("obj %s too dim!! mag_%s = %.3f" % (
                obj.id, filt.filter_type, obj.getMagFilter(filt)))
            dim_obj += 1
            obj.unload_SED()
            continue

        # Get corresponding shear values
        if self.overall_config["shear_setting"]["shear_type"] == "constant":
            if obj.type == 'star':
                obj.g1, obj.g2 = 0., 0.
            else:
                # Figure out shear fields from overall configuration shear setting
                obj.g1, obj.g2 = g1_field, g2_field
        elif self.overall_config["shear_setting"]["shear_type"] == "catalog":
            pass
        else:
            self.chip_output.Log_error("Unknown shear input")
            raise ValueError("Unknown shear input")

        # Get position of object on the focal plane
        pos_img, _, _, _, fd_shear = obj.getPosImg_Offset_WCS(
            img=chip.img, fdmodel=fd_model, chip=chip, verbose=False, chip_wcs=chip_wcs, img_header=self.h_ext, ra_offset=self.ra_offset, dec_offset=self.dec_offset)
Fang Yuedong's avatar
Fang Yuedong committed

        # [TODO] For now, only consider objects which their centers (after field distortion) are projected within the focal plane
        # Otherwise they will be considered missed objects
        # if pos_img.x == -1 or pos_img.y == -1 or (not chip.isContainObj(x_image=pos_img.x, y_image=pos_img.y, margin=0.)):
        if pos_img.x == -1 or pos_img.y == -1:
            self.chip_output.Log_info('obj_ra = %.6f, obj_dec = %.6f, obj_ra_orig = %.6f, obj_dec_orig = %.6f' % (
                obj.ra, obj.dec, obj.ra_orig, obj.dec_orig))
            self.chip_output.Log_error("Object missed: %s" % (obj.id))
Fang Yuedong's avatar
Fang Yuedong committed
            missed_obj += 1
            obj.unload_SED()
            continue

        # Draw object & update output catalog
        try:
            if self.overall_config["run_option"]["out_cat_only"]:
                isUpdated = True
                obj.real_pos = obj.getRealPos(
                    chip.img, global_x=obj.posImg.x, global_y=obj.posImg.y, img_real_wcs=obj.chip_wcs)
                pos_shear = 0.
            elif chip.survey_type == "photometric" and not self.overall_config["run_option"]["out_cat_only"]:
                isUpdated, pos_shear = obj.drawObj_multiband(
                    tel=tel,
                    pos_img=pos_img,
                    psf_model=psf_model,
                    bandpass_list=filt.bandpass_sub_list,
                    filt=filt,
                    chip=chip,
                    g1=obj.g1,
                    g2=obj.g2,
                    exptime=exptime,
                    fd_shear=fd_shear)

            elif chip.survey_type == "spectroscopic" and not self.overall_config["run_option"]["out_cat_only"]:
                isUpdated, pos_shear = obj.drawObj_slitless(
                    tel=tel,
                    pos_img=pos_img,
                    psf_model=psf_model,
                    bandpass_list=filt.bandpass_sub_list,
                    filt=filt,
                    chip=chip,
                    g1=obj.g1,
                    g2=obj.g2,
                    exptime=exptime,
                    normFilter=norm_filt,
                    fd_shear=fd_shear)

            if isUpdated == 1:
                # TODO: add up stats
                self.chip_output.cat_add_obj(
                    obj, pos_img, pos_shear, ra_offset=self.ra_offset, dec_offset=self.dec_offset)
Fang Yuedong's avatar
Fang Yuedong committed
                pass
            elif isUpdated == 0:
                missed_obj += 1
                self.chip_output.Log_error("Object missed: %s" % (obj.id))
Fang Yuedong's avatar
Fang Yuedong committed
            else:
                self.chip_output.Log_error(
                    "Draw error, object omitted: %s" % (obj.id))
                continue
        except Exception as e:
            traceback.print_exc()
            self.chip_output.Log_error(e)
            self.chip_output.Log_error(
                "pointing: #%d, chipID: %d" % (pointing.id, chip.chipID))
            if obj.type == "galaxy":
                self.chip_output.Log_error("obj id: %s" % (obj.param['id']))
                self.chip_output.Log_error("    e1: %.5f\n    e2: %.5f\n    size: %f\n    bfrac: %f\n    detA: %f\n    g1: %.5f\n    g2: %.5f\n" % (
                    obj.param['e1'], obj.param['e2'], obj.param['size'], obj.param['bfrac'], obj.param['detA'], obj.param['g1'], obj.param['g2']))
        # Unload SED:
        obj.unload_SED()
        del obj
Wei Chengliang's avatar
Wei Chengliang committed
        # gc.collect()
    cat.starDDL.freeGlobeData()
    del cat.starDDL
Zhang Xin's avatar
Zhang Xin committed

    if chip.survey_type == "spectroscopic" and not self.overall_config["run_option"]["out_cat_only"] and chip.slsPSFOptim:
        # from observation_sim.instruments.chip import chip_utils as chip_utils
        # gn = chip_utils.getChipSLSGratingID(chip.chipID)[0]
        # img1 = np.zeros([2,chip.img.array.shape[0],chip.img.array.shape[1]])

        # for id1 in np.arange(2):
        #     gn = chip_utils.getChipSLSGratingID(chip.chipID)[id1]
        #     img_i = 0
        #     for id2 in ['0','1']:
        #         o_n = "order"+id2
        #         for id3 in ['1','2','3','4']:
        #             w_n = "w"+id3
        #             img1[img_i] = img1[img_i] + chip.img_stack[gn][o_n][w_n].array
        #         img_i = img_i + 1
        # from astropy.io import fits
        # fits.writeto('order0.fits',img1[0],overwrite=True)
        # fits.writeto('order1.fits',img1[1],overwrite=True)

        psf_model.convolveFullImgWithPCAPSF(chip)
Fang Yuedong's avatar
Fang Yuedong committed
    del psf_model
    gc.collect()

    self.chip_output.Log_info("Running checkpoint #1 (Object rendering finished): pointing-%d chip-%d pid-%d memory-%6.2fGB" %
                              (pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024)))

    self.chip_output.Log_info(
        "# objects that are too bright %d out of %d" % (bright_obj, nobj))
    self.chip_output.Log_info(
        "# objects that are too dim %d out of %d" % (dim_obj, nobj))
    self.chip_output.Log_info(
        "# objects that are missed %d out of %d" % (missed_obj, nobj))

    # Apply flat fielding (with shutter effects)
    flat_normal = np.ones_like(chip.img.array)
Wei Chengliang's avatar
Wei Chengliang committed
    if obs_param["flat_fielding"] is True:
Fang Yuedong's avatar
Fang Yuedong committed
        flat_normal = flat_normal * chip.flat_img.array / \
            np.mean(chip.flat_img.array)
Wei Chengliang's avatar
Wei Chengliang committed
    if obs_param["shutter_effect"] is True:
Fang Yuedong's avatar
Fang Yuedong committed
        flat_normal = flat_normal * chip.shutter_img
        flat_normal = np.array(flat_normal, dtype='float32')
        self.updateHeaderInfo(header_flag='ext', keys=[
                              'SHTSTAT'], values=[True])
    else:
        self.updateHeaderInfo(header_flag='ext', keys=['SHTSTAT', 'SHTOPEN1', 'SHTCLOS0'], values=[
                              True, self.h_ext['SHTCLOS1'], self.h_ext['SHTOPEN0']])
    chip.img *= flat_normal
    del flat_normal

    # renew header info
    datetime_obs = datetime.utcfromtimestamp(pointing.timestamp)
    datetime_obs = datetime_obs.replace(tzinfo=timezone.utc)
    t_obs = Time(datetime_obs)

    # ccd刷新2s,等待0.s,开始曝光
    t_obs_renew = Time(t_obs.mjd - (2.+0.) / 86400., format="mjd")

    t_obs_utc = datetime.utcfromtimestamp(np.round(datetime.utcfromtimestamp(
        t_obs_renew.unix).replace(tzinfo=timezone.utc).timestamp(), 1))
    self.updateHeaderInfo(header_flag='prim', keys=[
                          'DATE-OBS'], values=[t_obs_utc.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-5]])

    # dark time : 曝光时间+刷新后等带时间0.s+关快门后读出前等待0.s
    self.updateHeaderInfo(header_flag='ext', keys=[
                          'DARKTIME'], values=[0.+0.+pointing.exp_time])

    return chip, filt, tel, pointing