Newer
Older
import sys
import os
from ObservationSim.Config import ConfigDir, ChipOutput
from ObservationSim.Config.Header import generatePrimaryHeader, generateExtensionHeader
from ObservationSim.Instrument import Telescope, Filter, FilterParam, FocalPlane, Chip
from ObservationSim.MockObject import calculateSkyMap_split_g
from ObservationSim.PSF import PSFGauss, FieldDistortion, PSFInterp
from ObservationSim._util import getShearFiled, makeSubDir_PointingList
import numpy as np
import mpi4py.MPI as MPI
import galsim
import logging
import psutil
class Observation(object):
def __init__(self, config, Catalog, work_dir=None, data_dir=None):
self.path_dict = ConfigDir(config=config, work_dir=work_dir, data_dir=data_dir)
# self.config = ReadConfig(self.path_dict["config_file"])
self.config = config
self.tel = Telescope(optEffCurve_path=self.path_dict["mirror_file"]) # Currently the default values are hard coded in
self.focal_plane = FocalPlane(survey_type=self.config["obs_setting"]["survey_type"]) # Currently the default values are hard coded in
self.filter_param = FilterParam(filter_dir=self.path_dict["filter_dir"]) # Currently the default values are hard coded in
self.chip_list = []
self.filter_list = []
self.fd_model = FieldDistortion(fdModel_path=self.path_dict["fd_path"])
else:
self.fd_model = None
# Construct chips & filters:
nchips = self.focal_plane.nchip_x*self.focal_plane.nchip_y
for i in range(nchips):
chipID = i + 1
if self.focal_plane.isIgnored(chipID=chipID):
continue
# Make Chip & Filter lists
chip = Chip(chipID, ccdEffCurve_dir=self.path_dict["ccd_dir"], CRdata_dir=self.path_dict["CRdata_dir"], normalize_dir=self.path_dict["normalize_dir"], sls_dir=self.path_dict["sls_dir"], config=self.config) # currently there is no config file for chips
filter_id, filter_type = chip.getChipFilter()
filt = Filter(filter_id=filter_id, filter_type=filter_type, filter_param=self.filter_param, ccd_bandpass=chip.effCurve)
self.chip_list.append(chip)
self.filter_list.append(filt)
# Read catalog and shear(s)
self.g1_field, self.g2_field, self.nshear = getShearFiled(config=self.config)
def runOneChip(self, chip, filt, chip_output, wcs_fp=None, psf_model=None, pointing_ID=0, ra_cen=None, dec_cen=None, img_rot=None, exptime=150., timestamp_obs=1621915200, pointing_type="MS", shear_cat_file=None, cat_dir=None, sed_dir=None):
ra_cen = self.config["obs_setting"]["ra_center"]
dec_cen = self.config["obs_setting"]["dec_center"]
psf_model = PSFInterp(chip=chip, PSF_data_file=self.path_dict["psf_dir"])
else:
print("unrecognized PSF model type!!", flush=True)
# Get (extra) shear fields
if shear_cat_file is not None:
self.g1_field, self.g2_field, self.nshear = getShearFiled(config=self.config, shear_cat_file=shear_cat_file)
# Get WCS for the focal plane
if wcs_fp == None:
wcs_fp = self.focal_plane.getTanWCS(ra_cen, dec_cen, img_rot, chip.pix_scale)
# Create chip Image
chip.img = galsim.ImageF(chip.npix_x, chip.npix_y)
chip.img.setOrigin(chip.bound.xmin, chip.bound.ymin)
chip.img.wcs = wcs_fp
if chip.survey_type == "photometric":
sky_map = None
elif chip.survey_type == "spectroscopic":
sky_map = calculateSkyMap_split_g(xLen=chip.npix_x, yLen=chip.npix_y, blueLimit=filt.blue_limit, redLimit=filt.red_limit, skyfn=self.path_dict["sky_file"], conf=chip.sls_conf, pixelSize=chip.pix_scale, isAlongY=0)
if pointing_type == 'MS':
# Load catalogues and templates
self.cat = self.Catalog(config=self.config, chip=chip, cat_dir=cat_dir, sed_dir=sed_dir)
self.nobj = len(self.cat.objs)
# Loop over objects
missed_obj = 0
bright_obj = 0
dim_obj = 0
for j in range(self.nobj):
if obj.type == 'star' and self.config["galaxy_only"]:
continue
elif obj.type == 'galaxy' and self.config["star_only"]:
continue
elif obj.type == 'quasar' and self.config["star_only"]:
continue
# load SED
try:
sed_data = self.cat.load_sed(obj)
norm_filt = self.cat.load_norm_filt(obj)
obj.sed, obj.param["mag_%s"%filt.filter_type] = self.cat.convert_sed(
mag=obj.param["mag_use_normal"],
sed=sed_data,
target_filt=filt,
norm_filt=norm_filt,
)
except Exception as e:
print(e)
continue
# Exclude very bright/dim objects (for now)
if filt.is_too_bright(mag=obj.getMagFilter(filt)):
# print("obj too birght!!", flush=True)
if obj.type != 'galaxy':
bright_obj += 1
obj.unload_SED()
continue
if filt.is_too_dim(mag=obj.getMagFilter(filt)):
# print("obj too dim!!", flush=True)
dim_obj += 1
if self.config["shear_setting"]["shear_type"] == "constant":
if obj.type == 'star':
g1, g2 = 0, 0
else:
g1, g2 = self.g1_field, self.g2_field
elif self.config["shear_setting"]["shear_type"] == "extra":
try:
# TODO: every object with individual shear from input catalog(s)
g1, g2 = self.g1_field[j], self.g2_field[j]
except:
print("failed to load external shear.")
pass
elif self.config["shear_setting"]["shear_type"] == "catalog":
pass
else:
raise ValueError("Unknown shear input")
pos_img, offset, local_wcs = obj.getPosImg_Offset_WCS(img=chip.img, fdmodel=self.fd_model, chip=chip, verbose=False)
if pos_img.x == -1 or pos_img.y == -1:
# Exclude object which is outside the chip area (after field distortion)
# print("obj missed!!")
missed_obj += 1
obj.unload_SED()
continue
if self.config["out_cat_only"]:
isUpdated = True
if chip.survey_type == "photometric" and not self.config["out_cat_only"]:
isUpdated, pos_shear = obj.drawObj_multiband(
tel=self.tel,
pos_img=pos_img,
psf_model=psf_model,
bandpass_list=filt.bandpass_sub_list,
filt=filt,
chip=chip,
g1=g1,
g2=g2,
exptime=exptime)
elif chip.survey_type == "spectroscopic" and not self.config["out_cat_only"]:
isUpdated, pos_shear = obj.drawObj_slitless(
tel=self.tel,
pos_img=pos_img,
psf_model=psf_model,
bandpass_list=filt.bandpass_sub_list,
filt=filt,
chip=chip,
g1=g1,
g2=g2,
exptime=exptime,
# normFilter=normF,
normFilter=norm_filt,
)
if isUpdated:
# TODO: add up stats
chip_output.cat_add_obj(obj, pos_img, pos_shear, g1, g2)
pass
else:
# print("object omitted", flush=True)
continue
except Exception as e:
print(e)
print("check running:1: pointing-{:} chip-{:} pid-{:} memory-{:6.2}GB".format(pointing_ID, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ), flush=True)
# Detector Effects
# ===========================================================
# whether to output zero, dark, flat calibration images.
chip.img = chip.addEffects(
config=self.config,
img=chip.img,
chip_output=chip_output,
filt=filt,
ra_cen=ra_cen,
dec_cen=dec_cen,
img_rot=img_rot,
pointing_ID=pointing_ID,
timestamp_obs=timestamp_obs,
pointing_type=pointing_type,
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
if pointing_type == 'MS':
datetime_obs = datetime.fromtimestamp(timestamp_obs)
date_obs = datetime_obs.strftime("%y%m%d")
time_obs = datetime_obs.strftime("%H%M%S")
h_prim = generatePrimaryHeader(
xlen=chip.npix_x,
ylen=chip.npix_y,
pointNum = str(pointing_ID),
ra=ra_cen,
dec=dec_cen,
psize=chip.pix_scale,
row_num=chip.rowID,
col_num=chip.colID,
# date=self.config["date_obs"],
# time_obs=self.config["time_obs"],
date=date_obs,
time_obs=time_obs,
im_type='MS')
h_ext = generateExtensionHeader(
xlen=chip.npix_x,
ylen=chip.npix_y,
ra=ra_cen,
dec=dec_cen,
pa=img_rot.deg,
gain=chip.gain,
readout=chip.read_noise,
dark=chip.dark_noise,
saturation=90000,
psize=chip.pix_scale,
row_num=chip.rowID,
col_num=chip.colID,
extName='raw')
chip.img = galsim.Image(chip.img.array, dtype=np.uint16)
# chip.img = galsim.Image(chip.img.array, dtype=np.uint32)
hdu1 = fits.PrimaryHDU(header=h_prim)
hdu2 = fits.ImageHDU(chip.img.array, header=h_ext)
hdu1 = fits.HDUList([hdu1, hdu2])
fname = os.path.join(chip_output.subdir, h_prim['FILENAME'] + '.fits')
hdu1.writeto(fname, output_verify='ignore', overwrite=True)
print("# objects that are too bright %d out of %d"%(bright_obj, self.nobj))
print("# objects that are too dim %d out of %d"%(dim_obj, self.nobj))
print("# objects that are missed %d out of %d"%(missed_obj, self.nobj))
del chip.img
print("check running:2: pointing-{:} chip-{:} pid-{:} memory-{:6.2}GB".format(pointing_ID, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ), flush=True)
def runExposure_MPI_PointingList(self, ra_cen=None, dec_cen=None, pRange=None, timestamp_obs=np.array([1621915200]), pointing_type=np.array(['MS']), img_rot=None, exptime=150., shear_cat_file=None, chips=None, use_mpi=False):
if use_mpi:
comm = MPI.COMM_WORLD
ind_thread = comm.Get_rank()
num_thread = comm.Get_size()
if chips is None:
nchips_per_fp = len(self.chip_list)
run_chips = self.chip_list
run_filts = self.filter_list
else:
# Only run a particular set of chips
run_chips = []
run_filts = []
nchips_per_fp = len(chips)
for ichip in range(len(self.chip_list)):
chip = self.chip_list[ichip]
filt = self.filter_list[ichip]
if chip.chipID in chips:
run_chips.append(chip)
run_filts.append(filt)
# TEMP
if len(timestamp_obs) == 1:
timestamp_obs = np.tile(timestamp_obs, len(ra_cen))
pointing_type = np.tile(pointing_type, len(ra_cen))
if pRange is not None:
timestamp_obs = timestamp_obs[pRange]
pointing_type = pointing_type[pRange]
ra_cen = ra_cen[pRange]
dec_cen = dec_cen[pRange]
# The Starting pointing ID
if pRange is not None:
pStart = pRange[0]
else:
pStart = 0
for ipoint in range(len(ra_cen)):
for ichip in range(nchips_per_fp):
i = ipoint*nchips_per_fp + ichip
if pRange is None:
pointing_ID = pStart + ipoint
else:
pointing_ID = pRange[ipoint]
if use_mpi:
if i % num_thread != ind_thread:
continue
pid = os.getpid()
sub_img_dir, prefix = makeSubDir_PointingList(path_dict=self.path_dict, config=self.config, pointing_ID=pointing_ID)
# chip = self.chip_list[ichip]
# filt = self.filter_list[ichip]
chip = run_chips[ichip]
filt = run_filts[ichip]
print("running pointing#%d, chip#%d, at PID#%d..."%(pointing_ID, chip.chipID, pid), flush=True)
chip_output = ChipOutput(
config=self.config,
focal_plane=self.focal_plane,
chip=chip,
filt=filt,
exptime=exptime,
pointing_ID=pointing_ID,
subdir=sub_img_dir,
prefix=prefix)
self.runOneChip(
chip=chip,
filt=filt,
chip_output=chip_output,
pointing_ID = pointing_ID,
ra_cen=ra_cen[ipoint],
dec_cen=dec_cen[ipoint],
img_rot=img_rot,
exptime=exptime,
timestamp_obs=timestamp_obs[ipoint],
pointing_type=pointing_type[ipoint],
print("finished running chip#%d..."%(chip.chipID), flush=True)