import os import yaml import time import scipy as sp import numpy as np from CpicImgSim.config import cpism_refdata, which_focalplane, S # S is synphot from CpicImgSim.config import optics_config from CpicImgSim.utils import region_replace from CpicImgSim.io import log from astropy.convolution import convolve_fft from scipy.signal import fftconvolve FILTERS = { 'f565': S.FileBandpass(f'{cpism_refdata}/throughtput/f565_total.fits'), 'f661': S.FileBandpass(f'{cpism_refdata}/throughtput/f661_total.fits'), 'f743': S.FileBandpass(f'{cpism_refdata}/throughtput/f743_total.fits'), 'f883': S.FileBandpass(f'{cpism_refdata}/throughtput/f883_total.fits'), 'f940': S.FileBandpass(f'{cpism_refdata}/throughtput/f940_total.fits'), 'f1265': S.FileBandpass(f'{cpism_refdata}/throughtput/f1265_total.fits'), 'f1425': S.FileBandpass(f'{cpism_refdata}/throughtput/f1425_total.fits'), 'f1542': S.FileBandpass(f'{cpism_refdata}/throughtput/f1542_total.fits'), } def filter_throughput(filter_name): """ Totally throughput of the each CPIC band. Including the throughput of the filter, telescope, cpic, and camera QE. If the filter_name is not supported, return the throughput of the default filter(f661). Parameters ----------- filter_name: str The name of the filter. One of ['f565', 'f661'(default), 'f743', 'f883', 'f940', 'f1265', 'f1425', 'f1542'] Returns -------- synphot.Bandpass The throughput of the filter. """ filter_name = filter_name.lower() filter_name = 'f661' if filter_name == 'default' else filter_name if filter_name not in FILTERS.keys(): log.warning(f"滤光片名称错误({filter_name}),返回默认滤光片(f661)透过率") filter_name = 'f661' return FILTERS[filter_name] def example_psf_func(band, spectrum, frame_size, error=0.1): """ Example psf generating function. Parameters ------------- band: str The name of the band. spectrum: synphot.Spectrum or synphot.SourceSpectrum The spectrum of the target. frame_size: int The size of the frame. error: float Phase RMS error. Returns --------------- 2D array psf image with shape of `frame_size` """ pass def example_monochromatic_psf(wavelength, error=0.1): pass def rotate_and_shift(shift, rotation, init_shifts): rotation_rad = rotation / 180 * np.pi return np.array([ shift[0] * np.cos(rotation_rad) + shift[1] * np.sin(rotation_rad), -shift[0] * np.sin(rotation_rad) + shift[1] * np.cos(rotation_rad) ]) + np.array(init_shifts) from scipy.ndimage import rotate def ideal_focus_image( bandpass: S.spectrum.SpectralElement, targets: list, platescale, platesize: list = [1024, 1024], init_shifts: list = [0, 0], rotation: float = 0,): focal_image = np.zeros(platesize) focal_shape = np.array(platesize)[::-1] # x, y if not targets: return focal_image for target in targets: sub_x, sub_y, sub_spectrum, sub_image = target sub_shift = rotate_and_shift([sub_x, sub_y], rotation, init_shifts) / platescale sed = (sub_spectrum * bandpass).integrate() if sub_image is None: x = (focal_shape[0] - 1)/2 + sub_shift[0] y = (focal_shape[1] - 1)/2 + sub_shift[1] int_x = int(x) int_y = int(y) if int_x < 0 or int_x >= focal_shape[0] - 1 or int_y < 0 or int_y >= focal_shape[1] - 1: continue dx1 = x - int_x dx0 = 1 - dx1 dy1 = y - int_y dy0 = 1 - dy1 sub = np.array([ [dx0*dy0, dx1*dy0], [dx0*dy1, dx1*dy1]]) * sed focal_image[int_y: int_y+2, int_x: int_x+2] += sub else: # sub_image = sub_image sub_image = np.abs(rotate(sub_image, rotation, reshape=False)) sub_image = sub_image / sub_image.sum() sub_img_shape = np.array(sub_image.shape)[::-1] sub_shift += (focal_shape-1)/2 - (sub_img_shape-1)/2 focal_image = region_replace( focal_image, sub_image * sed, sub_shift, subpix=True ) return focal_image from scipy.signal import fftconvolve def sp_convole_fft(image, kernal): kernal = kernal / kernal.sum() # y0 = kernal.shape[0] // 2 # x0 = kernal.shape[1] // 2 outimg = fftconvolve(image, kernal, mode='same') # return outimg[y0:y0+image.shape[0], x0:x0+image.shape[1]] return outimg def convolve_psf( band: str, targets: list, psf_function: callable, init_shifts: list = [0, 0], rotation: float = 0, nsample: int = 5, error: float = 1, platesize: list = [1024, 1024]) -> np.ndarray : config = optics_config[which_focalplane(band)] platescale = config['platescale'] filter = filter_throughput(band) wave = filter.wave throughput = filter.throughput min_wave = wave[0] max_wave = wave[-1] all_fp_image = [] for i_wave in range(nsample): d_wave = (max_wave - min_wave) / nsample wave0 = min_wave + i_wave * d_wave wave1 = min_wave + (i_wave + 1) * d_wave center_wavelength = (wave0 + wave1) / 2 * 1e-10 i_throughput = throughput.copy() i_throughput[(wave > wave1) | (wave < wave0)] = 0 i_band = S.ArrayBandpass(wave, i_throughput, waveunits='angstrom') i_fp_image = ideal_focus_image(i_band, targets, platescale, platesize, init_shifts, rotation) psf = psf_function(center_wavelength, error=error) t0 = time.time() # c_fp_image = convolve_fft(i_fp_image, psf, allow_huge=True) c_fp_image = sp_convole_fft(i_fp_image, psf) print(f"Convolution time: {time.time()-t0}") all_fp_image.append(c_fp_image) return np.array(all_fp_image).mean(axis=0) def make_focus_image( band: str, targets: list, psf_function: callable, init_shifts: list = [0, 0], rotation: float = 0, platesize: list = [1024, 1024]) -> np.ndarray: """ Make the focus image of the targets. Parameters ----------- band: str The name of the band. targets: list The list of the targets. Each element of the list is a tuple of (x, y, spectrum). - x, y: float - The position of the target in the focal plane. - spectrum: synphot.Spectrum or synphot.SourceSpectrum - The spectrum of the target. psf_function: callable The function to generate the PSF, with same parameters and return as `example_psf_func`. init_shifts: list The initial shifts of the center targets. Unit: arcsec. The default is [0, 0]. rotation: float The rotation of the focal plane. Unit: degree. The default is 0 degree. platesize: list The size of the focal plane. Unit: pixel. The default is [1024, 1024]. Returns -------- np.ndarray The focus image of the targets. 2D array with the shape of platesize. """ config = optics_config[which_focalplane(band)] platescale = config['platescale'] focal_image = np.zeros(platesize) if not targets: return focal_image cstar_x, cstar_y, cstar_spectrum = targets[0] cstar_shift = rotate_and_shift([cstar_x, cstar_y]) / platescale error_value = 0 # nm cstar_psf = psf_function(band, cstar_spectrum, config['cstar_frame_size'], error=error_value) platesize = np.array(platesize)[::-1] psf_shape = np.array(cstar_psf.shape)[::-1] cstar_shift += (platesize-1)/2 - (psf_shape-1)/2 focal_image = region_replace( focal_image, cstar_psf, cstar_shift, padded_in=False, padded_out=False, subpix=True) for i_target in range(1, len(targets)): sub_x, sub_y, sub_spectrum = targets[i_target] pdout = False if i_target == len(targets)-1 else True pdin = False if i_target == 1 else True log.debug(f"input target {sub_x=:}, {sub_y=:}") sub_shift = rotate_and_shift([sub_x, sub_y], rotation, init_shifts) / platescale log.debug(f"after rotate and shift {sub_shift=:}") sub_psf = psf_function( band, sub_spectrum, config['substellar_frame_size'], error=error_value ) psf_shape = np.array(sub_psf.shape)[::-1] sub_shift += (platesize-1)/2 - (psf_shape-1)/2 log.debug(f"input shift of region_replace: {sub_shift=:}") focal_image = region_replace( focal_image, sub_psf, sub_shift, padded_in=pdin, padded_out=pdout, subpix=True ) return focal_image def focal_mask(image, iwa, platescale, throughtput=1e-6): """ Mask the image outside the inner working angle. Parameters ----------- image: np.ndarray The image to be masked. iwa: float The inner working angle. Unit: arcsec. platescale: float The platescale of the image. Unit: arcsec/pixel. throughtput: float The throughtput of the mask. The default is 1e-6. Returns -------- np.ndarray The masked image. """ xx, yy = np.mgrid[0:image.shape[0], 0:image.shape[1]] center = np.array([(image.shape[0]-1)/2, (image.shape[1]-1)/2]) mask = (abs(xx - center[0]) < iwa / platescale) | (abs(yy - center[1]) < iwa / platescale) image_out = image.copy() image_out[mask] *= throughtput return image_out