Skip to content
SingleEpochImage.py 7.58 KiB
Newer Older
Fang Yuedong's avatar
Fang Yuedong committed
import numpy as np
import copy
from astropy import wcs
from astropy.io import fits
import galsim

from ObservationSim.Instrument import Chip, Filter, FilterParam, FocalPlane, Telescope
from ObservationSim.PSF import PSFGauss, PSFInterp

class SingleEpochImage(object):
    
    def __init__(self, config, filepath):
        self.header0, self.header_img, self.img = self.read_initial_image(filepath)
        self._get_wcs(self.header_img)
        self._determine_unique_area(config)
        self.output_img_fname = config['output_img_name']

        if config['n_objects'] is not None:
            # Fixed number of objects per image
            self.objs_per_real = config['n_objects']
        elif config['object_density'] is not None:
            # Fixed number density of objects
            self.objs_per_real = round(self.u_area * config['object_density'])
        else:
            # Grid types: calculate nobjects later
            self.objs_per_real = None

        self.tel = Telescope()
        # Determine which CCD
        self.chip_ID = int(self.header0['DETECTOR'][-2:])
        # Determine epxosure time
        self.exp_time = float(self.header0['EXPTIME'])
        config["obs_setting"]={}
        config["obs_setting"]["exp_time"] = self.exp_time
        # Construnct Chip object
        self.chip = Chip(chipID=self.chip_ID, config=config)
        # Load PSF model
        if config["psf_setting"]["psf_model"] == "Gauss":
            self.psf_model = PSFGauss(chip=self.chip)
        elif config["psf_setting"]["psf_model"] == "Interp":
            self.psf_model = PSFInterp(chip=self.chip, PSF_data_file=config["psf_setting"]["psf_dir"])

        filter_id, filter_type = self.chip.getChipFilter()
        filter_param = FilterParam()
        self.filt = Filter(filter_id=filter_id,
                        filter_type=filter_type,
                        filter_param=filter_param)
        self.focal_plane = FocalPlane()
        
        self.setup_image_for_injection()

    def setup_image_for_injection(self):
        ra_cen = self.wcs.wcs.crval[0]
        dec_cen = self.wcs.wcs.crval[1]
        self.wcs_fp = self.focal_plane.getTanWCS(ra_cen, dec_cen, self.pos_ang*galsim.degrees, self.pixel_scale)
        # self.inj_img = galsim.ImageF(self.chip.npix_x, self.chip.npix_y)
        self.chip.img = galsim.Image(self.img, copy=True)
        self.chip.img.setOrigin(self.chip.bound.xmin, self.chip.bound.ymin)
        self.chip.img.wcs = self.wcs_fp
        print(self.chip.img.array)

    def read_initial_image(self, filepath):
        data = fits.open(filepath)
        header0 = data[0].header
        header1 = data[1].header
        image = fits.getdata(filepath)

        # (TEMP)
        image = np.float64(image)
        image *= 1.1
        image -= 500.

        temp_img = galsim.Image(image, copy=True)

        temp_img.array[temp_img.array > 65535] = 65535
        temp_img.replaceNegative(replace_value=0)
        temp_img.quantize()
        temp_img = galsim.Image(temp_img.array, dtype=np.uint16)
        # self.chip.img = galsim.Image(self.chip.img.array, dtype=np.int32)
        hdu1 = fits.PrimaryHDU(header=header0)
        hdu2 = fits.ImageHDU(temp_img.array, header=header1)
        hdu1 = fits.HDUList([hdu1, hdu2])
        fname = "nullwt_image_for_injection.fits"
        hdu1.writeto(fname, output_verify='ignore', overwrite=True)
        return header0, header1, image

    def _get_wcs(self, header):
        crpix1 = float(header['CRPIX1'])
        crpix2 = float(header['CRPIX2'])
        
        crval1 = float(header['CRVAL1'])
        crval2 = float(header['CRVAL2'])

        ctype1 = str(header['CTYPE1'])
        ctype2 = str(header['CTYPE2'])

        cd1_1 = float(header['CD1_1'])
        cd1_2 = float(header['CD1_2'])
        cd2_1 = float(header['CD2_1'])
        cd2_2 = float(header['CD2_2'])
        self.pos_ang = float(header['POS_ANG'])

        # Create WCS object
        self.wcs = wcs.WCS()
        self.wcs.wcs.crpix = [crpix1, crpix2]
        self.wcs.wcs.crval = [crval1, crval2]
        self.wcs.wcs.ctype = [ctype1, ctype2]
        self.wcs.wcs.cd = [[cd1_1, cd1_2], [cd2_1, cd2_2]]

        self.pixel_scale = 0.074
        self.Npix_x = int(header['NAXIS1'])
        self.Npix_y = int(header['NAXIS2'])

    def _determine_unique_area(self, config):
        coners = np.array([(1, 1), (1, self.Npix_y), (self.Npix_x, 1), (self.Npix_x, self.Npix_y)])
        coners = self.wcs.wcs_pix2world(coners, 1)
        ra_coners = coners[:, 0]
        dec_coners = coners[:, 1]
        self.ramin, self.ramax = min(ra_coners), max(ra_coners)
        self.decmin, self.decmax = min(dec_coners), max(dec_coners)

        if self.ramax - self.ramin > 1.:
            self.ra_boundary_cross = True
        else:
            self.ra_boundary_cross = False
        
        d1, d2 = np.deg2rad([self.decmin, self.decmax])
        r1, r2 = self.ramin, self.ramax

        if self.ra_boundary_cross:
            r2 = r2 + 360.
        
        # In deg^2
        a = (180. / np.pi) * (r2 - r1) * (np.sin(d2) - np.sin(d1))
        # Save in arcmin^2
        self.u_area = 3600. * a

    def inject_objects(self, pos, cat):
        nobj = len(pos)
        # Make sure we have enough objects to inject
        assert nobj <= len(cat.objs)
        
        for i in range(nobj):
            obj = cat.objs[i]
            try:
                sed_data = cat.load_sed(obj)
                norm_filt = cat.load_norm_filt(obj)
                obj.sed, obj.param["mag_%s"%self.filt.filter_type], obj.param["flux_%s"%self.filt.filter_type] = cat.convert_sed(
                        mag=obj.param["mag_use_normal"],
                        sed=sed_data,
                        target_filt=self.filt, 
                        norm_filt=norm_filt)
            except Exception as e:
                print(e)
                continue

            # Update object position to a point on grid
            obj.param['ra'], obj.param['dec'] = pos[i][0], pos[i][1]

            pos_img, offset, local_wcs = obj.getPosImg_Offset_WCS(img=self.chip.img)
            print(pos_img.x, pos_img.y)
            try:
                isUpdated, pos_shear = obj.drawObj_multiband(
                    tel=self.tel,
                    pos_img=pos_img, 
                    psf_model=self.psf_model, 
                    bandpass_list=self.filt.bandpass_sub_list,
                    filt=self.filt, 
                    chip=self.chip, 
                    g1=obj.g1, 
                    g2=obj.g2, 
                    exptime=self.exp_time)
                if isUpdated:
                    # TODO: add up stats
                    # print("updating output catalog...")
                    print('Updated')
                    pass
                else:
                    # print("object omitted", flush=True)
                    continue
            except Exception as e:
                print(e)
                pass
            # Unload SED:
            obj.unload_SED()
            del obj

    def save_injected_img(self):
        self.chip.img.array[self.chip.img.array > 65535] = 65535
        self.chip.img.replaceNegative(replace_value=0)
        self.chip.img.quantize()
        self.chip.img = galsim.Image(self.chip.img.array, dtype=np.uint16)
        # self.chip.img = galsim.Image(self.chip.img.array, dtype=np.int32)
        hdu1 = fits.PrimaryHDU(header=self.header0)
        hdu2 = fits.ImageHDU(self.chip.img.array, header=self.header_img)
        hdu1 = fits.HDUList([hdu1, hdu2])
        # fname = 'test_inject.fits'
        # fname = '20220621_test_injection.fits'
        fname = self.output_img_fname
        hdu1.writeto(fname, output_verify='ignore', overwrite=True)