import os import numpy as np import mpi4py.MPI as MPI import galsim import logging import psutil from astropy.io import fits from datetime import datetime from ObservationSim.Config import config_dir, ChipOutput from ObservationSim.Config.Header import generatePrimaryHeader, generateExtensionHeader from ObservationSim.Instrument import Telescope, Filter, FilterParam, FocalPlane, Chip from ObservationSim.Instrument.Chip import Effects from ObservationSim.MockObject import calculateSkyMap_split_g from ObservationSim.PSF import PSFGauss, FieldDistortion, PSFInterp from ObservationSim._util import get_shear_field, makeSubDir_PointingList from ObservationSim.Astrometry.Astrometry_util import on_orbit_obs_position class Observation(object): def __init__(self, config, Catalog, work_dir=None, data_dir=None): self.path_dict = config_dir(config=config, work_dir=work_dir, data_dir=data_dir) self.config = config self.tel = Telescope() self.focal_plane = FocalPlane(survey_type=self.config["obs_setting"]["survey_type"]) self.filter_param = FilterParam() self.chip_list = [] self.filter_list = [] self.all_filter = [] self.Catalog = Catalog # if we want to apply field distortion? if self.config["ins_effects"]["field_dist"] == True: self.fd_model = FieldDistortion(fdModel_path=self.path_dict["fd_path"]) else: self.fd_model = None # Construct chips & filters: nchips = self.focal_plane.nchip_x*self.focal_plane.nchip_y for i in range(nchips): chipID = i + 1 # Make Chip & Filter lists chip = Chip( chipID=chipID, config=self.config) filter_id, filter_type = chip.getChipFilter() filt = Filter(filter_id=filter_id, filter_type=filter_type, filter_param=self.filter_param) if not self.focal_plane.isIgnored(chipID=chipID): self.chip_list.append(chip) self.filter_list.append(filt) self.all_filter.append(filt) # Read catalog and shear(s) self.g1_field, self.g2_field, self.nshear = get_shear_field(config=self.config) def run_one_chip(self, chip, filt, pointing, chip_output, wcs_fp=None, psf_model=None, shear_cat_file=None, cat_dir=None, sed_dir=None): # print(':::::::::::::::::::Current Pointing Information::::::::::::::::::') # print("RA: %f, DEC; %f" % (pointing.ra, pointing.dec)) # print("Time: %s" % datetime.fromtimestamp(pointing.timestamp).isoformat()) # print("Exposure time: %f" % pointing.exp_time) # print("Satellite Position (x, y, z): (%f, %f, %f)" % (pointing.sat_x, pointing.sat_y, pointing.sat_z)) # print("Satellite Velocity (x, y, z): (%f, %f, %f)" % (pointing.sat_vx, pointing.sat_vy, pointing.sat_vz)) # print("Position Angle: %f" % pointing.img_pa.deg) # print('Chip : %d' % chip.chipID) # print(':::::::::::::::::::::::::::END:::::::::::::::::::::::::::::::::::') chip_output.logger.info(':::::::::::::::::::Current Pointing Information::::::::::::::::::') chip_output.logger.info("RA: %f, DEC; %f" % (pointing.ra, pointing.dec)) chip_output.logger.info("Time: %s" % datetime.fromtimestamp(pointing.timestamp).isoformat()) chip_output.logger.info("Exposure time: %f" % pointing.exp_time) chip_output.logger.info("Satellite Position (x, y, z): (%f, %f, %f)" % (pointing.sat_x, pointing.sat_y, pointing.sat_z)) chip_output.logger.info("Satellite Velocity (x, y, z): (%f, %f, %f)" % (pointing.sat_vx, pointing.sat_vy, pointing.sat_vz)) chip_output.logger.info("Position Angle: %f" % pointing.img_pa.deg) chip_output.logger.info('Chip : %d' % chip.chipID) chip_output.logger.info(':::::::::::::::::::::::::::END:::::::::::::::::::::::::::::::::::') if self.config["psf_setting"]["psf_model"] == "Gauss": psf_model = PSFGauss(chip=chip) elif self.config["psf_setting"]["psf_model"] == "Interp": psf_model = PSFInterp(chip=chip, PSF_data_file=self.path_dict["psf_dir"]) else: # print("unrecognized PSF model type!!", flush=True) chip_output.logger.error("unrecognized PSF model type!!", flush=True) # Get (extra) shear fields if shear_cat_file is not None: self.g1_field, self.g2_field, self.nshear = get_shear_field(config=self.config, shear_cat_file=shear_cat_file) # Apply astrometric simulation for pointing if self.config["obs_setting"]["enable_astrometric_model"]: dt = datetime.fromtimestamp(pointing.timestamp) date_str = dt.date().isoformat() time_str = dt.time().isoformat() ra_cen, dec_cen = on_orbit_obs_position( input_ra_list=[pointing.ra], input_dec_list=[pointing.dec], input_pmra_list=[0.], input_pmdec_list=[0.], input_rv_list=[0.], input_parallax_list=[1e-9], input_nstars=1, input_x=pointing.sat_x, input_y=pointing.sat_y, input_z=pointing.sat_z, input_vx=pointing.sat_vx, input_vy=pointing.sat_vy, input_vz=pointing.sat_vz, input_epoch="J2015.5", input_date_str=date_str, input_time_str=time_str ) ra_cen, dec_cen = ra_cen[0], dec_cen[0] else: ra_cen = pointing.ra dec_cen = pointing.dec # Get WCS for the focal plane if wcs_fp == None: wcs_fp = self.focal_plane.getTanWCS(ra_cen, dec_cen, pointing.img_pa, chip.pix_scale) # Create chip Image chip.img = galsim.ImageF(chip.npix_x, chip.npix_y) chip.img.setOrigin(chip.bound.xmin, chip.bound.ymin) chip.img.wcs = wcs_fp if chip.survey_type == "photometric": sky_map = None # elif chip.survey_type == "spectroscopic": # sky_map = calculateSkyMap_split_g(xLen=chip.npix_x, yLen=chip.npix_y, blueLimit=filt.blue_limit, redLimit=filt.red_limit, skyfn=self.path_dict["sky_file"], conf=chip.sls_conf, pixelSize=chip.pix_scale, isAlongY=0) elif chip.survey_type == "spectroscopic": flat_normal = np.ones_like(chip.img.array) if self.config["ins_effects"]["flat_fielding"] == True: # print("SLS flat preprocess,CHIP %d : Creating and applying Flat-Fielding"%chip.chipID, flush=True) # print(chip.img.bounds, flush=True) chip_output.logger.info("SLS flat preprocess,CHIP %d : Creating and applying Flat-Fielding"%chip.chipID) msg = str(chip.img.bounds) chip_output.logger.info(msg) flat_img = Effects.MakeFlatSmooth( chip.img.bounds, int(self.config["random_seeds"]["seed_flat"])) flat_normal = flat_normal * flat_img.array / np.mean(flat_img.array) if self.config["ins_effects"]["shutter_effect"] == True: # print("SLS flat preprocess,CHIP %d : Apply shutter effect"%chip.chipID, flush=True) chip_output.logger.info("SLS flat preprocess,CHIP %d : Apply shutter effect"%chip.chipID) shuttimg = Effects.ShutterEffectArr(chip.img, t_shutter=1.3, dist_bearing=735, dt=1E-3) # shutter effect normalized image for this chip flat_normal = flat_normal*shuttimg flat_normal = np.array(flat_normal,dtype='float32') sky_map = calculateSkyMap_split_g( skyMap=flat_normal, blueLimit=filt.blue_limit, redLimit=filt.red_limit, conf=chip.sls_conf, pixelSize=chip.pix_scale, isAlongY=0) del flat_normal if pointing.pointing_type == 'MS': # Load catalogues and templates self.cat = self.Catalog(config=self.config, chip=chip, pointing=pointing, cat_dir=cat_dir, sed_dir=sed_dir, chip_output=chip_output) chip_output.create_output_file() self.nobj = len(self.cat.objs) for ifilt in range(len(self.all_filter)): temp_filter = self.all_filter[ifilt] # Update the limiting magnitude using exposure time in pointing temp_filter.update_limit_saturation_mags(exptime=pointing.exp_time, chip=chip) # Select cutting band filter for saturation/limiting magnitude if temp_filter.filter_type.lower() == self.config["obs_setting"]["cut_in_band"].lower(): cut_filter = temp_filter # Loop over objects missed_obj = 0 bright_obj = 0 dim_obj = 0 for j in range(self.nobj): # (DEBUG) # if j >= 10: # break obj = self.cat.objs[j] if obj.type == 'star' and self.config["run_option"]["galaxy_only"]: continue elif obj.type == 'galaxy' and self.config["run_option"]["star_only"]: continue elif obj.type == 'quasar' and self.config["run_option"]["star_only"]: continue # load and convert SED; also caculate object's magnitude in all CSST bands try: sed_data = self.cat.load_sed(obj) norm_filt = self.cat.load_norm_filt(obj) obj.sed, obj.param["mag_%s"%filt.filter_type], obj.param["flux_%s"%filt.filter_type] = self.cat.convert_sed( mag=obj.param["mag_use_normal"], sed=sed_data, target_filt=filt, norm_filt=norm_filt, ) _, obj.param["mag_%s"%cut_filter.filter_type], obj.param["flux_%s"%cut_filter.filter_type] = self.cat.convert_sed( mag=obj.param["mag_use_normal"], sed=sed_data, target_filt=cut_filter, norm_filt=norm_filt, ) except Exception as e: print(e) chip_output.logger.error(e) continue # chip_output.logger.info("debug point #1") # Exclude very bright/dim objects (for now) # if filt.is_too_bright(mag=obj.getMagFilter(filt)): # if filt.is_too_bright(mag=obj.mag_use_normal): if cut_filter.is_too_bright( mag=obj.param["mag_%s"%self.config["obs_setting"]["cut_in_band"].lower()], margin=self.config["obs_setting"]["mag_sat_margin"]): # print("obj too birght!!", flush=True) if obj.type != 'galaxy': bright_obj += 1 obj.unload_SED() continue if filt.is_too_dim( mag=obj.getMagFilter(filt), margin=self.config["obs_setting"]["mag_lim_margin"]): # if cut_filter.is_too_dim(mag=obj.param["mag_%s"%self.config["obs_setting"]["cut_in_band"].lower()]): # print("obj too dim!!", flush=True) dim_obj += 1 obj.unload_SED() # print(obj.getMagFilter(filt)) continue # chip_output.logger.info("debug point #2") if self.config["shear_setting"]["shear_type"] == "constant": if obj.type == 'star': obj.g1, obj.g2 = 0., 0. else: obj.g1, obj.g2 = self.g1_field, self.g2_field elif self.config["shear_setting"]["shear_type"] == "extra": try: # TODO: every object with individual shear from input catalog(s) obj.g1, obj.g2 = self.g1_field[j], self.g2_field[j] except: # print("failed to load external shear.") chip_output.logger.error("failed to load external shear.") pass # chip_output.logger.info("debug point #3") elif self.config["shear_setting"]["shear_type"] == "catalog": pass else: chip_output.logger.error("Unknown shear input") raise ValueError("Unknown shear input") # chip_output.logger.info("debug point #4") pos_img, offset, local_wcs = obj.getPosImg_Offset_WCS(img=chip.img, fdmodel=self.fd_model, chip=chip, verbose=False) if pos_img.x == -1 or pos_img.y == -1: # Exclude object which is outside the chip area (after field distortion) # print("obj missed!!") missed_obj += 1 obj.unload_SED() continue # chip_output.logger.info("debug point #5") # Draw object & update output catalog try: # chip_output.logger.info("debug point #6") # chip_output.logger.info("current filter type: %s"%filt.filter_type) if self.config["run_option"]["out_cat_only"]: isUpdated = True pos_shear = 0. elif chip.survey_type == "photometric" and not self.config["run_option"]["out_cat_only"]: isUpdated, pos_shear = obj.drawObj_multiband( tel=self.tel, pos_img=pos_img, psf_model=psf_model, bandpass_list=filt.bandpass_sub_list, filt=filt, chip=chip, g1=obj.g1, g2=obj.g2, exptime=pointing.exp_time ) elif chip.survey_type == "spectroscopic" and not self.config["run_option"]["out_cat_only"]: isUpdated, pos_shear = obj.drawObj_slitless( tel=self.tel, pos_img=pos_img, psf_model=psf_model, bandpass_list=filt.bandpass_sub_list, filt=filt, chip=chip, g1=obj.g1, g2=obj.g2, exptime=pointing.exp_time, normFilter=norm_filt, ) # chip_output.logger.info("debug point #7") if isUpdated: # TODO: add up stats # print("updating output catalog...") chip_output.cat_add_obj(obj, pos_img, pos_shear) pass else: # print("object omitted", flush=True) continue except Exception as e: print(e) chip_output.logger.error(e) pass # Unload SED: obj.unload_SED() del obj del psf_model del self.cat # print("check running:1: pointing-{:} chip-{:} pid-{:} memory-{:6.2}GB".format(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ), flush=True) chip_output.logger.info("check running:1: pointing-%d chip-%d pid-%d memory-%6.2fGB"%(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )) # Detector Effects # =========================================================== # whether to output zero, dark, flat calibration images. if not self.config["run_option"]["out_cat_only"]: chip.img = chip.addEffects( config=self.config, img=chip.img, chip_output=chip_output, filt=filt, ra_cen=pointing.ra, dec_cen=pointing.dec, img_rot=pointing.img_pa, pointing_ID=pointing.id, timestamp_obs=pointing.timestamp, pointing_type=pointing.pointing_type, sky_map=sky_map, tel = self.tel, logger=chip_output.logger) if pointing.pointing_type == 'MS': datetime_obs = datetime.fromtimestamp(pointing.timestamp) date_obs = datetime_obs.strftime("%y%m%d") time_obs = datetime_obs.strftime("%H%M%S") h_prim = generatePrimaryHeader( xlen=chip.npix_x, ylen=chip.npix_y, pointNum = str(pointing.id), ra=pointing.ra, dec=pointing.dec, psize=chip.pix_scale, row_num=chip.rowID, col_num=chip.colID, date=date_obs, time_obs=time_obs, exptime=pointing.exp_time, im_type='SCI', sat_pos=[pointing.sat_x, pointing.sat_y, pointing.sat_z], sat_vel=[pointing.sat_vx, pointing.sat_vy, pointing.sat_vz]) h_ext = generateExtensionHeader( xlen=chip.npix_x, ylen=chip.npix_y, ra=pointing.ra, dec=pointing.dec, pa=pointing.img_pa.deg, gain=chip.gain, readout=chip.read_noise, dark=chip.dark_noise, saturation=90000, psize=chip.pix_scale, row_num=chip.rowID, col_num=chip.colID, extName='raw') chip.img = galsim.Image(chip.img.array, dtype=np.uint16) hdu1 = fits.PrimaryHDU(header=h_prim) hdu2 = fits.ImageHDU(chip.img.array, header=h_ext) hdu1 = fits.HDUList([hdu1, hdu2]) fname = os.path.join(chip_output.subdir, h_prim['FILENAME'] + '.fits') hdu1.writeto(fname, output_verify='ignore', overwrite=True) # print("# objects that are too bright %d out of %d"%(bright_obj, self.nobj)) # print("# objects that are too dim %d out of %d"%(dim_obj, self.nobj)) # print("# objects that are missed %d out of %d"%(missed_obj, self.nobj)) chip_output.logger.info("# objects that are too bright %d out of %d"%(bright_obj, self.nobj)) chip_output.logger.info("# objects that are too dim %d out of %d"%(dim_obj, self.nobj)) chip_output.logger.info("# objects that are missed %d out of %d"%(missed_obj, self.nobj)) del chip.img # print("check running:2: pointing-{:} chip-{:} pid-{:} memory-{:6.2}GB".format(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ), flush=True) chip_output.logger.info("check running:2: pointing-%d chip-%d pid-%d memory-%6.2fGB"%(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )) def runExposure_MPI_PointingList(self, pointing_list, shear_cat_file=None, chips=None, use_mpi=False): if use_mpi: comm = MPI.COMM_WORLD ind_thread = comm.Get_rank() num_thread = comm.Get_size() if chips is None: nchips_per_fp = len(self.chip_list) run_chips = self.chip_list run_filts = self.filter_list else: # Only run a particular set of chips run_chips = [] run_filts = [] nchips_per_fp = len(chips) for ichip in range(len(self.chip_list)): chip = self.chip_list[ichip] filt = self.filter_list[ichip] if chip.chipID in chips: run_chips.append(chip) run_filts.append(filt) for ipoint in range(len(pointing_list)): for ichip in range(nchips_per_fp): i = ipoint*nchips_per_fp + ichip pointing = pointing_list[ipoint] pointing_ID = pointing.id if use_mpi: if i % num_thread != ind_thread: continue pid = os.getpid() sub_img_dir, prefix = makeSubDir_PointingList(path_dict=self.path_dict, config=self.config, pointing_ID=pointing_ID) chip = run_chips[ichip] filt = run_filts[ichip] # print("running pointing#%d, chip#%d, at PID#%d..."%(pointing_ID, chip.chipID, pid), flush=True) chip_output = ChipOutput( config=self.config, focal_plane=self.focal_plane, chip=chip, filt=filt, exptime=pointing.exp_time, pointing_type=pointing.pointing_type, pointing_ID=pointing_ID, subdir=sub_img_dir, prefix=prefix) chip_output.logger.info("running pointing#%d, chip#%d, at PID#%d..."%(pointing_ID, chip.chipID, pid)) self.run_one_chip( chip=chip, filt=filt, chip_output=chip_output, pointing=pointing, cat_dir=self.path_dict["cat_dir"]) print("finished running chip#%d..."%(chip.chipID), flush=True) chip_output.logger.info("finished running chip#%d..."%(chip.chipID)) for handler in chip_output.logger.handlers[:]: chip_output.logger.removeHandler(handler)