import os import numpy as np import mpi4py.MPI as MPI import galsim import psutil import gc from datetime import datetime import traceback from ObservationSim.Config import config_dir, ChipOutput from ObservationSim.Instrument import Telescope, Filter, FilterParam, FocalPlane, Chip from ObservationSim.Instrument.Chip import Effects from ObservationSim.Instrument.Chip import ChipUtils as chip_utils from ObservationSim._util import makeSubDir_PointingList from ObservationSim.Astrometry.Astrometry_util import on_orbit_obs_position from ObservationSim.SimSteps import SimSteps, SIM_STEP_TYPES class Observation(object): def __init__(self, config, Catalog, work_dir=None, data_dir=None): self.path_dict = config_dir(config=config, work_dir=work_dir, data_dir=data_dir) self.config = config self.tel = Telescope() self.filter_param = FilterParam() self.Catalog = Catalog def prepare_chip_for_exposure(self, chip, ra_cen, dec_cen, pointing, wcs_fp=None): # Get WCS for the focal plane if wcs_fp == None: wcs_fp = self.focal_plane.getTanWCS(ra_cen, dec_cen, pointing.img_pa, chip.pix_scale) # Create chip Image chip.img = galsim.ImageF(chip.npix_x, chip.npix_y) chip.img.setOrigin(chip.bound.xmin, chip.bound.ymin) chip.img.wcs = wcs_fp # Get random generators for this chip chip.rng_poisson, chip.poisson_noise = chip_utils.get_poisson( seed=int(self.config["random_seeds"]["seed_poisson"]) + pointing.id*30 + chip.chipID, sky_level=0.) # Get flat, shutter, and PRNU images chip.flat_img, _ = chip_utils.get_flat(img=chip.img, seed=int(self.config["random_seeds"]["seed_flat"])) chip.shutter_img = Effects.ShutterEffectArr(chip.img, t_shutter=1.3, dist_bearing=735, dt=1E-3) chip.prnu_img = Effects.PRNU_Img(xsize=chip.npix_x, ysize=chip.npix_y, sigma=0.01, seed=int(self.config["random_seeds"]["seed_prnu"]+chip.chipID)) return chip def run_one_chip(self, chip, filt, pointing, chip_output, wcs_fp=None, psf_model=None, cat_dir=None, sed_dir=None): chip_output.Log_info(':::::::::::::::::::Current Pointing Information::::::::::::::::::') chip_output.Log_info("RA: %f, DEC; %f" % (pointing.ra, pointing.dec)) chip_output.Log_info("Time: %s" % datetime.utcfromtimestamp(pointing.timestamp).isoformat()) chip_output.Log_info("Exposure time: %f" % pointing.exp_time) chip_output.Log_info("Satellite Position (x, y, z): (%f, %f, %f)" % (pointing.sat_x, pointing.sat_y, pointing.sat_z)) chip_output.Log_info("Satellite Velocity (x, y, z): (%f, %f, %f)" % (pointing.sat_vx, pointing.sat_vy, pointing.sat_vz)) chip_output.Log_info("Position Angle: %f" % pointing.img_pa.deg) chip_output.Log_info('Chip : %d' % chip.chipID) chip_output.Log_info(':::::::::::::::::::::::::::END:::::::::::::::::::::::::::::::::::') # Apply astrometric simulation for pointing if self.config["obs_setting"]["enable_astrometric_model"]: dt = datetime.utcfromtimestamp(pointing.timestamp) date_str = dt.date().isoformat() time_str = dt.time().isoformat() ra_cen, dec_cen = on_orbit_obs_position( input_ra_list=[pointing.ra], input_dec_list=[pointing.dec], input_pmra_list=[0.], input_pmdec_list=[0.], input_rv_list=[0.], input_parallax_list=[1e-9], input_nstars=1, input_x=pointing.sat_x, input_y=pointing.sat_y, input_z=pointing.sat_z, input_vx=pointing.sat_vx, input_vy=pointing.sat_vy, input_vz=pointing.sat_vz, input_epoch="J2000", input_date_str=date_str, input_time_str=time_str ) ra_cen, dec_cen = ra_cen[0], dec_cen[0] else: ra_cen = pointing.ra dec_cen = pointing.dec # Prepare necessary chip properties for simulation chip = self.prepare_chip_for_exposure(chip, ra_cen, dec_cen, pointing) # Load catalogues self.cat = self.Catalog(config=self.config, chip=chip, pointing=pointing, cat_dir=cat_dir, sed_dir=sed_dir, chip_output=chip_output, filt=filt) # Initialize SimSteps sim_steps = SimSteps(overall_config=self.config, chip_output=chip_output, all_filters=self.all_filters) for step in pointing.obs_param["call_sequence"]: chip_output.Log_info("Starting simulation step: %s, calling function: %s"%(step, SIM_STEP_TYPES[step])) obs_param = pointing.obs_param["call_sequence"][step] step_name = SIM_STEP_TYPES[step] try: step_func = getattr(sim_steps, step_name) chip, filt, tel, pointing = step_func( chip=chip, filt=filt, tel=self.tel, pointing=pointing, catalog=self.cat, obs_param=obs_param) chip_output.Log_info("Finished simulation step: %s"%(step)) except Exception as e: traceback.print_exc() chip_output.Log_error(e) chip_output.Log_error("Failed simulation on step: %s"%(step)) break chip_output.Log_info("check running:1: pointing-%d chip-%d pid-%d memory-%6.2fGB"%(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )) del chip.img # def run_one_chip_calibration(self, chip, filt, pointing, chip_output, skyback_level = 20000, sky_level_filt = 'g', wcs_fp=None, psf_model=None, cat_dir=None, sed_dir=None): # # # Get WCS for the focal plane # # if wcs_fp == None: # # wcs_fp = self.focal_plane.getTanWCS(ra_cen, dec_cen, pointing.img_pa, chip.pix_scale) # # Create chip Image # chip.img = galsim.ImageF(chip.npix_x, chip.npix_y) # chip.img.setOrigin(chip.bound.xmin, chip.bound.ymin) # # chip.img.wcs = wcs_fp # pf_map = np.zeros_like(chip.img.array) # if self.config["obs_setting"]["LED_TYPE"] is not None: # if len(self.config["obs_setting"]["LED_TYPE"]) != 0: # print("LED OPEN--------") # led_obj = FlatLED(chip, filt) # led_flat = led_obj.drawObj_LEDFlat(led_type_list=self.config["obs_setting"]["LED_TYPE"], exp_t_list=self.config["obs_setting"]["LED_TIME"]) # pf_map = led_flat # # whether to output zero, dark, flat calibration images. # expTime = self.config["obs_setting"]["exp_time"] # norm_scaler = skyback_level/expTime / self.filter_param.param[sky_level_filt][5] # if skyback_level == 0: # self.config["ins_effects"]["shutter_effect"] = False # if chip.survey_type == "photometric": # sky_map = np.ones_like(chip.img.array) * norm_scaler * self.filter_param.param[chip.filter_type][5] / self.tel.pupil_area # elif chip.survey_type == "spectroscopic": # flat_normal = np.ones_like(chip.img.array) # if self.config["ins_effects"]["flat_fielding"] == True: # chip_output.Log_info("SLS flat preprocess,CHIP %d : Creating and applying Flat-Fielding" % chip.chipID) # msg = str(chip.img.bounds) # chip_output.Log_info(msg) # flat_img = Effects.MakeFlatSmooth( # chip.img.bounds, # int(self.config["random_seeds"]["seed_flat"])) # flat_normal = flat_normal * flat_img.array / np.mean(flat_img.array) # if self.config["ins_effects"]["shutter_effect"] == True: # chip_output.Log_info("SLS flat preprocess,CHIP %d : Apply shutter effect" % chip.chipID) # shuttimg = Effects.ShutterEffectArr(chip.img, t_shutter=1.3, dist_bearing=735, # dt=1E-3) # shutter effect normalized image for this chip # flat_normal = flat_normal * shuttimg # flat_normal = np.array(flat_normal, dtype='float32') # sky_map = calculateSkyMap_split_g( # skyMap=flat_normal, # blueLimit=filt.blue_limit, # redLimit=filt.red_limit, # conf=chip.sls_conf, # pixelSize=chip.pix_scale, # isAlongY=0, # flat_cube=chip.flat_cube) # sky_map = sky_map * norm_scaler # chip.img = chip.addEffects( # config=self.config, # img=chip.img, # chip_output=chip_output, # filt=filt, # ra_cen=pointing.ra, # dec_cen=pointing.dec, # img_rot=pointing.img_pa, # exptime=self.config["obs_setting"]["exp_time"], # pointing_ID=pointing.id, # timestamp_obs=pointing.timestamp, # pointing_type=pointing.pointing_type, # sky_map=sky_map, tel=self.tel, # post_flash_map=pf_map, # logger=chip_output.logger) # datetime_obs = datetime.utcfromtimestamp(pointing.timestamp) # date_obs = datetime_obs.strftime("%y%m%d") # time_obs = datetime_obs.strftime("%H%M%S") # h_prim = generatePrimaryHeader( # xlen=chip.npix_x, # ylen=chip.npix_y, # pointNum=str(pointing.id), # ra=pointing.ra, # dec=pointing.dec, # pixel_scale=chip.pix_scale, # date=date_obs, # time_obs=time_obs, # exptime=self.config["obs_setting"]["exp_time"], # im_type='DARKPF', # sat_pos=[pointing.sat_x, pointing.sat_y, pointing.sat_z], # sat_vel=[pointing.sat_vx, pointing.sat_vy, pointing.sat_vz], # chip_name=str(chip.chipID).rjust(2, '0')) # h_ext = generateExtensionHeader( # chip=chip, # xlen=chip.npix_x, # ylen=chip.npix_y, # ra=pointing.ra, # dec=pointing.dec, # pa=pointing.img_pa.deg, # gain=chip.gain, # readout=chip.read_noise, # dark=chip.dark_noise, # saturation=90000, # pixel_scale=chip.pix_scale, # pixel_size=chip.pix_size, # xcen=chip.x_cen, # ycen=chip.y_cen, # extName='SCI', # timestamp=pointing.timestamp, # exptime=self.config["obs_setting"]["exp_time"], # readoutTime=chip.readout_time) # chip.img = galsim.Image(chip.img.array, dtype=np.uint16) # hdu1 = fits.PrimaryHDU(header=h_prim) # hdu1.add_checksum() # hdu1.header.comments['CHECKSUM'] = 'HDU checksum' # hdu1.header.comments['DATASUM'] = 'data unit checksum' # hdu2 = fits.ImageHDU(chip.img.array, header=h_ext) # hdu2.add_checksum() # hdu2.header.comments['XTENSION'] = 'extension type' # hdu2.header.comments['CHECKSUM'] = 'HDU checksum' # hdu2.header.comments['DATASUM'] = 'data unit checksum' # hdu1 = fits.HDUList([hdu1, hdu2]) # fname = os.path.join(chip_output.subdir, h_prim['FILENAME'] + '.fits') # hdu1.writeto(fname, output_verify='ignore', overwrite=True) # # chip_output.Log_info("# objects that are too bright %d out of %d" % (bright_obj, self.nobj)) # # chip_output.Log_info("# objects that are too dim %d out of %d" % (dim_obj, self.nobj)) # # chip_output.Log_info("# objects that are missed %d out of %d" % (missed_obj, self.nobj)) # del chip.img # chip_output.Log_info("check running:2: pointing-%d chip-%d pid-%d memory-%6.2fGB" % ( # pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024))) def runExposure_MPI_PointingList(self, pointing_list, chips=None, use_mpi=False): if use_mpi: comm = MPI.COMM_WORLD ind_thread = comm.Get_rank() num_thread = comm.Get_size() process_counter = 0 for ipoint in range(len(pointing_list)): # Construct chips & filters: pointing = pointing_list[ipoint] pointing_ID = pointing.id self.focal_plane = FocalPlane(chip_list=pointing.obs_param["run_chips"]) # Make Chip & Filter lists self.chip_list = [] self.filter_list = [] self.all_filters = [] for i in range(self.focal_plane.nchips): chipID = i + 1 chip = Chip(chipID=chipID, config=self.config) filter_id, filter_type = chip.getChipFilter() filt = Filter( filter_id=filter_id, filter_type=filter_type, filter_param=self.filter_param) if not self.focal_plane.isIgnored(chipID=chipID): self.chip_list.append(chip) self.filter_list.append(filt) self.all_filters.append(filt) if chips is None: # Run all chips defined in configuration of this pointing run_chips = self.chip_list run_filts = self.filter_list nchips_per_fp = len(self.chip_list) else: # Only run a particular set of chips (defined in the overall config file) run_chips = [] run_filts = [] for ichip in range(len(self.chip_list)): chip = self.chip_list[ichip] filt = self.filter_list[ichip] if chip.chipID in chips: run_chips.append(chip) run_filts.append(filt) nchips_per_fp = len(chips) for ichip in range(nchips_per_fp): i_process = process_counter + ichip if use_mpi: if i_process % num_thread != ind_thread: continue pid = os.getpid() sub_img_dir, prefix = makeSubDir_PointingList(path_dict=self.path_dict, config=self.config, pointing_ID=pointing_ID) chip = run_chips[ichip] filt = run_filts[ichip] # chip_output.Log_info("running pointing#%d, chip#%d, at PID#%d..."%(pointing_ID, chip.chipID, pid)) chip_output = ChipOutput( config=self.config, focal_plane=self.focal_plane, chip=chip, filt=filt, exptime=pointing.exp_time, pointing_type=pointing.pointing_type, pointing_ID=pointing_ID, subdir=sub_img_dir, prefix=prefix, ra_cen=pointing.ra, dec_cen=pointing.dec) chip_output.Log_info("running pointing#%d, chip#%d, at PID#%d..."%(pointing_ID, chip.chipID, pid)) self.run_one_chip( chip=chip, filt=filt, chip_output=chip_output, pointing=pointing) # if self.config["obs_setting"]["survey_type"] == "CALIBRATION": # self.run_one_chip_calibration(chip=chip, # filt=filt, # chip_output=chip_output, # pointing=pointing, # skyback_level = self.config["obs_setting"]["FLAT_LEVEL"], # sky_level_filt = self.config["obs_setting"]["FLAT_LEVEL_FIL"]) # else: # self.run_one_chip( # chip=chip, # filt=filt, # chip_output=chip_output, # pointing=pointing) chip_output.Log_info("finished running chip#%d..."%(chip.chipID)) for handler in chip_output.logger.handlers[:]: chip_output.logger.removeHandler(handler) gc.collect() process_counter += nchips_per_fp