calculate_completeness_fraction.py 7.37 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import argparse
import os
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import ascii, fits
from cross_match_catalogs import read_catalog, match_catalogs_img
from ObservationSim.Instrument import Telescope, Filter, FilterParam

VC_A = 2.99792458e+18  # speed of light: A/s

def define_options():
    parser = argparse.ArgumentParser()
    parser.add_argument('--TU_catalog', dest='TU_catalog', type=str, required=True,
                        help='path to the (injected) truth catalog')
    parser.add_argument('--source_catalog', dest='source_catalog', type=str, required=True,
                        help='path to the (extracted) injected catalog')
    parser.add_argument('--orig_catalog', dest='orig_catalog', type=str, required=True,
                        help='path to the (extracted) original catalog')
    parser.add_argument('--image', dest='image', type=str, required=True,
                        help='path to the image, used to get the header info')
    parser.add_argument('--output_dir', dest='output_dir', type=str, required=False,
                        default='./workspace', help='output path')
    return parser

def getChipFilter(chipID, filter_layout=None):
        """Return the filter index and type for a given chip #(chipID)
        """
        filter_type_list = ["nuv","u", "g", "r", "i","z","y","GU", "GV", "GI", "FGS"]
        if filter_layout is not None:
            return filter_layout[chipID][0], filter_layout[chipID][1]

        # updated configurations
        if chipID>42 or chipID<1: raise ValueError("!!! Chip ID: [1,42]")
        if chipID in [6, 15, 16, 25]:  filter_type = "y"
        if chipID in [11, 20]:         filter_type = "z"
        if chipID in [7, 24]:          filter_type = "i"
        if chipID in [14, 17]:         filter_type = "u"
        if chipID in [9, 22]:          filter_type = "r"
        if chipID in [12, 13, 18, 19]: filter_type = "nuv"
        if chipID in [8, 23]:          filter_type = "g"
        if chipID in [1, 10, 21, 30]:  filter_type = "GI"
        if chipID in [2, 5, 26, 29]:   filter_type = "GV"
        if chipID in [3, 4, 27, 28]:   filter_type = "GU"
        if chipID in range(31, 43):    filter_type = 'FGS'
        filter_id = filter_type_list.index(filter_type)

        return filter_id, filter_type

def magToFlux(mag):
    """
    flux of a given AB magnitude

    Parameters:
    mag: magnitude in unit of AB

    Return:
    flux: flux in unit of erg/s/cm^2/Hz
    """
    flux = 10**(-0.4*(mag+48.6))
    return flux

def getElectronFluxFilt(mag, filt, tel, exptime=150.):
    photonEnergy = filt.getPhotonE()
    flux = magToFlux(mag)
    factor = 1.0e4 * flux/photonEnergy * VC_A * (1.0/filt.blue_limit - 1.0/filt.red_limit)
    return factor * filt.efficiency * tel.pupil_area * exptime

def convert_catalog(catname):
    data_dir = os.path.dirname(catname)
    base_name = os.path.basename(catname)
    text_file = ascii.read(catname)
    fits_filename = os.path.join(data_dir, base_name + '.fits')
    text_file.write(fits_filename, overwrite=True)

def validation_hist(val, idx, name="val", nbins=10, fig_name='detected_counts.png', output_dir='./'):
    counts, bins = np.histogram(val, bins=nbins)
    is_empty = np.full(len(val), False)
    for i in range(len(idx)):
        if idx[i].size == 0:
            is_empty[i] = True
    counts_detected, _ = np.histogram(val[~is_empty], bins=nbins)
    plt.figure()
    plt.stairs(counts, bins, color='r', label='TU objects')
    plt.stairs(counts_detected, bins, color='g', label='Detected')
    plt.xlabel(name, size='x-large')
    plt.title("Counts")
    plt.legend(loc='upper right', fancybox=True)
    fig_name = os.path.join(output_dir, fig_name)
    plt.savefig(fig_name)
    return counts, bins

def hist_fraction(val, idx, name='val', nbins=10, normed=False, output_dir='./'):
    counts, bins = np.histogram(val, bins=nbins)
    is_empty = np.full(len(val), False)
    for i in range(len(idx)):
        if idx[i].size == 0:
            is_empty[i] = True
    counts_detected, _ = np.histogram(val[~is_empty], bins=nbins, density=normed)
    fraction = counts_detected / counts
    fraction[np.where(np.isnan(fraction))[0]] = 0.
    plt.figure()
    plt.stairs(fraction, bins, color='r', label='completeness fraction')
    plt.xlabel(name, size='x-large')
    plt.title("Completeness Fraction")
    fig_name = os.path.join(output_dir, "completeness_fraction_%s.png"%(name))
    plt.savefig(fig_name)
    return fraction

def calculate_fraction(TU_catalog, source_catalog, output_dir, nbins=10):
    convert_catalog(TU_catalog)
    x_TU, y_TU, col_list = read_catalog(TU_catalog + '.fits', ext_num=1, ra_name="xImage", dec_name="yImage", col_list=["mag"])
    mag_TU = col_list[0]
    x_source, y_source, _ = read_catalog(source_catalog, ext_num=1, ra_name="X_IMAGE", dec_name="Y_IMAGE")
    idx1, idx2, = match_catalogs_img(x1=x_TU, y1=y_TU, x2=x_source, y2=y_source)
    counts, bins = validation_hist(val=mag_TU, idx=idx1, name="mag_injected", output_dir=output_dir)
    fraction = hist_fraction(val=mag_TU, idx=idx1, name="mag_injected", nbins=10, output_dir=output_dir)
    return counts, bins, fraction

def calculate_undetected_flux(orig_cat, mag_bins, fraction, mag_low=20.0, mag_high=26.0, image=None,  output_dir='./'):
    # Get info from original image
    hdu = fits.open(image)
    header0 = hdu[0].header
    header1 = hdu[1].header
    nx_pix, ny_pix = header0["PIXSIZE1"], header0["PIXSIZE2"]
    exp_time = header0["EXPTIME"]
    gain = header1["GAIN1"]
    chipID = int(header0["DETECTOR"][-2:])
    zp = header1["ZP"]
    hdu.close()

    # Get info from original catalog
    ra_orig, dec_orig, col_list_orig = read_catalog(orig_cat, ra_name='RA', dec_name='DEC', col_list=['Mag_Kron'])
    mag_orig = col_list_orig[0]
    nbins = len(mag_bins) - 1
    counts, _ = np.histogram(mag_orig, bins=nbins)
    
    mags = (mag_bins[:-1] + mag_bins[1:])/2.
    counts_missing = (counts / fraction) - counts
    counts_missing[np.where(np.isnan(counts_missing))[0]] = 0.
    counts_missing[np.where(np.isinf(counts_missing))[0]] = 0.
    print(counts_missing)
    print(counts_missing.sum())

    plt.figure()
    plt.stairs(counts_missing, mag_bins, color='r', label='undetected counts')
    plt.xlabel("mag_injected", size='x-large')
    plt.title("Undetected Sources")
    fig_name = os.path.join(output_dir, "undetected_sources.png")
    plt.savefig(fig_name)

    tel = Telescope()
    filter_param = FilterParam()
    filter_id, filter_type = getChipFilter(chipID=chipID)

    filt = Filter(filter_id=filter_id,
                filter_type=filter_type,
                filter_param=filter_param)
    
    undetected_flux = 0.
    for i in range(len(mags)):
        if mags[i] < mag_low or mags[i] > mag_high:
            continue
        flux_electrons = counts_missing[i] * getElectronFluxFilt(mag=mags[i], filt=filt, tel=tel)
        undetected_flux += flux_electrons

    undetected_flux /= (float(nx_pix) * float(ny_pix))
    return undetected_flux

if __name__ == "__main__":
    args = define_options().parse_args()
    counts, bins, fraction = calculate_fraction(
        TU_catalog=args.TU_catalog,
        source_catalog=args.source_catalog,
        output_dir=args.output_dir,
        nbins=20
    )
    undetected_flux = calculate_undetected_flux(
        orig_cat=args.orig_catalog,
        mag_bins=bins,
        fraction=fraction,
        image=args.image,
        output_dir=args.output_dir,
    )
    print(undetected_flux)