Skip to content
aperture_noise.py 6.93 KiB
Newer Older
import argparse
import numpy as np
import os
from astropy.io import fits
from astropy.stats import sigma_clip
from astropy.stats import median_absolute_deviation
from astropy.stats import mad_std

def get_all_stats(values_arrays):
    """
    """
    stats = {}

    stats['mean']   = np.mean(values_arrays)
    stats['median'] = np.median(values_arrays)
    stats['std']    = np.std(values_arrays)
    stats['mad']    = mad_std(values_arrays)

    return stats

def define_options():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_image', dest='data_image', type=str, required=True,
                        help='Name of the data image: (default: "%(default)s"')
    parser.add_argument('--seg_image', dest='seg_image', type=str, required=True,
                        help='Name of the mask / segmentation image: (default: "%(default)s"')
    parser.add_argument('--flag_image', dest='flag_image', type=str, required=False,
                        default=None, help='Name of the flag image (default: "%(default)s"')
    parser.add_argument('--sky_image', dest='sky_image', type=str, required=False,
                        default=None, help='Name of the sky image (default: "%(default)s"')
    parser.add_argument('--aper_min', dest='aper_min', type=int, required=False,
                        default=5, help='Minimum no. of pixels at level: (default: "%(default)s"')
    parser.add_argument('--aper_max', dest='aper_max', type=int, required=False,
                        default=20, help='Maximum no. of pixels at level: (default: "%(default)s"')
    parser.add_argument('--aper_sampling', dest='aper_sampling', type=int, required=False,
                        default=1, help='Minimum no. of pixels at level: (default: "%(default)s"')
    parser.add_argument('--n_sample', dest='n_sample', type=int, required=False,
                        default=500, help='Minimum no. of pixels at level: (default: "%(default)s"')
    parser.add_argument('--out_basename', dest='out_basename', type=str, required=False,
                        default="aper", help='Base name for the output names: (default: "%(default)s"')
    parser.add_argument('--output_dir', dest='output_dir', type=str, required=False,
                        default="./workspace", help='dir path for the output : (default: "%(default)s"')
    return parser

def create_circular_mask(h, w, center=None, radius=None):
    if center is None: # use the middle of the image
        center = [int(w / 2), int(h / 2)]
    if radius is None: # use the smallest distance between the center and image walls
        radius = min(center[0], center[1], w - center[0], h - center[1])

    Y, X = np.ogrid[:h, :w]
    dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)

    mask = dist_from_center <= radius
    return mask

def sampling(image_data, seg_data, flag_data, aperture, Nsample):
    wx, wy = np.where( (image_data != 0) & (seg_data == 0) & (flag_data == 0))
    Nx, Ny = image_data.shape

    flux_average = np.zeros(Nsample)
    flux_median = np.zeros(Nsample)
    x_position = np.zeros(Nsample)
    y_position = np.zeros(Nsample)
    i = 0
    i_iter = 0
    while i < Nsample:
        if i_iter > 100*Nsample:
            print('# Not enough background pixels for image depth analysis!')
            break
        i_iter += 1

        idx = np.random.randint(len(wx))
        stmpsize = aperture+1

        if wx[idx]+stmpsize >= Nx:
            continue
        if wy[idx]+stmpsize >= Ny:
            continue

        img_stmp  = image_data[ wx[idx]:wx[idx]+stmpsize, wy[idx]:wy[idx]+stmpsize ]
        seg_stmp  = seg_data[ wx[idx]:wx[idx]+stmpsize, wy[idx]:wy[idx]+stmpsize ]
        flag_stmp = flag_data[ wx[idx]:wx[idx]+stmpsize, wy[idx]:wy[idx]+stmpsize ]

        mask = create_circular_mask(stmpsize, stmpsize, center=[stmpsize//2,stmpsize//2], radius=aperture//2)
        area = np.pi*(aperture/2)**2
        area_sum = len(mask[mask==True])
        ratio = area/area_sum

        ss = np.sum(seg_stmp[mask])
        if ss != 0:
            continue
        fs = np.sum(flag_stmp[mask])
        if fs != 0:
            continue
        flux_average[i] = np.average(img_stmp[mask])
        flux_median[i]  = np.median(img_stmp[mask])
        x_position[i] = (wx[idx]+wx[idx]+stmpsize)/2.0
        y_position[i] = (wy[idx]+wy[idx]+stmpsize)/2.0 

        i += 1

    print('Needed %i tries for %i samples!'%(i_iter, Nsample))

    return flux_average, flux_median, x_position, y_position
    
def noise_statistics_aperture(fitsname, segname, flagname=None, sky_image=None, aperture_min=1, aperture_max=10, aperture_step=1, seed=None, Nsample=100, sigma_cl=10., base_name="aper", output_dir='./'):
    f     = fits.open(fitsname)
    fseg  = fits.open(segname)
    # image_data = f[1].data
    image_data = f[1].data * f[1].header["GAIN1"]
    seg_data = fseg[0].data

    f.close()
    fseg.close()

    if flagname:
        fflag = fits.open(flagname)
        flag_data = fflag[1].data
        fflag.close()
    else:
        flag_data = np.zeros(image_data.shape)

    if sky_image:
        hdu = fits.open(sky_image)
        sky_data = hdu[0].data * hdu[0].header["GAIN1"]
        image_data -= sky_data

    if seed != None:
        np.random.seed(seed)
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    im = image_data
    im[seg_data > 0] = 0.
    hist_data = im[im != 0.].flatten()

    aperture_list = np.arange(aperture_min, aperture_max+1, aperture_step, dtype=int)
    sigma_output = np.zeros(len(aperture_list))
    mad_output   = np.zeros(len(aperture_list))
    mad_std_output = np.zeros(len(aperture_list))

    for j, aperture in enumerate(aperture_list):
        flux_average, flux_median, x_position, y_position = sampling(image_data, seg_data, flag_data, aperture, Nsample)
        
        mean_stats   = get_all_stats(flux_average)
        median_stats = get_all_stats(flux_median)
        print("Mean:   %e += %e +- %e"%(mean_stats['median'], mean_stats['mad'], mean_stats['std']))
        print("Median: %e += %e +- %e"%(median_stats['median'], median_stats['mad'], median_stats['std']))

        aper_file = '%s_%03i.txt'%(base_name, aperture)
        aper_file = os.path.join(output_dir, aper_file)
        print('Aperture file: %s'%aper_file)
        with open(aper_file, "w+") as aper_out:
            for one_value in zip(flux_average, flux_median, x_position, y_position):
                one_line = "{:.7f} {:.7f} {:.1f} {:.1f}\n".format(*one_value)
                aper_out.write(one_line)

    return aperture_list, sigma_output, mad_output, mad_std_output

if __name__ == "__main__":
    args = define_options().parse_args()
    aperture_ap, sigma_ap, mad_ap, nmad_ap = noise_statistics_aperture(
        args.data_image,
        args.seg_image,
        args.flag_image,
        args.sky_image,
        aperture_min=args.aper_min,
        aperture_max=args.aper_max,
        aperture_step=args.aper_sampling,
        Nsample=args.n_sample,
        base_name=args.out_basename,
        output_dir=args.output_dir)