import os import numpy as np import mpi4py.MPI as MPI import galsim import logging import psutil from astropy.io import fits from datetime import datetime from ObservationSim.Config import config_dir, ChipOutput from ObservationSim.Config.Header import generatePrimaryHeader, generateExtensionHeader from ObservationSim.Instrument import Telescope, Filter, FilterParam, FocalPlane, Chip from ObservationSim.Instrument.Chip import Effects from ObservationSim.MockObject import calculateSkyMap_split_g from ObservationSim.PSF import PSFGauss, FieldDistortion, PSFInterp from ObservationSim._util import get_shear_field, makeSubDir_PointingList class Observation(object): def __init__(self, config, Catalog, work_dir=None, data_dir=None): self.path_dict = config_dir(config=config, work_dir=work_dir, data_dir=data_dir) self.config = config self.tel = Telescope() self.focal_plane = FocalPlane(survey_type=self.config["obs_setting"]["survey_type"]) self.filter_param = FilterParam() self.chip_list = [] self.filter_list = [] self.Catalog = Catalog # if we want to apply field distortion? if self.config["ins_effects"]["field_dist"] == True: self.fd_model = FieldDistortion(fdModel_path=self.path_dict["fd_path"]) else: self.fd_model = None # Construct chips & filters: nchips = self.focal_plane.nchip_x*self.focal_plane.nchip_y for i in range(nchips): chipID = i + 1 if self.focal_plane.isIgnored(chipID=chipID): continue # Make Chip & Filter lists chip = Chip( chipID=chipID, config=self.config) filter_id, filter_type = chip.getChipFilter() filt = Filter(filter_id=filter_id, filter_type=filter_type, filter_param=self.filter_param, ccd_bandpass=chip.effCurve) self.chip_list.append(chip) self.filter_list.append(filt) # Read catalog and shear(s) self.g1_field, self.g2_field, self.nshear = get_shear_field(config=self.config) def run_one_chip(self, chip, filt, pointing, chip_output, wcs_fp=None, psf_model=None, shear_cat_file=None, cat_dir=None, sed_dir=None): print(':::::::::::::::::::Current Pointing Information::::::::::::::::::') print("RA: %f, DEC; %f" % (pointing.ra, pointing.dec)) print("Time: %s" % datetime.fromtimestamp(pointing.timestamp).isoformat()) print("Exposure time: %f" % pointing.exp_time) print("Satellite Position (x, y, z): (%f, %f, %f)" % (pointing.sat_x, pointing.sat_y, pointing.sat_z)) print("Satellite Velocity (x, y, z): (%f, %f, %f)" % (pointing.sat_vx, pointing.sat_vy, pointing.sat_vz)) print("Position Angle: %f" % pointing.img_pa.deg) print('Chip : %d' % chip.chipID) print(':::::::::::::::::::::::::::END:::::::::::::::::::::::::::::::::::') if self.config["psf_setting"]["psf_model"] == "Gauss": psf_model = PSFGauss(chip=chip) elif self.config["psf_setting"]["psf_model"] == "Interp": psf_model = PSFInterp(chip=chip, PSF_data_file=self.path_dict["psf_dir"]) else: print("unrecognized PSF model type!!", flush=True) # Get (extra) shear fields if shear_cat_file is not None: self.g1_field, self.g2_field, self.nshear = get_shear_field(config=self.config, shear_cat_file=shear_cat_file) # Get WCS for the focal plane if wcs_fp == None: wcs_fp = self.focal_plane.getTanWCS(pointing.ra, pointing.dec, pointing.img_pa, chip.pix_scale) # Create chip Image chip.img = galsim.ImageF(chip.npix_x, chip.npix_y) chip.img.setOrigin(chip.bound.xmin, chip.bound.ymin) chip.img.wcs = wcs_fp if chip.survey_type == "photometric": sky_map = None # elif chip.survey_type == "spectroscopic": # sky_map = calculateSkyMap_split_g(xLen=chip.npix_x, yLen=chip.npix_y, blueLimit=filt.blue_limit, redLimit=filt.red_limit, skyfn=self.path_dict["sky_file"], conf=chip.sls_conf, pixelSize=chip.pix_scale, isAlongY=0) elif chip.survey_type == "spectroscopic": flat_normal = np.ones_like(chip.img.array) if self.config["ins_effects"]["flat_fielding"] == True: print("SLS flat preprocess,CHIP %d : Creating and applying Flat-Fielding"%chip.chipID, flush=True) print(chip.img.bounds, flush=True) flat_img = Effects.MakeFlatSmooth( chip.img.bounds, int(self.config["random_seeds"]["seed_flat"])) flat_normal = flat_normal * flat_img.array / np.mean(flat_img.array) if self.config["ins_effects"]["shutter_effect"] == True: print("SLS flat preprocess,CHIP %d : Apply shutter effect"%chip.chipID, flush=True) shuttimg = Effects.ShutterEffectArr(chip.img, t_shutter=1.3, dist_bearing=735, dt=1E-3) # shutter effect normalized image for this chip flat_normal = flat_normal*shuttimg flat_normal = np.array(flat_normal,dtype='float32') sky_map = calculateSkyMap_split_g( skyMap=flat_normal, blueLimit=filt.blue_limit, redLimit=filt.red_limit, conf=chip.sls_conf, pixelSize=chip.pix_scale, isAlongY=0) del flat_normal if pointing.pointing_type == 'MS': # Load catalogues and templates self.cat = self.Catalog(config=self.config, chip=chip, pointing=pointing, cat_dir=cat_dir, sed_dir=sed_dir) self.nobj = len(self.cat.objs) # Loop over objects missed_obj = 0 bright_obj = 0 dim_obj = 0 for j in range(self.nobj): # if j >= 10: # break obj = self.cat.objs[j] if obj.type == 'star' and self.config["run_option"]["galaxy_only"]: continue elif obj.type == 'galaxy' and self.config["run_option"]["star_only"]: continue elif obj.type == 'quasar' and self.config["run_option"]["star_only"]: continue # load SED try: sed_data = self.cat.load_sed(obj) norm_filt = self.cat.load_norm_filt(obj) obj.sed, obj.param["mag_%s"%filt.filter_type] = self.cat.convert_sed( mag=obj.param["mag_use_normal"], sed=sed_data, target_filt=filt, norm_filt=norm_filt, ) except Exception as e: print(e) continue # Exclude very bright/dim objects (for now) if filt.is_too_bright(mag=obj.getMagFilter(filt)): # print("obj too birght!!", flush=True) if obj.type != 'galaxy': bright_obj += 1 obj.unload_SED() continue if filt.is_too_dim(mag=obj.getMagFilter(filt)): # print("obj too dim!!", flush=True) dim_obj += 1 obj.unload_SED() # print(obj.getMagFilter(filt)) continue if self.config["shear_setting"]["shear_type"] == "constant": if obj.type == 'star': obj.g1, obj.g2 = 0., 0. else: obj.g1, obj.g2 = self.g1_field, self.g2_field elif self.config["shear_setting"]["shear_type"] == "extra": try: # TODO: every object with individual shear from input catalog(s) obj.g1, obj.g2 = self.g1_field[j], self.g2_field[j] except: print("failed to load external shear.") pass elif self.config["shear_setting"]["shear_type"] == "catalog": pass else: raise ValueError("Unknown shear input") pos_img, offset, local_wcs = obj.getPosImg_Offset_WCS(img=chip.img, fdmodel=self.fd_model, chip=chip, verbose=False) if pos_img.x == -1 or pos_img.y == -1: # Exclude object which is outside the chip area (after field distortion) # print("obj missed!!") missed_obj += 1 obj.unload_SED() continue # Draw object & update output catalog try: if self.config["run_option"]["out_cat_only"]: isUpdated = True if chip.survey_type == "photometric" and not self.config["run_option"]["out_cat_only"]: isUpdated, pos_shear = obj.drawObj_multiband( tel=self.tel, pos_img=pos_img, psf_model=psf_model, bandpass_list=filt.bandpass_sub_list, filt=filt, chip=chip, g1=obj.g1, g2=obj.g2, exptime=pointing.exp_time ) elif chip.survey_type == "spectroscopic" and not self.config["run_option"]["out_cat_only"]: isUpdated, pos_shear = obj.drawObj_slitless( tel=self.tel, pos_img=pos_img, psf_model=psf_model, bandpass_list=filt.bandpass_sub_list, filt=filt, chip=chip, g1=obj.g1, g2=obj.g2, exptime=pointing.exp_time, normFilter=norm_filt, ) if isUpdated: # TODO: add up stats # print("updating output catalog...") chip_output.cat_add_obj(obj, pos_img, pos_shear) pass else: # print("object omitted", flush=True) continue except Exception as e: print(e) pass # Unload SED: obj.unload_SED() del obj del psf_model del self.cat print("check running:1: pointing-{:} chip-{:} pid-{:} memory-{:6.2}GB".format(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ), flush=True) # Detector Effects # =========================================================== # whether to output zero, dark, flat calibration images. chip.img = chip.addEffects( config=self.config, img=chip.img, chip_output=chip_output, filt=filt, ra_cen=pointing.ra, dec_cen=pointing.dec, img_rot=pointing.img_pa, pointing_ID=pointing.id, timestamp_obs=pointing.timestamp, pointing_type=pointing.pointing_type, sky_map=sky_map, tel = self.tel) if pointing.pointing_type == 'MS': datetime_obs = datetime.fromtimestamp(pointing.timestamp) date_obs = datetime_obs.strftime("%y%m%d") time_obs = datetime_obs.strftime("%H%M%S") h_prim = generatePrimaryHeader( xlen=chip.npix_x, ylen=chip.npix_y, pointNum = str(pointing.id), ra=pointing.ra, dec=pointing.dec, psize=chip.pix_scale, row_num=chip.rowID, col_num=chip.colID, date=date_obs, time_obs=time_obs, exptime=pointing.exp_time, im_type='SCI', sat_pos=[pointing.sat_x, pointing.sat_y, pointing.sat_z], sat_vel=[pointing.sat_vx, pointing.sat_vy, pointing.sat_vz]) h_ext = generateExtensionHeader( xlen=chip.npix_x, ylen=chip.npix_y, ra=pointing.ra, dec=pointing.dec, pa=pointing.img_pa.deg, gain=chip.gain, readout=chip.read_noise, dark=chip.dark_noise, saturation=90000, psize=chip.pix_scale, row_num=chip.rowID, col_num=chip.colID, extName='raw') chip.img = galsim.Image(chip.img.array, dtype=np.uint16) hdu1 = fits.PrimaryHDU(header=h_prim) hdu2 = fits.ImageHDU(chip.img.array, header=h_ext) hdu1 = fits.HDUList([hdu1, hdu2]) fname = os.path.join(chip_output.subdir, h_prim['FILENAME'] + '.fits') hdu1.writeto(fname, output_verify='ignore', overwrite=True) print("# objects that are too bright %d out of %d"%(bright_obj, self.nobj)) print("# objects that are too dim %d out of %d"%(dim_obj, self.nobj)) print("# objects that are missed %d out of %d"%(missed_obj, self.nobj)) del chip.img print("check running:2: pointing-{:} chip-{:} pid-{:} memory-{:6.2}GB".format(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ), flush=True) def runExposure_MPI_PointingList(self, pointing_list, shear_cat_file=None, chips=None, use_mpi=False): if use_mpi: comm = MPI.COMM_WORLD ind_thread = comm.Get_rank() num_thread = comm.Get_size() if chips is None: nchips_per_fp = len(self.chip_list) run_chips = self.chip_list run_filts = self.filter_list else: # Only run a particular set of chips run_chips = [] run_filts = [] nchips_per_fp = len(chips) for ichip in range(len(self.chip_list)): chip = self.chip_list[ichip] filt = self.filter_list[ichip] if chip.chipID in chips: run_chips.append(chip) run_filts.append(filt) for ipoint in range(len(pointing_list)): for ichip in range(nchips_per_fp): i = ipoint*nchips_per_fp + ichip pointing = pointing_list[ipoint] pointing_ID = pointing.id if use_mpi: if i % num_thread != ind_thread: continue pid = os.getpid() sub_img_dir, prefix = makeSubDir_PointingList(path_dict=self.path_dict, config=self.config, pointing_ID=pointing_ID) chip = run_chips[ichip] filt = run_filts[ichip] print("running pointing#%d, chip#%d, at PID#%d..."%(pointing_ID, chip.chipID, pid), flush=True) chip_output = ChipOutput( config=self.config, focal_plane=self.focal_plane, chip=chip, filt=filt, exptime=pointing.exp_time, pointing_type=pointing.pointing_type, pointing_ID=pointing_ID, subdir=sub_img_dir, prefix=prefix) self.run_one_chip( chip=chip, filt=filt, chip_output=chip_output, pointing=pointing, cat_dir=self.path_dict["cat_dir"]) print("finished running chip#%d..."%(chip.chipID), flush=True)