import sys import os from ObservationSim.Config import ConfigDir, ChipOutput from ObservationSim.Config.Header import generatePrimaryHeader, generateExtensionHeader from ObservationSim.Instrument import Telescope, Filter, FilterParam, FocalPlane, Chip from ObservationSim.MockObject import calculateSkyMap_split_g from ObservationSim.PSF import PSFGauss, FieldDistortion, PSFInterp from ObservationSim._util import getShearFiled, makeSubDir_PointingList from astropy.io import fits from datetime import datetime import numpy as np import mpi4py.MPI as MPI import galsim import logging import psutil class Observation(object): def __init__(self, config, Catalog, work_dir=None, data_dir=None): self.path_dict = ConfigDir(config=config, work_dir=work_dir, data_dir=data_dir) # self.config = ReadConfig(self.path_dict["config_file"]) self.config = config self.tel = Telescope(optEffCurve_path=self.path_dict["mirror_file"]) # Currently the default values are hard coded in self.focal_plane = FocalPlane(survey_type=self.config["obs_setting"]["survey_type"]) # Currently the default values are hard coded in self.filter_param = FilterParam(filter_dir=self.path_dict["filter_dir"]) # Currently the default values are hard coded in self.chip_list = [] self.filter_list = [] self.Catalog = Catalog # if we want to apply field distortion? if self.config["ins_effects"]["field_dist"] == True: self.fd_model = FieldDistortion(fdModel_path=self.path_dict["fd_path"]) else: self.fd_model = None # Construct chips & filters: nchips = self.focal_plane.nchip_x*self.focal_plane.nchip_y for i in range(nchips): chipID = i + 1 if self.focal_plane.isIgnored(chipID=chipID): continue # Make Chip & Filter lists chip = Chip(chipID, ccdEffCurve_dir=self.path_dict["ccd_dir"], CRdata_dir=self.path_dict["CRdata_dir"], normalize_dir=self.path_dict["normalize_dir"], sls_dir=self.path_dict["sls_dir"], config=self.config) # currently there is no config file for chips filter_id, filter_type = chip.getChipFilter() filt = Filter(filter_id=filter_id, filter_type=filter_type, filter_param=self.filter_param, ccd_bandpass=chip.effCurve) self.chip_list.append(chip) self.filter_list.append(filt) # Read catalog and shear(s) self.g1_field, self.g2_field, self.nshear = getShearFiled(config=self.config) def runOneChip(self, chip, filt, chip_output, wcs_fp=None, psf_model=None, pointing_ID=0, ra_cen=None, dec_cen=None, img_rot=None, exptime=150., timestamp_obs=1621915200, pointing_type="MS", shear_cat_file=None, cat_dir=None, sed_dir=None): if (ra_cen is None) or (dec_cen is None): ra_cen = self.config["obs_setting"]["ra_center"] dec_cen = self.config["obs_setting"]["dec_center"] if img_rot is None: img_rot = self.config["obs_setting"]["image_rot"] if self.config["psf_setting"]["psf_model"] == "Gauss": psf_model = PSFGauss(chip=chip) elif self.config["psf_setting"]["psf_model"] == "Interp": psf_model = PSFInterp(chip=chip, PSF_data_file=self.path_dict["psf_dir"]) else: print("unrecognized PSF model type!!", flush=True) # Get (extra) shear fields if shear_cat_file is not None: self.g1_field, self.g2_field, self.nshear = getShearFiled(config=self.config, shear_cat_file=shear_cat_file) # Get WCS for the focal plane if wcs_fp == None: wcs_fp = self.focal_plane.getTanWCS(ra_cen, dec_cen, img_rot, chip.pix_scale) # Create chip Image chip.img = galsim.ImageF(chip.npix_x, chip.npix_y) chip.img.setOrigin(chip.bound.xmin, chip.bound.ymin) chip.img.wcs = wcs_fp if chip.survey_type == "photometric": sky_map = None elif chip.survey_type == "spectroscopic": sky_map = calculateSkyMap_split_g(xLen=chip.npix_x, yLen=chip.npix_y, blueLimit=filt.blue_limit, redLimit=filt.red_limit, skyfn=self.path_dict["sky_file"], conf=chip.sls_conf, pixelSize=chip.pix_scale, isAlongY=0) if pointing_type == 'MS': # Load catalogues and templates self.cat = self.Catalog(config=self.config, chip=chip, cat_dir=cat_dir, sed_dir=sed_dir) self.nobj = len(self.cat.objs) # Loop over objects missed_obj = 0 bright_obj = 0 dim_obj = 0 for j in range(self.nobj): # if j >= 100: # break obj = self.cat.objs[j] if obj.type == 'star' and self.config["galaxy_only"]: continue elif obj.type == 'galaxy' and self.config["star_only"]: continue elif obj.type == 'quasar' and self.config["star_only"]: continue # load SED try: sed_data = self.cat.load_sed(obj) norm_filt = self.cat.load_norm_filt(obj) obj.sed, obj.param["mag_%s"%filt.filter_type] = self.cat.convert_sed( mag=obj.param["mag_use_normal"], sed=sed_data, target_filt=filt, norm_filt=norm_filt, ) except Exception as e: print(e) continue # Exclude very bright/dim objects (for now) if filt.is_too_bright(mag=obj.getMagFilter(filt)): # print("obj too birght!!", flush=True) if obj.type != 'galaxy': bright_obj += 1 obj.unload_SED() continue if filt.is_too_dim(mag=obj.getMagFilter(filt)): # print("obj too dim!!", flush=True) dim_obj += 1 obj.unload_SED() # print(obj.getMagFilter(filt)) continue if self.config["shear_setting"]["shear_type"] == "constant": if obj.type == 'star': g1, g2 = 0, 0 else: g1, g2 = self.g1_field, self.g2_field elif self.config["shear_setting"]["shear_type"] == "extra": try: # TODO: every object with individual shear from input catalog(s) g1, g2 = self.g1_field[j], self.g2_field[j] except: print("failed to load external shear.") pass elif self.config["shear_setting"]["shear_type"] == "catalog": pass else: raise ValueError("Unknown shear input") pos_img, offset, local_wcs = obj.getPosImg_Offset_WCS(img=chip.img, fdmodel=self.fd_model, chip=chip, verbose=False) if pos_img.x == -1 or pos_img.y == -1: # Exclude object which is outside the chip area (after field distortion) # print("obj missed!!") missed_obj += 1 obj.unload_SED() continue # Draw object & update output catalog try: if self.config["out_cat_only"]: isUpdated = True if chip.survey_type == "photometric" and not self.config["out_cat_only"]: isUpdated, pos_shear = obj.drawObj_multiband( tel=self.tel, pos_img=pos_img, psf_model=psf_model, bandpass_list=filt.bandpass_sub_list, filt=filt, chip=chip, g1=g1, g2=g2, exptime=exptime) elif chip.survey_type == "spectroscopic" and not self.config["out_cat_only"]: isUpdated, pos_shear = obj.drawObj_slitless( tel=self.tel, pos_img=pos_img, psf_model=psf_model, bandpass_list=filt.bandpass_sub_list, filt=filt, chip=chip, g1=g1, g2=g2, exptime=exptime, # normFilter=normF, normFilter=norm_filt, ) if isUpdated: # TODO: add up stats chip_output.cat_add_obj(obj, pos_img, pos_shear, g1, g2) pass else: # print("object omitted", flush=True) continue except Exception as e: print(e) pass # Unload SED: obj.unload_SED() del obj del psf_model del self.cat print("check running:1: pointing-{:} chip-{:} pid-{:} memory-{:6.2}GB".format(pointing_ID, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ), flush=True) # Detector Effects # =========================================================== # whether to output zero, dark, flat calibration images. chip.img = chip.addEffects( config=self.config, img=chip.img, chip_output=chip_output, filt=filt, ra_cen=ra_cen, dec_cen=dec_cen, img_rot=img_rot, pointing_ID=pointing_ID, timestamp_obs=timestamp_obs, pointing_type=pointing_type, sky_map=sky_map, tel = self.tel) if pointing_type == 'MS': datetime_obs = datetime.fromtimestamp(timestamp_obs) date_obs = datetime_obs.strftime("%y%m%d") time_obs = datetime_obs.strftime("%H%M%S") h_prim = generatePrimaryHeader( xlen=chip.npix_x, ylen=chip.npix_y, pointNum = str(pointing_ID), ra=ra_cen, dec=dec_cen, psize=chip.pix_scale, row_num=chip.rowID, col_num=chip.colID, # date=self.config["date_obs"], # time_obs=self.config["time_obs"], date=date_obs, time_obs=time_obs, im_type='MS') h_ext = generateExtensionHeader( xlen=chip.npix_x, ylen=chip.npix_y, ra=ra_cen, dec=dec_cen, pa=img_rot.deg, gain=chip.gain, readout=chip.read_noise, dark=chip.dark_noise, saturation=90000, psize=chip.pix_scale, row_num=chip.rowID, col_num=chip.colID, extName='raw') chip.img = galsim.Image(chip.img.array, dtype=np.uint16) # chip.img = galsim.Image(chip.img.array, dtype=np.uint32) hdu1 = fits.PrimaryHDU(header=h_prim) hdu2 = fits.ImageHDU(chip.img.array, header=h_ext) hdu1 = fits.HDUList([hdu1, hdu2]) fname = os.path.join(chip_output.subdir, h_prim['FILENAME'] + '.fits') hdu1.writeto(fname, output_verify='ignore', overwrite=True) print("# objects that are too bright %d out of %d"%(bright_obj, self.nobj)) print("# objects that are too dim %d out of %d"%(dim_obj, self.nobj)) print("# objects that are missed %d out of %d"%(missed_obj, self.nobj)) del chip.img print("check running:2: pointing-{:} chip-{:} pid-{:} memory-{:6.2}GB".format(pointing_ID, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ), flush=True) def runExposure_MPI_PointingList(self, ra_cen=None, dec_cen=None, pRange=None, timestamp_obs=np.array([1621915200]), pointing_type=np.array(['MS']), img_rot=None, exptime=150., shear_cat_file=None, chips=None, use_mpi=False): if use_mpi: comm = MPI.COMM_WORLD ind_thread = comm.Get_rank() num_thread = comm.Get_size() if chips is None: nchips_per_fp = len(self.chip_list) run_chips = self.chip_list run_filts = self.filter_list else: # Only run a particular set of chips run_chips = [] run_filts = [] nchips_per_fp = len(chips) for ichip in range(len(self.chip_list)): chip = self.chip_list[ichip] filt = self.filter_list[ichip] if chip.chipID in chips: run_chips.append(chip) run_filts.append(filt) # TEMP if len(timestamp_obs) == 1: timestamp_obs = np.tile(timestamp_obs, len(ra_cen)) pointing_type = np.tile(pointing_type, len(ra_cen)) if pRange is not None: timestamp_obs = timestamp_obs[pRange] pointing_type = pointing_type[pRange] ra_cen = ra_cen[pRange] dec_cen = dec_cen[pRange] # The Starting pointing ID if pRange is not None: pStart = pRange[0] else: pStart = 0 for ipoint in range(len(ra_cen)): for ichip in range(nchips_per_fp): i = ipoint*nchips_per_fp + ichip if pRange is None: pointing_ID = pStart + ipoint else: pointing_ID = pRange[ipoint] if use_mpi: if i % num_thread != ind_thread: continue pid = os.getpid() sub_img_dir, prefix = makeSubDir_PointingList(path_dict=self.path_dict, config=self.config, pointing_ID=pointing_ID) # chip = self.chip_list[ichip] # filt = self.filter_list[ichip] chip = run_chips[ichip] filt = run_filts[ichip] print("running pointing#%d, chip#%d, at PID#%d..."%(pointing_ID, chip.chipID, pid), flush=True) chip_output = ChipOutput( config=self.config, focal_plane=self.focal_plane, chip=chip, filt=filt, exptime=exptime, pointing_type=pointing_type[ipoint], pointing_ID=pointing_ID, subdir=sub_img_dir, prefix=prefix) self.runOneChip( chip=chip, filt=filt, chip_output=chip_output, pointing_ID = pointing_ID, ra_cen=ra_cen[ipoint], dec_cen=dec_cen[ipoint], img_rot=img_rot, exptime=exptime, timestamp_obs=timestamp_obs[ipoint], pointing_type=pointing_type[ipoint], cat_dir=self.path_dict["cat_dir"]) print("finished running chip#%d..."%(chip.chipID), flush=True)