import os import galsim import traceback import gc import psutil import numpy as np from astropy.io import fits from datetime import datetime from numpy.random import Generator, PCG64 from ObservationSim._util import get_shear_field from ObservationSim.Straylight import calculateSkyMap_split_g from ObservationSim.Config.Header import generatePrimaryHeader, generateExtensionHeader from ObservationSim.PSF import PSFGauss, FieldDistortion, PSFInterp, PSFInterpSLS from ObservationSim.Instrument.Chip import ChipUtils as chip_utils from ObservationSim.Instrument.Chip import Effects from ObservationSim.Instrument.Chip.libCTI.CTI_modeling import CTI_sim class SimSteps: def __init__(self, overall_config, chip_output, all_filters): self.overall_config = overall_config self.chip_output = chip_output self.all_filters = all_filters def prepare_headers(self, chip, pointing): datetime_obs = datetime.utcfromtimestamp(pointing.timestamp) date_obs = datetime_obs.strftime("%y%m%d") time_obs = datetime_obs.strftime("%H%M%S") self.h_prim = generatePrimaryHeader( xlen=chip.npix_x, ylen=chip.npix_y, pointNum = str(pointing.id), ra=pointing.ra, dec=pointing.dec, pixel_scale=chip.pix_scale, date=date_obs, time_obs=time_obs, exptime=pointing.exp_time, im_type=pointing.pointing_type, sat_pos=[pointing.sat_x, pointing.sat_y, pointing.sat_z], sat_vel=[pointing.sat_vx, pointing.sat_vy, pointing.sat_vz], project_cycle=self.overall_config["project_cycle"], run_counter=self.overall_config["run_counter"], chip_name=str(chip.chipID).rjust(2, '0')) self.h_ext = generateExtensionHeader( chip=chip, xlen=chip.npix_x, ylen=chip.npix_y, ra=pointing.ra, dec=pointing.dec, pa=pointing.img_pa.deg, gain=chip.gain, readout=chip.read_noise, dark=chip.dark_noise, saturation=90000, pixel_scale=chip.pix_scale, pixel_size=chip.pix_size, xcen=chip.x_cen, ycen=chip.y_cen, extName=pointing.pointing_type, timestamp = pointing.timestamp, exptime = pointing.exp_time, readoutTime = chip.readout_time) return self.h_prim, self.h_ext def add_sky_background(self, chip, filt, tel, pointing, catalog, obs_param): flat_normal = np.ones_like(chip.img.array) if obs_param["flat_fielding"] == True: flat_normal = flat_normal * chip.flat_img.array / np.mean(chip.flat_img.array) if obs_param["shutter_effect"] == True: flat_normal = flat_normal * chip.shutter_img flat_normal = np.array(flat_normal, dtype='float32') if obs_param["enable_straylight_model"]: # Filter.sky_background, Filter.zodical_spec will be updated filt.setFilterStrayLightPixel( jtime = pointing.jdt, sat_pos = np.array([pointing.sat_x, pointing.sat_y, pointing.sat_z]), pointing_radec = np.array([pointing.ra,pointing.dec]), sun_pos = np.array([pointing.sun_x, pointing.sun_y, pointing.sun_z])) self.chip_output.Log_info("================================================") self.chip_output.Log_info("sky background + stray light pixel flux value: %.5f"%(filt.sky_background)) if chip.survey_type == "photometric": sky_map = filt.getSkyNoise(exptime = obs_param["exptime"]) sky_map = sky_map * np.ones_like(chip.img.array) * flat_normal sky_map = galsim.Image(array=sky_map) else: # chip.loadSLSFLATCUBE(flat_fn='flat_cube.fits') sky_map = calculateSkyMap_split_g( skyMap=flat_normal, blueLimit=filt.blue_limit, redLimit=filt.red_limit, conf=chip.sls_conf, pixelSize=chip.pix_scale, isAlongY=0, flat_cube=chip.flat_cube, zoldial_spec = filt.zodical_spec) sky_map = sky_map + filt.sky_background sky_map = sky_map * tel.pupil_area * obs_param["exptime"] chip.img += sky_map return chip, filt, tel, pointing def add_objects(self, chip, filt, tel, pointing, catalog, obs_param): # Prepare output file(s) for this chip self.chip_output.create_output_file() # Prepare the PSF model if self.overall_config["psf_setting"]["psf_model"] == "Gauss": psf_model = PSFGauss(chip=chip, psfRa=self.overall_config["psf_setting"]["psf_rcont"]) elif self.overall_config["psf_setting"]["psf_model"] == "Interp": if chip.survey_type == "spectroscopic": psf_model = PSFInterpSLS(chip, filt, PSF_data_prefix=self.overall_config["psf_setting"]["psf_sls_dir"]) else: psf_model = PSFInterp(chip=chip, npsf=chip.n_psf_samples, PSF_data_file=self.overall_config["psf_setting"]["psf_pho_dir"]) else: self.chip_output.Log_error("unrecognized PSF model type!!", flush=True) # Apply field distortion model if obs_param["field_dist"] == True: fd_model = FieldDistortion(chip=chip, img_rot=pointing.img_pa.deg) else: fd_model = None # Update limiting magnitudes for all filters based on the exposure time # Get the filter which will be used for magnitude cut for ifilt in range(len(self.all_filters)): temp_filter = self.all_filters[ifilt] temp_filter.update_limit_saturation_mags(exptime=pointing.get_full_depth_exptime(temp_filter.filter_type), chip=chip) if temp_filter.filter_type.lower() == self.overall_config["obs_setting"]["cut_in_band"].lower(): cut_filter = temp_filter # Read in shear values from configuration file if the constant shear type is used if self.overall_config["shear_setting"]["shear_type"] == "constant": g1_field, g2_field, _ = get_shear_field(config=self.overall_config) self.chip_output.Log_info("Use constant shear: g1=%.5f, g2=%.5f"%(g1_field, g2_field)) # Get chip WCS if not hasattr(self, 'h_ext'): _, _ = self.prepare_headers(chip=chip, pointing=pointing) chip_wcs = galsim.FitsWCS(header = self.h_ext) # Loop over objects nobj = len(catalog.objs) missed_obj = 0 bright_obj = 0 dim_obj = 0 for j in range(nobj): # # [DEBUG] [TODO] # if j >= 10: # break obj = catalog.objs[j] # load and convert SED; also caculate object's magnitude in all CSST bands try: sed_data = catalog.load_sed(obj) norm_filt = catalog.load_norm_filt(obj) obj.sed, obj.param["mag_%s"%filt.filter_type.lower()], obj.param["flux_%s"%filt.filter_type.lower()] = catalog.convert_sed( mag=obj.param["mag_use_normal"], sed=sed_data, target_filt=filt, norm_filt=norm_filt, ) _, obj.param["mag_%s"%cut_filter.filter_type.lower()], obj.param["flux_%s"%cut_filter.filter_type.lower()] = catalog.convert_sed( mag=obj.param["mag_use_normal"], sed=sed_data, target_filt=cut_filter, norm_filt=norm_filt, ) except Exception as e: traceback.print_exc() self.chip_output.Log_error(e) continue # [TODO] Testing # self.chip_output.Log_info("mag_%s = %.3f"%(filt.filter_type.lower(), obj.param["mag_%s"%filt.filter_type.lower()])) # Exclude very bright/dim objects (for now) if cut_filter.is_too_bright( mag=obj.param["mag_%s"%self.overall_config["obs_setting"]["cut_in_band"].lower()], margin=self.overall_config["obs_setting"]["mag_sat_margin"]): self.chip_output.Log_info("obj %s too birght!! mag_%s = %.3f"%(obj.id, cut_filter.filter_type, obj.param["mag_%s"%self.overall_config["obs_setting"]["cut_in_band"].lower()])) bright_obj += 1 obj.unload_SED() continue if filt.is_too_dim( mag=obj.getMagFilter(filt), margin=self.overall_config["obs_setting"]["mag_lim_margin"]): self.chip_output.Log_info("obj %s too dim!! mag_%s = %.3f"%(obj.id, filt.filter_type, obj.getMagFilter(filt))) dim_obj += 1 obj.unload_SED() continue # Get corresponding shear values if self.overall_config["shear_setting"]["shear_type"] == "constant": if obj.type == 'star': obj.g1, obj.g2 = 0., 0. else: # Figure out shear fields from overall configuration shear setting obj.g1, obj.g2 = g1_field, g2_field elif self.overall_config["shear_setting"]["shear_type"] == "catalog": pass else: self.chip_output.Log_error("Unknown shear input") raise ValueError("Unknown shear input") # Get position of object on the focal plane pos_img, _, _, _, fd_shear = obj.getPosImg_Offset_WCS(img=chip.img, fdmodel=fd_model, chip=chip, verbose=False, chip_wcs=chip_wcs, img_header=self.h_ext) # [TODO] For now, only consider objects which their centers (after field distortion) are projected within the focal plane # Otherwise they will be considered missed objects # if pos_img.x == -1 or pos_img.y == -1 or (not chip.isContainObj(x_image=pos_img.x, y_image=pos_img.y, margin=0.)): if pos_img.x == -1 or pos_img.y == -1: self.chip_output.Log_info('obj_ra = %.6f, obj_dec = %.6f, obj_ra_orig = %.6f, obj_dec_orig = %.6f'%(obj.ra, obj.dec, obj.ra_orig, obj.dec_orig)) self.chip_output.Log_error("Objected missed: %s"%(obj.id)) missed_obj += 1 obj.unload_SED() continue # Draw object & update output catalog try: if self.overall_config["run_option"]["out_cat_only"]: isUpdated = True obj.real_pos = obj.getRealPos(chip.img, global_x=obj.posImg.x, global_y=obj.posImg.y, img_real_wcs=obj.chip_wcs) pos_shear = 0. elif chip.survey_type == "photometric" and not self.overall_config["run_option"]["out_cat_only"]: isUpdated, pos_shear = obj.drawObj_multiband( tel=tel, pos_img=pos_img, psf_model=psf_model, bandpass_list=filt.bandpass_sub_list, filt=filt, chip=chip, g1=obj.g1, g2=obj.g2, exptime=pointing.exp_time, fd_shear=fd_shear) elif chip.survey_type == "spectroscopic" and not self.overall_config["run_option"]["out_cat_only"]: isUpdated, pos_shear = obj.drawObj_slitless( tel=tel, pos_img=pos_img, psf_model=psf_model, bandpass_list=filt.bandpass_sub_list, filt=filt, chip=chip, g1=obj.g1, g2=obj.g2, exptime=pointing.exp_time, normFilter=norm_filt, fd_shear=fd_shear) if isUpdated == 1: # TODO: add up stats self.chip_output.cat_add_obj(obj, pos_img, pos_shear) pass elif isUpdated == 0: missed_obj += 1 self.chip_output.Log_error("Objected missed: %s"%(obj.id)) else: self.chip_output.Log_error("Draw error, object omitted: %s"%(obj.id)) continue except Exception as e: traceback.print_exc() self.chip_output.Log_error(e) # Unload SED: obj.unload_SED() del obj gc.collect() del psf_model gc.collect() self.chip_output.Log_info("Running checkpoint #1 (Object rendering finished): pointing-%d chip-%d pid-%d memory-%6.2fGB"%(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )) self.chip_output.Log_info("# objects that are too bright %d out of %d"%(bright_obj, nobj)) self.chip_output.Log_info("# objects that are too dim %d out of %d"%(dim_obj, nobj)) self.chip_output.Log_info("# objects that are missed %d out of %d"%(missed_obj, nobj)) # Apply flat fielding (with shutter effects) flat_normal = np.ones_like(chip.img.array) if obs_param["flat_fielding"] == True: flat_normal = flat_normal * chip.flat_img.array / np.mean(chip.flat_img.array) if obs_param["shutter_effect"] == True: flat_normal = flat_normal * chip.shutter_img flat_normal = np.array(flat_normal, dtype='float32') chip.img *= flat_normal del flat_normal return chip, filt, tel, pointing def add_cosmic_rays(self, chip, filt, tel, pointing, catalog, obs_param): self.chip_output.Log_info(msg=" Adding Cosmic-Ray", logger=self.logger) chip.img, crmap_gsimg, cr_event_num = chip_utils.add_cosmic_rays( img=chip.img, chip=chip, exptime=pointing.exptime, seed=self.overall_config["random_seeds"]["seed_CR"]+pointing.id*30+chip.chipID) # [TODO] output cosmic ray image return chip, filt, tel, pointing def apply_PRNU(self, chip, filt, tel, pointing, catalog, obs_param): chip.img *= chip.prnu_img return chip, filt, tel, pointing def add_poisson_and_dark(self, chip, filt, tel, pointing, catalog, obs_param): # Add dark current & Poisson noise InputDark = False if obs_param["add_dark"] == True: if InputDark: chip.img = chip_utils.add_inputdark(img=chip.img, chip=chip, exptime=pointing.exptime) else: chip.img, _ = chip_utils.add_poisson(img=chip.img, chip=chip, exptime=pointing.exptime, poisson_noise=chip.poisson_noise) else: chip.img, _ = chip_utils.add_poisson(img=chip.img, chip=self, exptime=pointing.exptime, poisson_noise=chip.poisson_noise, dark_noise=0.) return chip, filt, tel, pointing def add_brighter_fatter(self, chip, filt, tel, pointing, catalog, obs_param): chip.img = chip_utils.add_brighter_fatter(img=chip.img) return chip, filt, tel, pointing def add_detector_defects(self, chip, filt, tel, pointing, catalog, obs_param): # Add Hot Pixels or/and Dead Pixels rgbadpix = Generator(PCG64(int(self.overall_config["random_seeds"]["seed_defective"]+chip.chipID))) badfraction = 5E-5*(rgbadpix.random()*0.5+0.7) chip.img = Effects.DefectivePixels( chip.img, IfHotPix=obs_param["hot_pixels"], IfDeadPix=obs_param["dead_pixels"], fraction=badfraction, seed=self.overall_config["random_seeds"]["seed_defective"]+chip.chipID, biaslevel=0) # Apply Bad columns if obs_param["bad_columns"] == True: chip.img = Effects.BadColumns(chip.img, seed=self.overall_config["random_seeds"]["seed_badcolumns"], chipid=chip.chipID) return chip, filt, tel, pointing def add_nonlinearity(self, chip, filt, tel, pointing, catalog, obs_param): self.chip_output.Log_info(" Applying Non-Linearity on the chip image") chip.img = Effects.NonLinearity(GSImage=chip.img, beta1=5.e-7, beta2=0) return chip, filt, tel, pointing def add_blooming(self, chip, filt, tel, pointing, catalog, obs_param): self.chip_output.Log_info(" Applying CCD Saturation & Blooming") chip.img = Effects.SaturBloom(GSImage=chip.img, nsect_x=1, nsect_y=1, fullwell=int(chip.full_well)) return chip, filt, tel, pointing def apply_CTE(self, chip, filt, tel, pointing, catalog, obs_param): self.chip_output.Log_info(" Apply CTE Effect") ### 2*8 -> 1*16 img-layout img = chip_utils.formatOutput(GSImage=chip.img) chip.nsecy = 1 chip.nsecx = 16 img_arr = img.array ny, nx = img_arr.shape dx = int(nx/chip.nsecx) dy = int(ny/chip.nsecy) newimg = galsim.Image(nx, int(ny+chip.overscan_y), init_value=0) for ichannel in range(16): print('\n***add CTI effects: pointing-{:} chip-{:} channel-{:}***'.format(pointing.id, chip.chipID, ichannel+1)) noverscan, nsp, nmax = self.overscan_y, 3, 10 beta, w, c = 0.478, 84700, 0 t = np.array([0.74, 7.7, 37],dtype=np.float32) rho_trap = np.array([0.6, 1.6, 1.4],dtype=np.float32) trap_seeds = np.array([0, 1000, 10000],dtype=np.int32) + ichannel + chip.chipID*16 release_seed = 50 + ichannel + pointing.id*30 + chip.chipID*16 newimg.array[:, 0+ichannel*dx:dx+ichannel*dx] = CTI_sim(img_arr[:, 0+ichannel*dx:dx+ichannel*dx],dx,dy,noverscan,nsp,nmax,beta,w,c,t,rho_trap,trap_seeds,release_seed) newimg.wcs = img.wcs del img img = newimg ### 1*16 -> 2*8 img-layout chip.img = chip_utils.formatRevert(GSImage=img) chip.nsecy = 2 chip.nsecx = 8 # [TODO] make overscan_y == 0 chip.overscan_y = 0 return chip, filt, tel, pointing def add_prescan_overscan(self, chip, filt, tel, pointing, catalog, obs_param): self.chip_output.Log_info("Apply pre/over-scan") chip.img = chip_utils.AddPreScan(GSImage=chip.img, pre1=chip.prescan_x, pre2=chip.prescan_y, over1=chip.overscan_x, over2=chip.overscan_y) return chip, filt, tel, pointing def add_bias(self, chip, filt, tel, pointing, catalog, obs_param): self.chip_output.Log_info(" Adding Bias level and 16-channel non-uniformity") if obs_param["bias_16channel"] == True: chip.img = Effects.AddBiasNonUniform16(chip.img, bias_level=float(chip.bias_level), nsecy = chip.nsecy, nsecx=chip.nsecx, seed=self.overall_config["random_seeds"]["seed_biasNonUniform"]+chip.chipID) elif obs_param["bias_16channel"] == False: chip.img += self.bias_level return chip, filt, tel, pointing def add_readout_noise(self, chip, filt, tel, pointing, catalog, obs_param): seed = int(self.overall_config["random_seeds"]["seed_readout"]) + pointing.id*30 + chip.chipID rng_readout = galsim.BaseDeviate(seed) readout_noise = galsim.GaussianNoise(rng=rng_readout, sigma=chip.read_noise) chip.img.addNoise(readout_noise) return chip, filt, tel, pointing def apply_gain(self, chip, filt, tel, pointing, catalog, obs_param): self.chip_output.Log_info(" Applying Gain") if obs_param["gain_16channel"] == True: chip.img, chip.gain_channel = Effects.ApplyGainNonUniform16( chip.img, gain=chip.gain, nsecy = self.nsecy, nsecx=self.nsecx, seed=self.overall_config["random_seeds"]["seed_gainNonUniform"]+chip.chipID) elif obs_param["gain_16channel"] == False: chip.img /= chip.gain return chip, filt, tel, pointing def quantization_and_output(self, chip, filt, tel, pointing, catalog, obs_param): chip.img.array[chip.img.array > 65535] = 65535 chip.img.replaceNegative(replace_value=0) chip.img.quantize() chip.img = galsim.Image(chip.img.array, dtype=np.uint16) hdu1 = fits.PrimaryHDU(header=self.h_prim) hdu1.add_checksum() hdu1.header.comments['CHECKSUM'] = 'HDU checksum' hdu1.header.comments['DATASUM'] = 'data unit checksum' hdu2 = fits.ImageHDU(chip.img.array, header=self.h_ext) hdu2.add_checksum() hdu2.header.comments['XTENSION'] = 'extension type' hdu2.header.comments['CHECKSUM'] = 'HDU checksum' hdu2.header.comments['DATASUM'] = 'data unit checksum' hdu1 = fits.HDUList([hdu1, hdu2]) fname = os.path.join(self.chip_output.subdir, self.h_prim['FILENAME'] + '.fits') hdu1.writeto(fname, output_verify='ignore', overwrite=True) return chip, filt, tel, pointing SIM_STEP_TYPES = { "scie_obs": "add_objects", "sky_background": "add_sky_background", "cosmic_rays": "add_cosmic_rays", "PRNU_effect": "apply_PRNU", "poisson_and_dark": "add_poisson_and_dark", "bright_fatter": "add_brighter_fatter", "detector_defects": "add_detector_defects", "nonlinearity": "add_nonlinearity", "blooming": "add_blooming", "CTE_effect": "apply_CTE", "prescan_overscan": "add_prescan_overscan", "bias": "add_bias", "readout_noise": "add_readout_noise", "gain": "apply_gain", "quantization_and_output": "quantization_and_output" }