import os import numpy as np import mpi4py.MPI as MPI import galsim import logging import psutil import gc from astropy.io import fits from datetime import datetime import traceback from ObservationSim.Config import config_dir, ChipOutput from ObservationSim.Config.Header import generatePrimaryHeader, generateExtensionHeader from ObservationSim.Instrument import Telescope, Filter, FilterParam, FocalPlane, Chip from ObservationSim.Instrument.Chip import Effects from ObservationSim.Straylight import calculateSkyMap_split_g from ObservationSim.PSF import PSFGauss, FieldDistortion, PSFInterp from ObservationSim._util import get_shear_field, makeSubDir_PointingList from ObservationSim.Astrometry.Astrometry_util import on_orbit_obs_position class Observation(object): def __init__(self, config, Catalog, work_dir=None, data_dir=None): self.path_dict = config_dir(config=config, work_dir=work_dir, data_dir=data_dir) self.config = config self.tel = Telescope() self.focal_plane = FocalPlane(survey_type=self.config["obs_setting"]["survey_type"]) self.filter_param = FilterParam() self.chip_list = [] self.filter_list = [] self.all_filter = [] self.Catalog = Catalog # Construct chips & filters: for i in range(self.focal_plane.nchips): chipID = i + 1 # Make Chip & Filter lists chip = Chip( chipID=chipID, config=self.config) filter_id, filter_type = chip.getChipFilter() filt = Filter(filter_id=filter_id, filter_type=filter_type, filter_param=self.filter_param) if not self.focal_plane.isIgnored(chipID=chipID): self.chip_list.append(chip) self.filter_list.append(filt) self.all_filter.append(filt) def run_one_chip(self, chip, filt, pointing, chip_output, wcs_fp=None, psf_model=None, cat_dir=None, sed_dir=None): chip_output.Log_info(':::::::::::::::::::Current Pointing Information::::::::::::::::::') chip_output.Log_info("RA: %f, DEC; %f" % (pointing.ra, pointing.dec)) chip_output.Log_info("Time: %s" % datetime.utcfromtimestamp(pointing.timestamp).isoformat()) chip_output.Log_info("Exposure time: %f" % pointing.exp_time) chip_output.Log_info("Satellite Position (x, y, z): (%f, %f, %f)" % (pointing.sat_x, pointing.sat_y, pointing.sat_z)) chip_output.Log_info("Satellite Velocity (x, y, z): (%f, %f, %f)" % (pointing.sat_vx, pointing.sat_vy, pointing.sat_vz)) chip_output.Log_info("Position Angle: %f" % pointing.img_pa.deg) chip_output.Log_info('Chip : %d' % chip.chipID) chip_output.Log_info(':::::::::::::::::::::::::::END:::::::::::::::::::::::::::::::::::') if self.config["psf_setting"]["psf_model"] == "Gauss": psf_model = PSFGauss(chip=chip, psfRa=self.config["psf_setting"]["psf_rcont"]) elif self.config["psf_setting"]["psf_model"] == "Interp": psf_model = PSFInterp(chip=chip, npsf=chip.n_psf_samples, PSF_data_file=self.path_dict["psf_dir"]) else: chip_output.Log_error("unrecognized PSF model type!!", flush=True) # Figure out shear fields self.g1_field, self.g2_field, self.nshear = get_shear_field(config=self.config) # Apply astrometric simulation for pointing if self.config["obs_setting"]["enable_astrometric_model"]: dt = datetime.utcfromtimestamp(pointing.timestamp) date_str = dt.date().isoformat() time_str = dt.time().isoformat() ra_cen, dec_cen = on_orbit_obs_position( input_ra_list=[pointing.ra], input_dec_list=[pointing.dec], input_pmra_list=[0.], input_pmdec_list=[0.], input_rv_list=[0.], input_parallax_list=[1e-9], input_nstars=1, input_x=pointing.sat_x, input_y=pointing.sat_y, input_z=pointing.sat_z, input_vx=pointing.sat_vx, input_vy=pointing.sat_vy, input_vz=pointing.sat_vz, input_epoch="J2000", input_date_str=date_str, input_time_str=time_str ) ra_cen, dec_cen = ra_cen[0], dec_cen[0] else: ra_cen = pointing.ra dec_cen = pointing.dec # Get WCS for the focal plane if wcs_fp == None: wcs_fp = self.focal_plane.getTanWCS(ra_cen, dec_cen, pointing.img_pa, chip.pix_scale) # Create chip Image chip.img = galsim.ImageF(chip.npix_x, chip.npix_y) chip.img.setOrigin(chip.bound.xmin, chip.bound.ymin) chip.img.wcs = wcs_fp if self.config["obs_setting"]["enable_straylight_model"]: filt.setFilterStrayLightPixel(jtime = pointing.jdt, sat_pos = np.array([pointing.sat_x, pointing.sat_y, pointing.sat_z]), pointing_radec = np.array([pointing.ra,pointing.dec]), sun_pos = np.array([pointing.sun_x,pointing.sun_y,pointing.sun_z])) chip_output.Log_info("========================sky pix========================") chip_output.Log_info(filt.sky_background) if chip.survey_type == "photometric": sky_map = None elif chip.survey_type == "spectroscopic": # chip.loadSLSFLATCUBE(flat_fn='flat_cube.fits') flat_normal = np.ones_like(chip.img.array) if self.config["ins_effects"]["flat_fielding"] == True: chip_output.Log_info("SLS flat preprocess,CHIP %d : Creating and applying Flat-Fielding"%chip.chipID) msg = str(chip.img.bounds) chip_output.Log_info(msg) flat_img = Effects.MakeFlatSmooth( chip.img.bounds, int(self.config["random_seeds"]["seed_flat"])) flat_normal = flat_normal * flat_img.array / np.mean(flat_img.array) if self.config["ins_effects"]["shutter_effect"] == True: chip_output.Log_info("SLS flat preprocess,CHIP %d : Apply shutter effect"%chip.chipID) shuttimg = Effects.ShutterEffectArr(chip.img, t_shutter=1.3, dist_bearing=735, dt=1E-3) # shutter effect normalized image for this chip flat_normal = flat_normal*shuttimg flat_normal = np.array(flat_normal,dtype='float32') sky_map = calculateSkyMap_split_g( skyMap=flat_normal, blueLimit=filt.blue_limit, redLimit=filt.red_limit, conf=chip.sls_conf, pixelSize=chip.pix_scale, isAlongY=0, flat_cube=chip.flat_cube, zoldial_spec = filt.zodical_spec) sky_map = sky_map+filt.sky_background del flat_normal if pointing.pointing_type == 'MS': # Load catalogues and templates self.cat = self.Catalog(config=self.config, chip=chip, pointing=pointing, cat_dir=cat_dir, sed_dir=sed_dir, chip_output=chip_output, filt=filt) chip_output.create_output_file() self.nobj = len(self.cat.objs) for ifilt in range(len(self.all_filter)): temp_filter = self.all_filter[ifilt] # Update the limiting magnitude using exposure time in pointing temp_filter.update_limit_saturation_mags(exptime=pointing.exp_time, chip=chip) # Select cutting band filter for saturation/limiting magnitude if temp_filter.filter_type.lower() == self.config["obs_setting"]["cut_in_band"].lower(): cut_filter = temp_filter if self.config["ins_effects"]["field_dist"] == True: self.fd_model = FieldDistortion(chip=chip, img_rot=pointing.img_pa.deg) else: self.fd_model = None # Loop over objects missed_obj = 0 bright_obj = 0 dim_obj = 0 h_ext = generateExtensionHeader( chip=chip, xlen=chip.npix_x, ylen=chip.npix_y, ra=pointing.ra, dec=pointing.dec, pa=pointing.img_pa.deg, gain=chip.gain, readout=chip.read_noise, dark=chip.dark_noise, saturation=90000, pixel_scale=chip.pix_scale, pixel_size=chip.pix_size, xcen=chip.x_cen, ycen=chip.y_cen, extName='SCI', timestamp = pointing.timestamp, exptime = pointing.exp_time, readoutTime = 40.) chip_wcs = galsim.FitsWCS(header=h_ext) for j in range(self.nobj): # (DEBUG) # if j >= 10: # break obj = self.cat.objs[j] # load and convert SED; also caculate object's magnitude in all CSST bands try: sed_data = self.cat.load_sed(obj) norm_filt = self.cat.load_norm_filt(obj) obj.sed, obj.param["mag_%s"%filt.filter_type.lower()], obj.param["flux_%s"%filt.filter_type.lower()] = self.cat.convert_sed( mag=obj.param["mag_use_normal"], sed=sed_data, target_filt=filt, norm_filt=norm_filt, ) _, obj.param["mag_%s"%cut_filter.filter_type.lower()], obj.param["flux_%s"%cut_filter.filter_type.lower()] = self.cat.convert_sed( mag=obj.param["mag_use_normal"], sed=sed_data, target_filt=cut_filter, norm_filt=norm_filt, ) except Exception as e: traceback.print_exc() chip_output.Log_error(e) continue # [TODO] Testing # chip_output.Log_info("mag_%s = %.3f"%(filt.filter_type.lower(), obj.param["mag_%s"%filt.filter_type.lower()])) # Exclude very bright/dim objects (for now) if cut_filter.is_too_bright( mag=obj.param["mag_%s"%self.config["obs_setting"]["cut_in_band"].lower()], margin=self.config["obs_setting"]["mag_sat_margin"]): chip_output.Log_info("obj %s too birght!! mag_%s = %.3f"%(obj.id, cut_filter.filter_type, obj.param["mag_%s"%self.config["obs_setting"]["cut_in_band"].lower()])) bright_obj += 1 obj.unload_SED() continue if filt.is_too_dim( mag=obj.getMagFilter(filt), margin=self.config["obs_setting"]["mag_lim_margin"]): chip_output.Log_info("obj %s too dim!! mag_%s = %.3f"%(obj.id, filt.filter_type, obj.getMagFilter(filt))) dim_obj += 1 obj.unload_SED() continue # Get corresponding shear values if self.config["shear_setting"]["shear_type"] == "constant": if obj.type == 'star': obj.g1, obj.g2 = 0., 0. else: obj.g1, obj.g2 = self.g1_field, self.g2_field elif self.config["shear_setting"]["shear_type"] == "catalog": pass else: chip_output.Log_error("Unknown shear input") raise ValueError("Unknown shear input") # Get position of object on the focal plane pos_img, offset, local_wcs, real_wcs, fd_shear = obj.getPosImg_Offset_WCS(img=chip.img, fdmodel=self.fd_model, chip=chip, verbose=False, chip_wcs=chip_wcs, img_header=h_ext) # [TODO] For now, only consider objects which their centers (after field distortion) are projected within the focal plane # Otherwise they will be considered missed objects # if pos_img.x == -1 or pos_img.y == -1 or (not chip.isContainObj(x_image=pos_img.x, y_image=pos_img.y, margin=0.)): if pos_img.x == -1 or pos_img.y == -1: chip_output.Log_info('obj_ra = %.6f, obj_dec = %.6f, obj_ra_orig = %.6f, obj_dec_orig = %.6f'%(obj.ra, obj.dec, obj.ra_orig, obj.dec_orig)) chip_output.Log_error("Objected missed: %s"%(obj.id)) missed_obj += 1 obj.unload_SED() continue # Draw object & update output catalog try: if self.config["run_option"]["out_cat_only"]: isUpdated = True obj.real_pos = obj.getRealPos(chip.img, global_x=obj.posImg.x, global_y=obj.posImg.y, img_real_wcs=obj.real_wcs) pos_shear = 0. elif chip.survey_type == "photometric" and not self.config["run_option"]["out_cat_only"]: isUpdated, pos_shear = obj.drawObj_multiband( tel=self.tel, pos_img=pos_img, psf_model=psf_model, bandpass_list=filt.bandpass_sub_list, filt=filt, chip=chip, g1=obj.g1, g2=obj.g2, exptime=pointing.exp_time, fd_shear=fd_shear) elif chip.survey_type == "spectroscopic" and not self.config["run_option"]["out_cat_only"]: isUpdated, pos_shear = obj.drawObj_slitless( tel=self.tel, pos_img=pos_img, psf_model=psf_model, bandpass_list=filt.bandpass_sub_list, filt=filt, chip=chip, g1=obj.g1, g2=obj.g2, exptime=pointing.exp_time, normFilter=norm_filt, fd_shear=fd_shear) if isUpdated == 1: # TODO: add up stats chip_output.cat_add_obj(obj, pos_img, pos_shear) pass elif isUpdated == 0: missed_obj += 1 chip_output.Log_error("Objected missed: %s"%(obj.id)) else: chip_output.Log_error("Draw error, object omitted: %s"%(obj.id)) continue except Exception as e: traceback.print_exc() chip_output.Log_error(e) # # [C6 TEST] # chip_output.Log_info("check running:1: pointing-{:} chip-{:} pid-{:} memory-{:6.2}GB".format(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )) # chip_output.Log_info('draw object %s'%obj.id) # chip_output.Log_info('mag = %.3f'%obj.param['mag_use_normal']) # Unload SED: obj.unload_SED() del obj gc.collect() del psf_model del self.cat gc.collect() chip_output.Log_info("check running:1: pointing-%d chip-%d pid-%d memory-%6.2fGB"%(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )) # Detector Effects # =========================================================== # whether to output zero, dark, flat calibration images. if not self.config["run_option"]["out_cat_only"]: chip.img = chip.addEffects( config=self.config, img=chip.img, chip_output=chip_output, filt=filt, ra_cen=pointing.ra, dec_cen=pointing.dec, img_rot=pointing.img_pa, exptime=pointing.exp_time, pointing_ID=pointing.id, timestamp_obs=pointing.timestamp, pointing_type=pointing.pointing_type, sky_map=sky_map, tel = self.tel, logger=chip_output.logger) if pointing.pointing_type == 'MS': datetime_obs = datetime.utcfromtimestamp(pointing.timestamp) date_obs = datetime_obs.strftime("%y%m%d") time_obs = datetime_obs.strftime("%H%M%S") h_prim = generatePrimaryHeader( xlen=chip.npix_x, ylen=chip.npix_y, pointNum = str(pointing.id), ra=pointing.ra, dec=pointing.dec, pixel_scale=chip.pix_scale, date=date_obs, time_obs=time_obs, exptime=pointing.exp_time, im_type='SCI', sat_pos=[pointing.sat_x, pointing.sat_y, pointing.sat_z], sat_vel=[pointing.sat_vx, pointing.sat_vy, pointing.sat_vz], chip_name=str(chip.chipID).rjust(2, '0')) h_ext = generateExtensionHeader( chip=chip, xlen=chip.npix_x, ylen=chip.npix_y, ra=pointing.ra, dec=pointing.dec, pa=pointing.img_pa.deg, gain=chip.gain, readout=chip.read_noise, dark=chip.dark_noise, saturation=90000, pixel_scale=chip.pix_scale, pixel_size=chip.pix_size, xcen=chip.x_cen, ycen=chip.y_cen, extName='SCI', timestamp=pointing.timestamp, exptime=pointing.exp_time, readoutTime=40.) chip.img = galsim.Image(chip.img.array, dtype=np.uint16) hdu1 = fits.PrimaryHDU(header=h_prim) hdu1.add_checksum() hdu1.header.comments['CHECKSUM'] = 'HDU checksum' hdu1.header.comments['DATASUM'] = 'data unit checksum' hdu2 = fits.ImageHDU(chip.img.array, header=h_ext) hdu2.add_checksum() hdu2.header.comments['XTENSION'] = 'extension type' hdu2.header.comments['CHECKSUM'] = 'HDU checksum' hdu2.header.comments['DATASUM'] = 'data unit checksum' hdu1 = fits.HDUList([hdu1, hdu2]) fname = os.path.join(chip_output.subdir, h_prim['FILENAME'] + '.fits') hdu1.writeto(fname, output_verify='ignore', overwrite=True) chip_output.Log_info("# objects that are too bright %d out of %d"%(bright_obj, self.nobj)) chip_output.Log_info("# objects that are too dim %d out of %d"%(dim_obj, self.nobj)) chip_output.Log_info("# objects that are missed %d out of %d"%(missed_obj, self.nobj)) del chip.img chip_output.Log_info("check running:2: pointing-%d chip-%d pid-%d memory-%6.2fGB"%(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )) def runExposure_MPI_PointingList(self, pointing_list,chips=None, use_mpi=False): if use_mpi: comm = MPI.COMM_WORLD ind_thread = comm.Get_rank() num_thread = comm.Get_size() if chips is None: nchips_per_fp = len(self.chip_list) run_chips = self.chip_list run_filts = self.filter_list else: # Only run a particular set of chips run_chips = [] run_filts = [] nchips_per_fp = len(chips) for ichip in range(len(self.chip_list)): chip = self.chip_list[ichip] filt = self.filter_list[ichip] if chip.chipID in chips: run_chips.append(chip) run_filts.append(filt) for ipoint in range(len(pointing_list)): for ichip in range(nchips_per_fp): i = ipoint*nchips_per_fp + ichip pointing = pointing_list[ipoint] pointing_ID = pointing.id if use_mpi: if i % num_thread != ind_thread: continue pid = os.getpid() sub_img_dir, prefix = makeSubDir_PointingList(path_dict=self.path_dict, config=self.config, pointing_ID=pointing_ID) chip = run_chips[ichip] filt = run_filts[ichip] # chip_output.Log_info("running pointing#%d, chip#%d, at PID#%d..."%(pointing_ID, chip.chipID, pid)) chip_output = ChipOutput( config=self.config, focal_plane=self.focal_plane, chip=chip, filt=filt, exptime=pointing.exp_time, pointing_type=pointing.pointing_type, pointing_ID=pointing_ID, subdir=sub_img_dir, prefix=prefix) chip_output.Log_info("running pointing#%d, chip#%d, at PID#%d..."%(pointing_ID, chip.chipID, pid)) self.run_one_chip( chip=chip, filt=filt, chip_output=chip_output, pointing=pointing) chip_output.Log_info("finished running chip#%d..."%(chip.chipID)) for handler in chip_output.logger.handlers[:]: chip_output.logger.removeHandler(handler) gc.collect()