import os import numpy as np import mpi4py.MPI as MPI import galsim import logging import psutil import gc from astropy.io import fits from datetime import datetime import traceback from ObservationSim.Config import config_dir, ChipOutput from ObservationSim.Config.Header import generatePrimaryHeader, generateExtensionHeader from ObservationSim.Instrument import Telescope, Filter, FilterParam, FocalPlane, Chip from ObservationSim.Instrument.Chip import Effects from ObservationSim.MockObject import calculateSkyMap_split_g from ObservationSim.PSF import PSFGauss, FieldDistortion, PSFInterp from ObservationSim._util import get_shear_field, makeSubDir_PointingList from ObservationSim.Astrometry.Astrometry_util import on_orbit_obs_position # import tracemalloc class Observation(object): def __init__(self, config, Catalog, work_dir=None, data_dir=None): self.path_dict = config_dir(config=config, work_dir=work_dir, data_dir=data_dir) self.config = config self.tel = Telescope() self.focal_plane = FocalPlane(survey_type=self.config["obs_setting"]["survey_type"]) self.filter_param = FilterParam() self.chip_list = [] self.filter_list = [] self.all_filter = [] self.Catalog = Catalog # if we want to apply field distortion? if self.config["ins_effects"]["field_dist"] == True: self.fd_model = FieldDistortion(fdModel_path=self.path_dict["fd_path"]) else: self.fd_model = None # Construct chips & filters: nchips = self.focal_plane.nchip_x*self.focal_plane.nchip_y for i in range(nchips): chipID = i + 1 # Make Chip & Filter lists chip = Chip( chipID=chipID, config=self.config) filter_id, filter_type = chip.getChipFilter() filt = Filter(filter_id=filter_id, filter_type=filter_type, filter_param=self.filter_param) if not self.focal_plane.isIgnored(chipID=chipID): self.chip_list.append(chip) self.filter_list.append(filt) self.all_filter.append(filt) # Read catalog and shear(s) self.g1_field, self.g2_field, self.nshear = get_shear_field(config=self.config) def run_one_chip(self, chip, filt, pointing, chip_output, wcs_fp=None, psf_model=None, shear_cat_file=None, cat_dir=None, sed_dir=None): # print(':::::::::::::::::::Current Pointing Information::::::::::::::::::') # print("RA: %f, DEC; %f" % (pointing.ra, pointing.dec)) # print("Time: %s" % datetime.utcfromtimestamp(pointing.timestamp).isoformat()) # print("Exposure time: %f" % pointing.exp_time) # print("Satellite Position (x, y, z): (%f, %f, %f)" % (pointing.sat_x, pointing.sat_y, pointing.sat_z)) # print("Satellite Velocity (x, y, z): (%f, %f, %f)" % (pointing.sat_vx, pointing.sat_vy, pointing.sat_vz)) # print("Position Angle: %f" % pointing.img_pa.deg) # print('Chip : %d' % chip.chipID) # print(':::::::::::::::::::::::::::END:::::::::::::::::::::::::::::::::::') chip_output.logger.info(':::::::::::::::::::Current Pointing Information::::::::::::::::::') chip_output.logger.info("RA: %f, DEC; %f" % (pointing.ra, pointing.dec)) chip_output.logger.info("Time: %s" % datetime.utcfromtimestamp(pointing.timestamp).isoformat()) chip_output.logger.info("Exposure time: %f" % pointing.exp_time) chip_output.logger.info("Satellite Position (x, y, z): (%f, %f, %f)" % (pointing.sat_x, pointing.sat_y, pointing.sat_z)) chip_output.logger.info("Satellite Velocity (x, y, z): (%f, %f, %f)" % (pointing.sat_vx, pointing.sat_vy, pointing.sat_vz)) chip_output.logger.info("Position Angle: %f" % pointing.img_pa.deg) chip_output.logger.info('Chip : %d' % chip.chipID) chip_output.logger.info(':::::::::::::::::::::::::::END:::::::::::::::::::::::::::::::::::') if self.config["psf_setting"]["psf_model"] == "Gauss": psf_model = PSFGauss(chip=chip, psfRa=self.config["psf_setting"]["psf_rcont"]) elif self.config["psf_setting"]["psf_model"] == "Interp": psf_model = PSFInterp(chip=chip, PSF_data_file=self.path_dict["psf_dir"]) else: # print("unrecognized PSF model type!!", flush=True) chip_output.logger.error("unrecognized PSF model type!!", flush=True) # Get (extra) shear fields if shear_cat_file is not None: self.g1_field, self.g2_field, self.nshear = get_shear_field(config=self.config, shear_cat_file=shear_cat_file) # Apply astrometric simulation for pointing if self.config["obs_setting"]["enable_astrometric_model"]: dt = datetime.utcfromtimestamp(pointing.timestamp) date_str = dt.date().isoformat() time_str = dt.time().isoformat() ra_cen, dec_cen = on_orbit_obs_position( input_ra_list=[pointing.ra], input_dec_list=[pointing.dec], input_pmra_list=[0.], input_pmdec_list=[0.], input_rv_list=[0.], input_parallax_list=[1e-9], input_nstars=1, input_x=pointing.sat_x, input_y=pointing.sat_y, input_z=pointing.sat_z, input_vx=pointing.sat_vx, input_vy=pointing.sat_vy, input_vz=pointing.sat_vz, input_epoch="J2000", input_date_str=date_str, input_time_str=time_str ) ra_cen, dec_cen = ra_cen[0], dec_cen[0] else: ra_cen = pointing.ra dec_cen = pointing.dec # Get WCS for the focal plane if wcs_fp == None: wcs_fp = self.focal_plane.getTanWCS(ra_cen, dec_cen, pointing.img_pa, chip.pix_scale) # Create chip Image chip.img = galsim.ImageF(chip.npix_x, chip.npix_y) chip.img.setOrigin(chip.bound.xmin, chip.bound.ymin) chip.img.wcs = wcs_fp if chip.survey_type == "photometric": sky_map = None # elif chip.survey_type == "spectroscopic": # sky_map = calculateSkyMap_split_g(xLen=chip.npix_x, yLen=chip.npix_y, blueLimit=filt.blue_limit, redLimit=filt.red_limit, skyfn=self.path_dict["sky_file"], conf=chip.sls_conf, pixelSize=chip.pix_scale, isAlongY=0) elif chip.survey_type == "spectroscopic": # chip.loadSLSFLATCUBE(flat_fn='flat_cube.fits') flat_normal = np.ones_like(chip.img.array) if self.config["ins_effects"]["flat_fielding"] == True: # print("SLS flat preprocess,CHIP %d : Creating and applying Flat-Fielding"%chip.chipID, flush=True) # print(chip.img.bounds, flush=True) chip_output.logger.info("SLS flat preprocess,CHIP %d : Creating and applying Flat-Fielding"%chip.chipID) msg = str(chip.img.bounds) chip_output.logger.info(msg) flat_img = Effects.MakeFlatSmooth( chip.img.bounds, int(self.config["random_seeds"]["seed_flat"])) flat_normal = flat_normal * flat_img.array / np.mean(flat_img.array) if self.config["ins_effects"]["shutter_effect"] == True: # print("SLS flat preprocess,CHIP %d : Apply shutter effect"%chip.chipID, flush=True) chip_output.logger.info("SLS flat preprocess,CHIP %d : Apply shutter effect"%chip.chipID) shuttimg = Effects.ShutterEffectArr(chip.img, t_shutter=1.3, dist_bearing=735, dt=1E-3) # shutter effect normalized image for this chip flat_normal = flat_normal*shuttimg flat_normal = np.array(flat_normal,dtype='float32') sky_map = calculateSkyMap_split_g( skyMap=flat_normal, blueLimit=filt.blue_limit, redLimit=filt.red_limit, conf=chip.sls_conf, pixelSize=chip.pix_scale, isAlongY=0, flat_cube=chip.flat_cube) # sky_map = np.ones([9216, 9232]) del flat_normal if pointing.pointing_type == 'MS': # Load catalogues and templates self.cat = self.Catalog(config=self.config, chip=chip, pointing=pointing, cat_dir=cat_dir, sed_dir=sed_dir, chip_output=chip_output, filt=filt) chip_output.create_output_file() self.nobj = len(self.cat.objs) for ifilt in range(len(self.all_filter)): temp_filter = self.all_filter[ifilt] # Update the limiting magnitude using exposure time in pointing temp_filter.update_limit_saturation_mags(exptime=pointing.exp_time, chip=chip) # Select cutting band filter for saturation/limiting magnitude if temp_filter.filter_type.lower() == self.config["obs_setting"]["cut_in_band"].lower(): cut_filter = temp_filter # Loop over objects missed_obj = 0 bright_obj = 0 dim_obj = 0 for j in range(self.nobj): # (DEBUG) # if j >= 10: # break # tracemalloc.start() obj = self.cat.objs[j] if obj.type == 'star' and self.config["run_option"]["galaxy_only"]: continue elif obj.type == 'galaxy' and self.config["run_option"]["star_only"]: continue elif obj.type == 'quasar' and self.config["run_option"]["star_only"]: continue # load and convert SED; also caculate object's magnitude in all CSST bands try: sed_data = self.cat.load_sed(obj) norm_filt = self.cat.load_norm_filt(obj) obj.sed, obj.param["mag_%s"%filt.filter_type], obj.param["flux_%s"%filt.filter_type] = self.cat.convert_sed( mag=obj.param["mag_use_normal"], sed=sed_data, target_filt=filt, norm_filt=norm_filt, ) _, obj.param["mag_%s"%cut_filter.filter_type], obj.param["flux_%s"%cut_filter.filter_type] = self.cat.convert_sed( mag=obj.param["mag_use_normal"], sed=sed_data, target_filt=cut_filter, norm_filt=norm_filt, ) except Exception as e: # print(e) traceback.print_exc() chip_output.logger.error(e) continue # print("mag_%s = %.3f"%(filt.filter_type, obj.param["mag_%s"%filt.filter_type])) # chip_output.logger.info("mag_%s = %.3f"%(filt.filter_type, obj.param["mag_%s"%filt.filter_type])) # Exclude very bright/dim objects (for now) # if filt.is_too_bright(mag=obj.getMagFilter(filt)): # if filt.is_too_bright(mag=obj.mag_use_normal): if cut_filter.is_too_bright( mag=obj.param["mag_%s"%self.config["obs_setting"]["cut_in_band"].lower()], margin=self.config["obs_setting"]["mag_sat_margin"]): # print("obj too birght!!", flush=True) # if obj.type != 'galaxy': # bright_obj += 1 # obj.unload_SED() # continue bright_obj += 1 obj.unload_SED() continue if filt.is_too_dim( mag=obj.getMagFilter(filt), margin=self.config["obs_setting"]["mag_lim_margin"]): # if cut_filter.is_too_dim(mag=obj.param["mag_%s"%self.config["obs_setting"]["cut_in_band"].lower()]): # print("obj too dim!!", flush=True) dim_obj += 1 obj.unload_SED() # print(obj.getMagFilter(filt)) continue if self.config["shear_setting"]["shear_type"] == "constant": if obj.type == 'star': obj.g1, obj.g2 = 0., 0. else: obj.g1, obj.g2 = self.g1_field, self.g2_field elif self.config["shear_setting"]["shear_type"] == "extra": try: # TODO: every object with individual shear from input catalog(s) obj.g1, obj.g2 = self.g1_field[j], self.g2_field[j] except: # print("failed to load external shear.") chip_output.logger.error("failed to load external shear.") pass elif self.config["shear_setting"]["shear_type"] == "catalog": pass else: chip_output.logger.error("Unknown shear input") raise ValueError("Unknown shear input") header_wcs = generateExtensionHeader( xlen=chip.npix_x, ylen=chip.npix_y, ra=ra_cen, dec=dec_cen, pa=pointing.img_pa.deg, gain=chip.gain, readout=chip.read_noise, dark=chip.dark_noise, saturation=90000, psize=chip.pix_scale, row_num=chip.rowID, col_num=chip.colID, extName='raw') pos_img, offset, local_wcs, real_wcs = obj.getPosImg_Offset_WCS(img=chip.img, fdmodel=self.fd_model, chip=chip, verbose=False, img_header=header_wcs) # pos_img, offset, local_wcs, real_wcs = obj.getPosImg_Offset_WCS(img=chip.img, fdmodel=self.fd_model, chip=chip, verbose=True, img_header=header_wcs) if pos_img.x == -1 or pos_img.y == -1: # Exclude object which is outside the chip area (after field distortion) # print('obj_ra = ', obj.ra, 'obj_dec = ', obj.dec, 'obj_ra_orig = ', obj.ra_orig, 'obj_dec_orig = ',obj.dec_orig) # print("obj missed: %s"%(obj.id)) chip_output.logger.error("obj missed: %s"%(obj.id)) missed_obj += 1 obj.unload_SED() continue # Draw object & update output catalog try: # chip_output.logger.info("current filter type: %s"%filt.filter_type) if self.config["run_option"]["out_cat_only"]: isUpdated = True obj.real_pos = obj.getRealPos(chip.img, global_x=obj.posImg.x, global_y=obj.posImg.y, img_real_wcs=obj.real_wcs) pos_shear = 0. elif chip.survey_type == "photometric" and not self.config["run_option"]["out_cat_only"]: isUpdated, pos_shear = obj.drawObj_multiband( tel=self.tel, pos_img=pos_img, psf_model=psf_model, bandpass_list=filt.bandpass_sub_list, filt=filt, chip=chip, g1=obj.g1, g2=obj.g2, exptime=pointing.exp_time ) elif chip.survey_type == "spectroscopic" and not self.config["run_option"]["out_cat_only"]: isUpdated, pos_shear = obj.drawObj_slitless( tel=self.tel, pos_img=pos_img, psf_model=psf_model, bandpass_list=filt.bandpass_sub_list, filt=filt, chip=chip, g1=obj.g1, g2=obj.g2, exptime=pointing.exp_time, normFilter=norm_filt) if isUpdated: # TODO: add up stats # print("updating output catalog...") chip_output.cat_add_obj(obj, pos_img, pos_shear) pass else: # print("object omitted", flush=True) continue except Exception as e: # print(e) traceback.print_exc() chip_output.logger.error(e) pass # # [C6 TEST] # print("check running:1: pointing-{:} chip-{:} pid-{:} memory-{:6.2}GB".format(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ), flush=True) # print('draw object %s'%obj.id) # print('mag = %.3f'%obj.param['mag_use_normal']) # # Output memory usage # snapshot = tracemalloc.take_snapshot() # top_stats = snapshot.statistics('lineno') # for stat in top_stats[:10]: # print(stat) # Unload SED: obj.unload_SED() del obj gc.collect() del psf_model del self.cat gc.collect() # print("check running:1: pointing-{:} chip-{:} pid-{:} memory-{:6.2}GB".format(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ), flush=True) chip_output.logger.info("check running:1: pointing-%d chip-%d pid-%d memory-%6.2fGB"%(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )) # Detector Effects # =========================================================== # whether to output zero, dark, flat calibration images. if not self.config["run_option"]["out_cat_only"]: chip.img = chip.addEffects( config=self.config, img=chip.img, chip_output=chip_output, filt=filt, ra_cen=pointing.ra, dec_cen=pointing.dec, img_rot=pointing.img_pa, pointing_ID=pointing.id, timestamp_obs=pointing.timestamp, pointing_type=pointing.pointing_type, sky_map=sky_map, tel = self.tel, logger=chip_output.logger) if pointing.pointing_type == 'MS': datetime_obs = datetime.utcfromtimestamp(pointing.timestamp) date_obs = datetime_obs.strftime("%y%m%d") time_obs = datetime_obs.strftime("%H%M%S") h_prim = generatePrimaryHeader( xlen=chip.npix_x, ylen=chip.npix_y, pointNum = str(pointing.id), ra=pointing.ra, dec=pointing.dec, psize=chip.pix_scale, row_num=chip.rowID, col_num=chip.colID, date=date_obs, time_obs=time_obs, exptime=pointing.exp_time, im_type='SCI', sat_pos=[pointing.sat_x, pointing.sat_y, pointing.sat_z], sat_vel=[pointing.sat_vx, pointing.sat_vy, pointing.sat_vz]) h_ext = generateExtensionHeader( xlen=chip.npix_x, ylen=chip.npix_y, ra=pointing.ra, dec=pointing.dec, pa=pointing.img_pa.deg, gain=chip.gain, readout=chip.read_noise, dark=chip.dark_noise, saturation=90000, psize=chip.pix_scale, row_num=chip.rowID, col_num=chip.colID, extName='raw') chip.img = galsim.Image(chip.img.array, dtype=np.uint16) hdu1 = fits.PrimaryHDU(header=h_prim) hdu1.add_checksum() hdu2 = fits.ImageHDU(chip.img.array, header=h_ext) hdu2.add_checksum() hdu1 = fits.HDUList([hdu1, hdu2]) fname = os.path.join(chip_output.subdir, h_prim['FILENAME'] + '.fits') hdu1.writeto(fname, output_verify='ignore', overwrite=True) # print("# objects that are too bright %d out of %d"%(bright_obj, self.nobj)) # print("# objects that are too dim %d out of %d"%(dim_obj, self.nobj)) # print("# objects that are missed %d out of %d"%(missed_obj, self.nobj)) chip_output.logger.info("# objects that are too bright %d out of %d"%(bright_obj, self.nobj)) chip_output.logger.info("# objects that are too dim %d out of %d"%(dim_obj, self.nobj)) chip_output.logger.info("# objects that are missed %d out of %d"%(missed_obj, self.nobj)) del chip.img # print("check running:2: pointing-{:} chip-{:} pid-{:} memory-{:6.2}GB".format(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ), flush=True) chip_output.logger.info("check running:2: pointing-%d chip-%d pid-%d memory-%6.2fGB"%(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )) def runExposure_MPI_PointingList(self, pointing_list, shear_cat_file=None, chips=None, use_mpi=False): if use_mpi: comm = MPI.COMM_WORLD ind_thread = comm.Get_rank() num_thread = comm.Get_size() if chips is None: nchips_per_fp = len(self.chip_list) run_chips = self.chip_list run_filts = self.filter_list else: # Only run a particular set of chips run_chips = [] run_filts = [] nchips_per_fp = len(chips) for ichip in range(len(self.chip_list)): chip = self.chip_list[ichip] filt = self.filter_list[ichip] if chip.chipID in chips: run_chips.append(chip) run_filts.append(filt) for ipoint in range(len(pointing_list)): for ichip in range(nchips_per_fp): i = ipoint*nchips_per_fp + ichip pointing = pointing_list[ipoint] pointing_ID = pointing.id if use_mpi: if i % num_thread != ind_thread: continue pid = os.getpid() sub_img_dir, prefix = makeSubDir_PointingList(path_dict=self.path_dict, config=self.config, pointing_ID=pointing_ID) chip = run_chips[ichip] filt = run_filts[ichip] # print("running pointing#%d, chip#%d, at PID#%d..."%(pointing_ID, chip.chipID, pid), flush=True) chip_output = ChipOutput( config=self.config, focal_plane=self.focal_plane, chip=chip, filt=filt, exptime=pointing.exp_time, pointing_type=pointing.pointing_type, pointing_ID=pointing_ID, subdir=sub_img_dir, prefix=prefix) chip_output.logger.info("running pointing#%d, chip#%d, at PID#%d..."%(pointing_ID, chip.chipID, pid)) self.run_one_chip( chip=chip, filt=filt, chip_output=chip_output, pointing=pointing, cat_dir=self.path_dict["cat_dir"]) print("finished running chip#%d..."%(chip.chipID), flush=True) chip_output.logger.info("finished running chip#%d..."%(chip.chipID)) for handler in chip_output.logger.handlers[:]: chip_output.logger.removeHandler(handler) gc.collect()