Skip to content
optics.py 7.4 KiB
Newer Older
GZhao's avatar
GZhao committed

import numpy as np
from scipy.signal import fftconvolve
from scipy.ndimage import rotate

from .config import config, S  # S is synphot
from .utils import region_replace
from .io import log
from .psf_simulation import single_band_masked_psf, single_band_psf

FILTERS = {}
for key, value in config['bands'].items():
    FILTERS[key] = S.FileBandpass(value)
default_band = config['default_band']


def filter_throughput(filter_name):
    """
    Totally throughput of the each CPIC band.
    Including the throughput of the filter, telescope, cpic, and camera QE.
    If the filter_name is not supported, return the throughput of the default filter(f661).

    Parameters
    -----------
    filter_name: str
        The name of the filter.
        One of ['f565', 'f661'(default), 'f743', 'f883', 'f940', 'f1265', 'f1425', 'f1542']

    Returns
    --------
    synphot.Bandpass
        The throughput of the filter.

    """
    filter_name = filter_name.lower()
    filter_name = default_band if filter_name == 'default' else filter_name
    if filter_name not in FILTERS.keys():
        log.warning(f"滤光片名称错误({filter_name}),返回默认滤光片({default_band})透过率")
        filter_name = default_band

    return FILTERS[filter_name]


def _rotate_and_shift(shift, rotation, init_shifts):
    rotation_rad = rotation / 180 * np.pi
    return np.array([
        shift[0] * np.cos(rotation_rad) + shift[1] * np.sin(rotation_rad),
        -shift[0] * np.sin(rotation_rad) + shift[1] * np.cos(rotation_rad)
    ]) + np.array(init_shifts)


def ideal_focus_image(
        bandpass: S.spectrum.SpectralElement,
        targets: list,
        platescale,
        platesize: list = [1024, 1024],
        init_shifts: list = [0, 0],
        rotation: float = 0) -> np.ndarray:
    """Ideal focus image of the targets.
    Each star is a little point of 1pixel.

    Parameters
    -----------
    bandpass: synphot.SpectralElement
        The bandpass of the filter.
    targets: list
        The list of the targets. See the output of `spectrum_generator` for details.
    platescale: float
        The platescale of the camera. Unit: arcsec/pixel
    platesize: list
        The size of the image. Unit: pixel
    init_shifts: list
        The shifts of the targets to simulate the miss alignment. Unit: arcsec
    rotation: float
        The rotation of the image. Unit: degree

    Returns
    --------
    np.ndarray
        The ideal focus image.
    """
    
    focal_image = np.zeros(platesize)
    focal_shape = np.array(platesize)[::-1] # x, y

    if not targets:
        return focal_image

    for target in targets:
        sub_x, sub_y, sub_spectrum, sub_image = target
        sub_shift = _rotate_and_shift([sub_x, sub_y], rotation, init_shifts) / platescale
        sed = (sub_spectrum * bandpass).integrate()

        if sub_image is None:

            x = (focal_shape[0] - 1)/2 + sub_shift[0]
            y = (focal_shape[1] - 1)/2 + sub_shift[1]

            int_x = int(x)
            int_y = int(y)
            if int_x < 0 or int_x >= focal_shape[0] - 1 or int_y < 0 or int_y >= focal_shape[1] - 1:
                continue

            dx1 = x - int_x
            dx0 = 1 - dx1
            dy1 = y - int_y
            dy0 = 1 - dy1

            sub = np.array([
                [dx0*dy0, dx1*dy0],
                [dx0*dy1, dx1*dy1]]) * sed
            
            focal_image[int_y: int_y+2, int_x: int_x+2] += sub
        else:
            # sub_image = sub_image
            sub_image = np.abs(rotate(sub_image, rotation, reshape=False))
            sub_image = sub_image / sub_image.sum()
            sub_img_shape = np.array(sub_image.shape)[::-1]
            sub_shift += (focal_shape-1)/2 - (sub_img_shape-1)/2
            focal_image = region_replace(
                focal_image,
                sub_image * sed,
                sub_shift,
                subpix=True
            )
    return focal_image


def focal_convolve(
        band: str,
        targets: list,
        init_shifts: list = [0, 0],
        rotation: float = 0,
        nsample: int = 5,
        error: float = 0,
        platesize: list = [1024, 1024]) -> np.ndarray :
    """PSF convolution of the ideal focus image.
    
    Parameters
    ----------
    band: str
        The name of the band.
    target: list
        The list of thetargets. See the output of `spectrum_generator` for details.
    init_shifts: list
        The shifts of the targets to simulate the miss alignment. Unit: arcsec
    rotation: float
        The rotation of the image. Unit: degree
    error: float
        The error of the DM acceleration. Unit: nm
    platesize: list
        The size of the image. Unit: pixel

    Returns
    --------
    np.ndarray
    """

    # config = optics_config[which_focalplane(band)]
    platescale = config['platescale']

    # telescope_config = optics_config['telescope']
    area = config['aperature_area']

    filter = filter_throughput(band)
   
    throughput = filter.throughput
    wave = filter.wave
    
    throughput_criterion = throughput.max() * 0.1
    wave_criterion = wave[throughput > throughput_criterion]
    min_wave = wave_criterion[0]
    max_wave = wave_criterion[-1]
    # print(min_wave, max_wave)

    platescale = config['platescale']
    iwa = config['mask_width'] / 2

    if abs(init_shifts[0]) > 4 or abs(init_shifts[1]) > 4:
        print('Input shifts are too large, and are set to zero')
        init_shifts = [0, 0]
        
    all_fp_image = []
    if not targets:
        return np.zeros((platesize[1], platesize[0]))

    for i_wave in range(nsample):
            d_wave = (max_wave - min_wave) / nsample
            wave0 = min_wave + i_wave * d_wave
            wave1 = min_wave + (i_wave + 1) * d_wave
            center_wavelength = (wave0 + wave1) / 2 * 1e-10

            i_throughput = throughput.copy()
            i_throughput[(wave > wave1) | (wave < wave0)] = 0
            i_band = S.ArrayBandpass(wave, i_throughput, waveunits='angstrom')

            i_fp_image = ideal_focus_image(i_band, targets[1:], platescale, platesize, init_shifts, rotation)
            psf = single_band_psf(center_wavelength, error=error)
            
            _, _, cstar_sp, _ = targets[0]
            cstar_flux = (cstar_sp * i_band).integrate()
            cstar_psf = single_band_masked_psf(center_wavelength, error=error, shift=init_shifts)

            c_fp_image = fftconvolve(i_fp_image, psf, mode='same')
            c_fp_image = focal_mask(c_fp_image, iwa, platescale)
            c_fp_image = c_fp_image + cstar_flux * cstar_psf
            
            all_fp_image.append(c_fp_image * area) # trans to photon/second

    return np.array(all_fp_image).sum(axis=0)


def focal_mask(image, iwa, platescale, throughtput=1e-6):
    """
    Mask the image outside the inner working angle.

    Parameters
    -----------
    image: np.ndarray
        The image to be masked.
    iwa: float
        The inner working angle. Unit: arcsec.
    platescale: float
        The platescale of the image. Unit: arcsec/pixel.
    throughtput: float
        The throughtput of the mask. The default is 1e-6.

    Returns
    --------
    np.ndarray
        The masked image.
    """
    xx, yy = np.mgrid[0:image.shape[0], 0:image.shape[1]]
    center = np.array([(image.shape[0]-1)/2, (image.shape[1]-1)/2])
    mask = (abs(xx - center[0]) < iwa /
            platescale) | (abs(yy - center[1]) < iwa / platescale)
    image_out = image.copy()
    image_out[mask] *= throughtput
    return image_out