import os import gc import psutil import traceback import numpy as np import galsim from ObservationSim._util import get_shear_field from ObservationSim.PSF import PSFGauss, FieldDistortion, PSFInterp, PSFInterpSLS def add_objects(self, chip, filt, tel, pointing, catalog, obs_param): # Prepare output file(s) for this chip self.chip_output.create_output_file() # Load catalogues if catalog is None: self.chip_output.Log_error("Catalog interface class must be specified for SCIE-OBS") raise ValueError("Catalog interface class must be specified for SCIE-OBS") cat = catalog(config=self.overall_config, chip=chip, pointing=pointing, chip_output=self.chip_output, filt=filt) # Prepare the PSF model if self.overall_config["psf_setting"]["psf_model"] == "Gauss": psf_model = PSFGauss(chip=chip, psfRa=self.overall_config["psf_setting"]["psf_rcont"]) elif self.overall_config["psf_setting"]["psf_model"] == "Interp": if chip.survey_type == "spectroscopic": psf_model = PSFInterpSLS(chip, filt, PSF_data_prefix=self.overall_config["psf_setting"]["psf_sls_dir"]) else: psf_model = PSFInterp(chip=chip, npsf=chip.n_psf_samples, PSF_data_file=self.overall_config["psf_setting"]["psf_pho_dir"]) else: self.chip_output.Log_error("unrecognized PSF model type!!", flush=True) # Apply field distortion model if obs_param["field_dist"] == True: fd_model = FieldDistortion(chip=chip, img_rot=pointing.img_pa.deg) else: fd_model = None # Update limiting magnitudes for all filters based on the exposure time # Get the filter which will be used for magnitude cut for ifilt in range(len(self.all_filters)): temp_filter = self.all_filters[ifilt] temp_filter.update_limit_saturation_mags(exptime=pointing.get_full_depth_exptime(temp_filter.filter_type), chip=chip) if temp_filter.filter_type.lower() == self.overall_config["obs_setting"]["cut_in_band"].lower(): cut_filter = temp_filter # Read in shear values from configuration file if the constant shear type is used if self.overall_config["shear_setting"]["shear_type"] == "constant": g1_field, g2_field, _ = get_shear_field(config=self.overall_config) self.chip_output.Log_info("Use constant shear: g1=%.5f, g2=%.5f"%(g1_field, g2_field)) # Get chip WCS if not hasattr(self, 'h_ext'): _, _ = self.prepare_headers(chip=chip, pointing=pointing) chip_wcs = galsim.FitsWCS(header = self.h_ext) # Loop over objects nobj = len(cat.objs) missed_obj = 0 bright_obj = 0 dim_obj = 0 for j in range(nobj): # # [DEBUG] [TODO] # if j >= 10: # break obj = cat.objs[j] # load and convert SED; also caculate object's magnitude in all CSST bands try: sed_data = cat.load_sed(obj) norm_filt = cat.load_norm_filt(obj) obj.sed, obj.param["mag_%s"%filt.filter_type.lower()], obj.param["flux_%s"%filt.filter_type.lower()] = cat.convert_sed( mag=obj.param["mag_use_normal"], sed=sed_data, target_filt=filt, norm_filt=norm_filt, ) _, obj.param["mag_%s"%cut_filter.filter_type.lower()], obj.param["flux_%s"%cut_filter.filter_type.lower()] = cat.convert_sed( mag=obj.param["mag_use_normal"], sed=sed_data, target_filt=cut_filter, norm_filt=norm_filt, ) except Exception as e: traceback.print_exc() self.chip_output.Log_error(e) continue # [TODO] Testing # self.chip_output.Log_info("mag_%s = %.3f"%(filt.filter_type.lower(), obj.param["mag_%s"%filt.filter_type.lower()])) # Exclude very bright/dim objects (for now) if cut_filter.is_too_bright( mag=obj.param["mag_%s"%self.overall_config["obs_setting"]["cut_in_band"].lower()], margin=self.overall_config["obs_setting"]["mag_sat_margin"]): self.chip_output.Log_info("obj %s too birght!! mag_%s = %.3f"%(obj.id, cut_filter.filter_type, obj.param["mag_%s"%self.overall_config["obs_setting"]["cut_in_band"].lower()])) bright_obj += 1 obj.unload_SED() continue if filt.is_too_dim( mag=obj.getMagFilter(filt), margin=self.overall_config["obs_setting"]["mag_lim_margin"]): self.chip_output.Log_info("obj %s too dim!! mag_%s = %.3f"%(obj.id, filt.filter_type, obj.getMagFilter(filt))) dim_obj += 1 obj.unload_SED() continue # Get corresponding shear values if self.overall_config["shear_setting"]["shear_type"] == "constant": if obj.type == 'star': obj.g1, obj.g2 = 0., 0. else: # Figure out shear fields from overall configuration shear setting obj.g1, obj.g2 = g1_field, g2_field elif self.overall_config["shear_setting"]["shear_type"] == "catalog": pass else: self.chip_output.Log_error("Unknown shear input") raise ValueError("Unknown shear input") # Get position of object on the focal plane pos_img, _, _, _, fd_shear = obj.getPosImg_Offset_WCS(img=chip.img, fdmodel=fd_model, chip=chip, verbose=False, chip_wcs=chip_wcs, img_header=self.h_ext) # [TODO] For now, only consider objects which their centers (after field distortion) are projected within the focal plane # Otherwise they will be considered missed objects # if pos_img.x == -1 or pos_img.y == -1 or (not chip.isContainObj(x_image=pos_img.x, y_image=pos_img.y, margin=0.)): if pos_img.x == -1 or pos_img.y == -1: self.chip_output.Log_info('obj_ra = %.6f, obj_dec = %.6f, obj_ra_orig = %.6f, obj_dec_orig = %.6f'%(obj.ra, obj.dec, obj.ra_orig, obj.dec_orig)) self.chip_output.Log_error("Objected missed: %s"%(obj.id)) missed_obj += 1 obj.unload_SED() continue # Draw object & update output catalog try: if self.overall_config["run_option"]["out_cat_only"]: isUpdated = True obj.real_pos = obj.getRealPos(chip.img, global_x=obj.posImg.x, global_y=obj.posImg.y, img_real_wcs=obj.chip_wcs) pos_shear = 0. elif chip.survey_type == "photometric" and not self.overall_config["run_option"]["out_cat_only"]: isUpdated, pos_shear = obj.drawObj_multiband( tel=tel, pos_img=pos_img, psf_model=psf_model, bandpass_list=filt.bandpass_sub_list, filt=filt, chip=chip, g1=obj.g1, g2=obj.g2, exptime=pointing.exp_time, fd_shear=fd_shear) elif chip.survey_type == "spectroscopic" and not self.overall_config["run_option"]["out_cat_only"]: isUpdated, pos_shear = obj.drawObj_slitless( tel=tel, pos_img=pos_img, psf_model=psf_model, bandpass_list=filt.bandpass_sub_list, filt=filt, chip=chip, g1=obj.g1, g2=obj.g2, exptime=pointing.exp_time, normFilter=norm_filt, fd_shear=fd_shear) if isUpdated == 1: # TODO: add up stats self.chip_output.cat_add_obj(obj, pos_img, pos_shear) pass elif isUpdated == 0: missed_obj += 1 self.chip_output.Log_error("Objected missed: %s"%(obj.id)) else: self.chip_output.Log_error("Draw error, object omitted: %s"%(obj.id)) continue except Exception as e: traceback.print_exc() self.chip_output.Log_error(e) # Unload SED: obj.unload_SED() del obj gc.collect() del psf_model gc.collect() self.chip_output.Log_info("Running checkpoint #1 (Object rendering finished): pointing-%d chip-%d pid-%d memory-%6.2fGB"%(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )) self.chip_output.Log_info("# objects that are too bright %d out of %d"%(bright_obj, nobj)) self.chip_output.Log_info("# objects that are too dim %d out of %d"%(dim_obj, nobj)) self.chip_output.Log_info("# objects that are missed %d out of %d"%(missed_obj, nobj)) # Apply flat fielding (with shutter effects) flat_normal = np.ones_like(chip.img.array) if obs_param["flat_fielding"] == True: flat_normal = flat_normal * chip.flat_img.array / np.mean(chip.flat_img.array) if obs_param["shutter_effect"] == True: flat_normal = flat_normal * chip.shutter_img flat_normal = np.array(flat_normal, dtype='float32') chip.img *= flat_normal del flat_normal return chip, filt, tel, pointing