Skip to content
evaluation_utils.py 8.62 KiB
Newer Older
from astropy.wcs import WCS
from astropy.io import ascii, fits
from matplotlib.colors import LogNorm
from scipy.stats import binned_statistic
from astropy.visualization import ZScaleInterval

import os
import numpy as np
import matplotlib.pyplot as plt
from cross_match_catalogs import match_catalogs_img


def plot_injection_comparison(orig_img, injected_img, flg_img=None, save_fig_dir=None, fig_prefix="", figsize=(12, 8)):
    z = ZScaleInterval()

    plt.figure(figsize=figsize, dpi=100)
    hdu_orig = fits.open(orig_img)[1]
    data_orig = hdu_orig.data
    if flg_img is not None:
        flg_data = fits.getdata(flg_img)
        data_orig[flg_data > 0] = 0.
    wcs = WCS(hdu_orig.header)
    plt.subplot(projection=wcs)
    z1, z2 = z.get_limits(data_orig)
    plt.imshow(data_orig, origin='lower', cmap='gray', vmin=z1, vmax=z2)
    plt.grid(color='white', ls='solid')
    if save_fig_dir is not None:
        output_filename = fig_prefix + "original_img.png"
        output_img_path = os.path.join(save_fig_dir, output_filename)
        plt.savefig(output_img_path)
    plt.show()

    plt.figure(figsize=figsize, dpi=100)
    hdu_inj = fits.open(injected_img)[1]
    data_inj = hdu_inj.data
    if flg_img is not None:
        data_inj[flg_data > 0] = 0.
    wcs = WCS(hdu_inj.header)
    plt.subplot(projection=wcs)
    z1, z2 = z.get_limits(data_inj)
    plt.imshow(data_inj, origin='lower', cmap='gray', vmin=z1, vmax=z2)
    plt.grid(color='white', ls='solid')
    if save_fig_dir is not None:
        output_filename = fig_prefix + "injected_img.png"
        output_img_path = os.path.join(save_fig_dir, output_filename)
        plt.savefig(output_img_path)
    plt.show()

    plt.figure(figsize=figsize, dpi=100)
    plt.subplot(projection=wcs)
    img_diff = data_inj - data_orig
    z1 = 0.
    z2 = 0.001
    plt.imshow(img_diff, origin='lower', cmap='gray', vmin=z1, vmax=z2)
    plt.grid(color='white', ls='solid')
    if save_fig_dir is not None:
        output_filename = fig_prefix + "diff_img.png"
        output_img_path = os.path.join(save_fig_dir, output_filename)
        plt.savefig(output_img_path)
    plt.show()


def plot_ensemble_hist(cat_path_list, column_name="Mag_Kron", column_unit="mag", title="Total KRON MAG distribution", save_fig_dir=None, fig_prefix="",
                       nbins=50, low=16., high=28., density=False):
    values = []
    bins = np.linspace(low, high, nbins+1)
    for cat_path in cat_path_list:
        if cat_path.endswith(".fits"):
            hdu = fits.open(cat_path)
            value_temp = hdu[1].data[column_name]
        elif cat_path.endswith(".cat"):
            data = ascii.read(cat_path)
            value_temp = data[column_name]
        print("number of objects in %s: %d" %
              (os.path.basename(cat_path), len(value_temp)))
        values = np.append(values, value_temp)

    plt.figure()
    plt.hist(values, bins=bins, density=density)
    plt.xlabel(column_name + '/' + column_unit, size='x-large')
    if density is False:
        plt.ylabel("Counts", size='x-large')
    plt.title(title, size='x-large')
    if save_fig_dir is not None:
        output_filename = fig_prefix + "%s_ensemble_hist.png" % (column_name)
        output_img_path = os.path.join(save_fig_dir, output_filename)
        plt.savefig(output_img_path)
    plt.show()


def create_hist_figure(counts, counts_detected, bins, name="val", output_dir='./', fig_name='detected_counts.png', save_figure=False, title=None):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_xlabel(name, size='x-large')
    ax.set_ylabel("Counts", size='x-large')
    if title is not None:
        ax.set_title(title, size='x-large')
    ax.stairs(counts, bins, color='r', label='TU objects')
    ax.stairs(counts_detected, bins, color='g', label='Detected')
    ax.legend(loc='upper right', fancybox=True)
    if save_figure:
        fig_name = os.path.join(output_dir, fig_name)
        fig.savefig(fig_name)
    return fig, ax


def create_fraction_figure(counts, counts_detected, bins, name='val', output_dir='./', fig_name="completeness_fraction.png",
                           save_figure=False, title=None, figure=None, color='r', label='patch_1', show_legend=False):
    fraction = counts_detected / counts
    fraction[np.where(np.isnan(fraction))[0]] = 0.
    if figure is not None:
        fig = figure
        ax = fig.axes[0]
        ax.stairs(fraction, bins, color=color, label=label)
        if title is not None:
            ax.set_title(title, size='x-large')
    else:
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.stairs(fraction, bins, color=color, label=label)
        ax.set_xlabel(name, size='x-large')
        if title is not None:
            ax.set_title(title, size='x-large')
        else:
            ax.set_title("Completeness Fraction")
    if show_legend:
        ax.legend(loc='upper right', fancybox=True)
    if save_figure:
        fig_name = os.path.join(output_dir, fig_name)
        fig.savefig(fig_name)
    return fig, ax, fraction


def validation_hist(val, idx, name="val", nbins=10, bins=None, fig_name='detected_counts.png', output_dir='./', create_figure=True):
    if bins is None:
        counts, bins = np.histogram(val, bins=nbins)
    else:
        counts, bins = np.histogram(val, bins=bins)
    is_empty = np.full(len(val), False)
    for i in range(len(idx)):
        if idx[i].size == 0:
            is_empty[i] = True
    if bins is None:
        counts_detected, _ = np.histogram(val[~is_empty], bins=nbins)
    else:
        counts_detected, _ = np.histogram(val[~is_empty], bins=bins)
    if create_figure:
        create_hist_figure(counts, counts_detected, bins,
                           name, output_dir, fig_name)
    return counts, counts_detected, bins


def plot_mag_comparison(truth_cat_list, measured_cat_root_dir, mag1_name="mag", mag2_name="Mag_Kron", save_fig_dir=None, fig_prefix="",
                        nbins=20, low=18., high=26., ylim=[-1., 1.], title=None):
    diff_list = []
    truth_list = []
    bins = np.linspace(low, high, nbins+1)

    for cat_path_truth in truth_cat_list:
        print("Injected truth catalog: ",
              os.path.basename(cat_path_truth))
        obs_id = cat_path_truth.split('/')[-2]

        # Read truth catalog
        data = ascii.read(cat_path_truth)
        x_truth = data["xImage"]
        y_truth = data["yImage"]
        mag_truth = data[mag1_name]

        # Read measured catalog
        cat_path_measured = os.path.join(measured_cat_root_dir,
                                         obs_id,
                                         os.path.basename(cat_path_truth).replace("img", "cat").replace(".cat", ".fits"))
        print("L1 processed photometry catalog: ",
              os.path.basename(cat_path_truth))
        hdu = fits.open(cat_path_measured)
        x_measure = hdu[1].data["X"]
        y_measure = hdu[1].data["Y"]
        mag_measure = hdu[1].data["Mag_Kron"]

        # Match measured objects vs truth
        idx1, _, = match_catalogs_img(
            x1=x_truth, y1=y_truth, x2=x_measure, y2=y_measure)

        for i in range(len(idx1)):
            if idx1[i].size == 0:
                continue
            else:
                diff_list.append(mag_measure[idx1[i][0]] - mag_truth[i])
                truth_list.append(mag_truth[i])

    bin_means, bin_edges, binnumber = binned_statistic(truth_list, diff_list, 'mean',
                                                       bins=nbins, range=[low, high])
    bin_median, bin_edges, binnumber = binned_statistic(truth_list, diff_list, 'median',
                                                        bins=nbins, range=[low, high])
    bin_std, bin_edges, binnumber = binned_statistic(truth_list, diff_list, 'std',
                                                     bins=nbins, range=[low, high])
    bin_width = (bin_edges[1] - bin_edges[0])
    bin_centers = bin_edges[1:] - bin_width/2

    plt.figure()
    plt.plot(truth_list, diff_list, 'ro', alpha=0.1)
    plt.axhline(y=0., color='k', alpha=0.6)
    plt.plot(bin_centers, bin_median, '--',
             label=r'$\rm{median}\ \Delta mag$', alpha=0.6)
    plt.errorbar(bin_centers, bin_means, yerr=bin_std, fmt='bo', capsize=2,
                 label=r'$\rm{mean}\ \Delta mag$', alpha=0.6)
    plt.xlim([low, high])
    plt.ylim(ylim)
    plt.xlabel("True mag", size='x-large')
    plt.ylabel("Measured (Kron) - True mag", size='x-large')
    plt.legend(loc='upper left', fancybox=True)
    if title is not None:
        plt.title(title, size='x-large')

    if save_fig_dir is not None:
        output_filename = fig_prefix + "measured-true_mag.png"
        output_img_path = os.path.join(save_fig_dir, output_filename)
        plt.savefig(output_img_path)
    plt.show()