import numpy as np
from scipy.interpolate import interp1d


def binningPSF(img, ngg):
    imgX = img.reshape(ngg, img.shape[0]//ngg, ngg, img.shape[1]//ngg).mean(-1).mean(1)
    return imgX


def radial_average_at_pixel(image, center_x, center_y, dr=10):
    # Get coordinates relative to the specified center pixel (x, y)
    y, x = np.indices(image.shape)
    r = np.sqrt((x - center_x)**2 + (y - center_y)**2)

    # Set up bins
    max_radius = int(r.max())  # Maximum distance from the center pixel
    radial_bins = np.arange(0, max_radius, dr)

    # Compute average value in each bin
    radial_means = []
    for i in range(len(radial_bins) - 1):
        mask = (r >= radial_bins[i]) & (r < radial_bins[i + 1])
        if np.any(mask):
            radial_means.append(image[mask].mean())
        else:
            radial_means.append(0)  # In case no pixels are in the bin
    return radial_bins[:-1], radial_means  # Exclude last bin since no mean calculated


def psf_extrapolate(psf, rr_trim=64, ngg=256):
    # ngg = 256
    # extrapolate PSF
    if True:
        xim = np.arange(256)-128
        xim, yim = np.meshgrid(xim, xim)
        rim = np.sqrt(xim**2 + yim**2)

        # rr_trim  = 96
        psf_temp = psf
        psf_temp[rim > rr_trim] = 0
        radii, means = radial_average_at_pixel(psf_temp, 128, 128, dr=4)

        radii_log = np.log(radii[1:])
        means_log = np.log(means[1:])

        finite_mask = np.isfinite(means_log)
        f_interp = interp1d(radii_log[finite_mask][:-1], means_log[finite_mask][:-1], kind='linear', fill_value="extrapolate")

        # ngg = 1024
        xim = np.arange(ngg)-int(ngg/2)
        xim, yim = np.meshgrid(xim, xim)
        rim = np.sqrt(xim**2 + yim**2)
        rim[int(ngg/2), int(ngg/2)] = np.finfo(float).eps  # 1e-7
        rim_log = np.log(rim)
        y_new = f_interp(rim_log)

        arr = np.zeros([ngg, ngg])
        arr[int(ngg/2-128):int(ngg/2+128), int(ngg/2-128):int(ngg/2+128)] = np.log(psf_temp + np.finfo(float).eps)
        arr[rim > rr_trim] = 0
        arr[arr == 0] = y_new[arr == 0]
        psf = np.exp(arr)
        psf[rim > int(ngg/2)] = 0
    imPSF = psf  # binningPSF(psf, int(ngg/2))
    imPSF = imPSF/np.nansum(imPSF)
    return imPSF


def psf_extrapolate1(psf, rr_trim=64, ngg=256):
    # ngg = 256
    # extrapolate PSF
    if True:
        psf_enlar = np.zeros([ngg, ngg])
        psf_enlar[int(ngg/2-128):int(ngg/2+128), int(ngg/2-128):int(ngg/2+128)] = psf
        xim = np.arange(ngg)-ngg/2
        xim, yim = np.meshgrid(xim, xim)
        rim = np.sqrt(xim**2 + yim**2)

        psf_temp = psf_enlar
        # psf_temp[rim >= rr_trim] = 0
        psf_temp[rim >= ngg/2-2] = np.finfo(float).eps
        radii, means = radial_average_at_pixel(psf_temp, ngg/2, ngg/2, dr=2)

        radii_log = np.log(radii[1:])
        # radii_log = radii[1:]
        means_log = np.log(means[1:])

        # xim = np.arange(256)-128
        # xim, yim = np.meshgrid(xim, xim)
        # rim = np.sqrt(xim**2 + yim**2)

        # # rr_trim  = 96
        # psf_temp = psf
        # psf_temp[rim > rr_trim] = 0
        # radii, means = radial_average_at_pixel(psf_temp, 128, 128, dr=4)

        # radii_log = np.log10(radii[1:])
        # # radii_log = radii[1:]
        # means_log = np.log10(means[1:])

        finite_mask = np.isfinite(means_log)
        f_interp = interp1d(radii_log[finite_mask][:-1], means_log[finite_mask][:-1], kind='linear', fill_value="extrapolate")

        # ngg = 1024
        # xim = np.arange(ngg)-int(ngg/2)
        # xim, yim = np.meshgrid(xim, xim)
        # rim = np.sqrt(xim**2 + yim**2)
        # rim[int(ngg/2), int(ngg/2)] = np.finfo(float).eps  # 1e-7
        rim_log = np.log(rim)
        y_new = f_interp(rim_log)

        arr = np.zeros([ngg, ngg])
        arr[int(ngg/2-128):int(ngg/2+128), int(ngg/2-128):int(ngg/2+128)] = np.log(psf + np.finfo(float).eps)
        arr[rim > 128-2] = 0
        arr[arr == 0] = y_new[arr == 0]
        psf_n = np.exp(arr)
        # psf_n[int(ngg/2-128):int(ngg/2+128), int(ngg/2-128):int(ngg/2+128)] = psf
        # psf[rim > int(ngg/2)] = 0
    imPSF = psf_n  # binningPSF(psf, int(ngg/2))
    # imPSF = imPSF/np.nansum(imPSF)
    return imPSF