Skip to content
x_test_optics.py 5.67 KiB
Newer Older
Chen Yili's avatar
Chen Yili committed
import unittest
GZhao's avatar
GZhao committed

import time

GZhao's avatar
GZhao committed
from csst_cpic_sim.target import star_photlam
from csst_cpic_sim.optics import make_focus_image, focal_mask, filter_throughput, ideal_focus_image
from csst_cpic_sim.config import which_focalplane, S
Chen Yili's avatar
Chen Yili committed
import numpy as np
GZhao's avatar
GZhao committed
from astropy.io import fits

def gaussian_psf(band, spectrum, shape, error=0.1):
    psf_shape = [shape, shape]

    xx, yy = np.mgrid[0:psf_shape[0], 0:psf_shape[1]]
    center = np.array([(psf_shape[0]-1)/2, (psf_shape[1]-1)/2])
    sigma = 10
    psf = np.exp(-((xx-center[0])**2 +
                    (yy-center[1])**2) / (2*sigma**2))
    psf = psf / psf.sum()

    filter = filter_throughput(band)
    return psf * (spectrum * filter).integrate()
Chen Yili's avatar
Chen Yili committed


class TestOptics(unittest.TestCase):
    def test_filter_throughtput(self):
        bandpass = filter_throughput('f661')
        self.assertIsInstance(bandpass, S.spectrum.SpectralElement)

        bandpass = filter_throughput('F1265')
        self.assertIsInstance(bandpass, S.spectrum.SpectralElement)

        bandpass = filter_throughput('deFault')
        self.assertIsInstance(bandpass, S.spectrum.SpectralElement)

        bandpass = filter_throughput('none')
        self.assertIsInstance(bandpass, S.spectrum.SpectralElement)

    def test_which_focalpalne(self):
        self.assertEqual(which_focalplane('f565'), 'vis')
        self.assertEqual(which_focalplane('F661'), 'vis')
        self.assertEqual(which_focalplane('f743'), 'vis')
        self.assertEqual(which_focalplane('f883'), 'vis')
        self.assertEqual(which_focalplane('F940'), 'nir')
        self.assertEqual(which_focalplane('f1265'), 'nir')
        self.assertEqual(which_focalplane('F1425'), 'nir')
        self.assertEqual(which_focalplane('f1542'), 'nir')
        self.assertEqual(which_focalplane('wfs'), 'wfs')
        self.assertRaises(ValueError, which_focalplane, 'what')

    def test_make_focus_image(self):
        # test fuction to generate psf

        # test targets
        cstar = star_photlam(0, 'F2V', is_blackbody=True)
        targets = [
            [0, 0, cstar],
            [1, 1, star_photlam(10-5, 'B2V', is_blackbody=True)],
            [-2, 2, star_photlam(11-5, 'A2V', is_blackbody=True)],
            [3, -3, star_photlam(12-5, 'G2V', is_blackbody=True)],
            [100, 100, star_photlam(12, 'K2V', is_blackbody=True)],
            [100, 100, star_photlam(12, 'K2V', is_blackbody=True)],
        ]

        focus_image = make_focus_image(
            'f661',
            targets,
            gaussian_psf,
            init_shifts=[1, 1],
            rotation=45,
            platesize=[1024, 1024]
        )
        self.assertIsNotNone(focus_image)

        focus_image = make_focus_image(
            'f661',
            [],
            gaussian_psf,
            init_shifts=[1, 1],
            rotation=45,
            platesize=[1024, 1024]
        )

        self.assertIsNotNone(focus_image)

    def test_focal_mask(self):
        image = np.zeros((100, 100)) + 1
        image_out = focal_mask(image, 1, 0.1, throughtput=0)
        self.assertEqual((image - image_out).sum(), 2000+2000-400)


if __name__ == '__main__':
GZhao's avatar
GZhao committed
    # unittest.main()
    import time
    from CpicImgSim.target import star_photlam

    def make_test_sub_image(size, shape):
        shape = np.array([shape, shape])
        sub_image = np.zeros(shape)
        center = (shape-1)/2

        xx, yy = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]))
        xx = xx - center[1]
        yy = yy - center[0]
        sub_image[np.abs(xx) < 0.6] = 1
        sub_image[np.abs(yy) < 0.6] = 1

        sub_image[(np.abs(xx) < 0.6) & (np.abs(yy) < 0.6)] = 0
        sub_image[yy > 7] = 0
        sub_image[yy < -3] = 0
        sub_image[np.abs(xx) > 3] = 0

        sub_image2 = (np.sqrt(xx**2*2 + yy**2) < size).astype(int)
        sub_image = sub_image2 * (1 - sub_image)
        return sub_image

    def test_ideal_focus_image():
        targets = [
            [0, 0, star_photlam(2, 'G2'), None],
            [5, 3, star_photlam(0, 'G2'), make_test_sub_image(4, 20)],
            [8, 0, star_photlam(-5, 'G2'), make_test_sub_image(10, 100)],
        ]
        bandpass = S.Box(6000, 500)
        
        start_time = time.time()
        foc = ideal_focus_image(
            bandpass,
            targets,
            0.0165,
            [1024, 1024],
        )
        end_time = time.time()
        execution_time = end_time - start_time

        fits.writeto('foc.fits', foc, overwrite=True)

        start_time = time.time()
        foc = ideal_focus_image(
            bandpass,
            targets,
            0.0165,
            [1024, 1024],
            rotation=30,
        )
        end_time = time.time()
        execution_time = end_time - start_time

        fits.writeto('foc_rot30.fits', foc, overwrite=True)
    

    def test_convolve_psf():
        targets = [
            [0, 0, star_photlam(2, 'G2'), None],
            [5, 3, star_photlam(0, 'G2'), make_test_sub_image(4, 20)],
            [8, 0, star_photlam(-5, 'G2'), make_test_sub_image(10, 100)],
        ]

        def cov_psf_func(wave, error=0.1):
            psf_shape = [1024, 1024]
            xx, yy = np.mgrid[0:psf_shape[0], 0:psf_shape[1]]
            center = np.array([(psf_shape[0]-1)/2, (psf_shape[1]-1)/2])
            sigma = 10
            psf = np.exp(-((xx-center[0])**2 +
                            (yy-center[1])**2) / (2*sigma**2))
            psf = psf / psf.sum()
            return psf
        
GZhao's avatar
GZhao committed
        from CpicImgSim.optics import focal_convolve
        img_final = focal_convolve('f661', targets, cov_psf_func)
GZhao's avatar
GZhao committed

        fits.writeto('cov.fits', img_final, overwrite=True)


    test_convolve_psf()

    # import matplotlib.pyplot as plt
    # plt.imshow(make_test_sub_image(5, 6))
    # plt.show()