Skip to content
calculate_completeness_fraction.py 7.37 KiB
Newer Older
import argparse
import os
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import ascii, fits
from cross_match_catalogs import read_catalog, match_catalogs_img
from ObservationSim.Instrument import Telescope, Filter, FilterParam

VC_A = 2.99792458e+18  # speed of light: A/s

def define_options():
    parser = argparse.ArgumentParser()
    parser.add_argument('--TU_catalog', dest='TU_catalog', type=str, required=True,
                        help='path to the (injected) truth catalog')
    parser.add_argument('--source_catalog', dest='source_catalog', type=str, required=True,
                        help='path to the (extracted) injected catalog')
    parser.add_argument('--orig_catalog', dest='orig_catalog', type=str, required=True,
                        help='path to the (extracted) original catalog')
    parser.add_argument('--image', dest='image', type=str, required=True,
                        help='path to the image, used to get the header info')
    parser.add_argument('--output_dir', dest='output_dir', type=str, required=False,
                        default='./workspace', help='output path')
    return parser

def getChipFilter(chipID, filter_layout=None):
        """Return the filter index and type for a given chip #(chipID)
        """
        filter_type_list = ["nuv","u", "g", "r", "i","z","y","GU", "GV", "GI", "FGS"]
        if filter_layout is not None:
            return filter_layout[chipID][0], filter_layout[chipID][1]

        # updated configurations
        if chipID>42 or chipID<1: raise ValueError("!!! Chip ID: [1,42]")
        if chipID in [6, 15, 16, 25]:  filter_type = "y"
        if chipID in [11, 20]:         filter_type = "z"
        if chipID in [7, 24]:          filter_type = "i"
        if chipID in [14, 17]:         filter_type = "u"
        if chipID in [9, 22]:          filter_type = "r"
        if chipID in [12, 13, 18, 19]: filter_type = "nuv"
        if chipID in [8, 23]:          filter_type = "g"
        if chipID in [1, 10, 21, 30]:  filter_type = "GI"
        if chipID in [2, 5, 26, 29]:   filter_type = "GV"
        if chipID in [3, 4, 27, 28]:   filter_type = "GU"
        if chipID in range(31, 43):    filter_type = 'FGS'
        filter_id = filter_type_list.index(filter_type)

        return filter_id, filter_type

def magToFlux(mag):
    """
    flux of a given AB magnitude

    Parameters:
    mag: magnitude in unit of AB

    Return:
    flux: flux in unit of erg/s/cm^2/Hz
    """
    flux = 10**(-0.4*(mag+48.6))
    return flux

def getElectronFluxFilt(mag, filt, tel, exptime=150.):
    photonEnergy = filt.getPhotonE()
    flux = magToFlux(mag)
    factor = 1.0e4 * flux/photonEnergy * VC_A * (1.0/filt.blue_limit - 1.0/filt.red_limit)
    return factor * filt.efficiency * tel.pupil_area * exptime

def convert_catalog(catname):
    data_dir = os.path.dirname(catname)
    base_name = os.path.basename(catname)
    text_file = ascii.read(catname)
    fits_filename = os.path.join(data_dir, base_name + '.fits')
    text_file.write(fits_filename, overwrite=True)

def validation_hist(val, idx, name="val", nbins=10, fig_name='detected_counts.png', output_dir='./'):
    counts, bins = np.histogram(val, bins=nbins)
    is_empty = np.full(len(val), False)
    for i in range(len(idx)):
        if idx[i].size == 0:
            is_empty[i] = True
    counts_detected, _ = np.histogram(val[~is_empty], bins=nbins)
    plt.figure()
    plt.stairs(counts, bins, color='r', label='TU objects')
    plt.stairs(counts_detected, bins, color='g', label='Detected')
    plt.xlabel(name, size='x-large')
    plt.title("Counts")
    plt.legend(loc='upper right', fancybox=True)
    fig_name = os.path.join(output_dir, fig_name)
    plt.savefig(fig_name)
    return counts, bins

def hist_fraction(val, idx, name='val', nbins=10, normed=False, output_dir='./'):
    counts, bins = np.histogram(val, bins=nbins)
    is_empty = np.full(len(val), False)
    for i in range(len(idx)):
        if idx[i].size == 0:
            is_empty[i] = True
    counts_detected, _ = np.histogram(val[~is_empty], bins=nbins, density=normed)
    fraction = counts_detected / counts
    fraction[np.where(np.isnan(fraction))[0]] = 0.
    plt.figure()
    plt.stairs(fraction, bins, color='r', label='completeness fraction')
    plt.xlabel(name, size='x-large')
    plt.title("Completeness Fraction")
    fig_name = os.path.join(output_dir, "completeness_fraction_%s.png"%(name))
    plt.savefig(fig_name)
    return fraction

def calculate_fraction(TU_catalog, source_catalog, output_dir, nbins=10):
    convert_catalog(TU_catalog)
    x_TU, y_TU, col_list = read_catalog(TU_catalog + '.fits', ext_num=1, ra_name="xImage", dec_name="yImage", col_list=["mag"])
    mag_TU = col_list[0]
    x_source, y_source, _ = read_catalog(source_catalog, ext_num=1, ra_name="X_IMAGE", dec_name="Y_IMAGE")
    idx1, idx2, = match_catalogs_img(x1=x_TU, y1=y_TU, x2=x_source, y2=y_source)
    counts, bins = validation_hist(val=mag_TU, idx=idx1, name="mag_injected", output_dir=output_dir)
    fraction = hist_fraction(val=mag_TU, idx=idx1, name="mag_injected", nbins=10, output_dir=output_dir)
    return counts, bins, fraction

def calculate_undetected_flux(orig_cat, mag_bins, fraction, mag_low=20.0, mag_high=26.0, image=None,  output_dir='./'):
    # Get info from original image
    hdu = fits.open(image)
    header0 = hdu[0].header
    header1 = hdu[1].header
    nx_pix, ny_pix = header0["PIXSIZE1"], header0["PIXSIZE2"]
    exp_time = header0["EXPTIME"]
    gain = header1["GAIN1"]
    chipID = int(header0["DETECTOR"][-2:])
    zp = header1["ZP"]
    hdu.close()

    # Get info from original catalog
    ra_orig, dec_orig, col_list_orig = read_catalog(orig_cat, ra_name='RA', dec_name='DEC', col_list=['Mag_Kron'])
    mag_orig = col_list_orig[0]
    nbins = len(mag_bins) - 1
    counts, _ = np.histogram(mag_orig, bins=nbins)
    
    mags = (mag_bins[:-1] + mag_bins[1:])/2.
    counts_missing = (counts / fraction) - counts
    counts_missing[np.where(np.isnan(counts_missing))[0]] = 0.
    counts_missing[np.where(np.isinf(counts_missing))[0]] = 0.
    print(counts_missing)
    print(counts_missing.sum())

    plt.figure()
    plt.stairs(counts_missing, mag_bins, color='r', label='undetected counts')
    plt.xlabel("mag_injected", size='x-large')
    plt.title("Undetected Sources")
    fig_name = os.path.join(output_dir, "undetected_sources.png")
    plt.savefig(fig_name)

    tel = Telescope()
    filter_param = FilterParam()
    filter_id, filter_type = getChipFilter(chipID=chipID)

    filt = Filter(filter_id=filter_id,
                filter_type=filter_type,
                filter_param=filter_param)
    
    undetected_flux = 0.
    for i in range(len(mags)):
        if mags[i] < mag_low or mags[i] > mag_high:
            continue
        flux_electrons = counts_missing[i] * getElectronFluxFilt(mag=mags[i], filt=filt, tel=tel)
        undetected_flux += flux_electrons

    undetected_flux /= (float(nx_pix) * float(ny_pix))
    return undetected_flux

if __name__ == "__main__":
    args = define_options().parse_args()
    counts, bins, fraction = calculate_fraction(
        TU_catalog=args.TU_catalog,
        source_catalog=args.source_catalog,
        output_dir=args.output_dir,
        nbins=20
    )
    undetected_flux = calculate_undetected_flux(
        orig_cat=args.orig_catalog,
        mag_bins=bins,
        fraction=fraction,
        image=args.image,
        output_dir=args.output_dir,
    )
    print(undetected_flux)