"""
Identifier:     csst_proto/test_flip_image.py
Name:           test_flip_image.py
Description:    图像翻转的单元测试
Author:         Bo Zhang
Created:        2023-10-26
Modified-History:
    2023-11-25, Bo Zhang, add module header
"""
import os
import unittest

import numpy as np

from csst_proto import flip_image, read_default_image
from csst_proto import flip_multiple_images_jl, flip_multiple_images_mp


class FlipImageTestCase(unittest.TestCase):
    def test_flip_image(self):
        """
        Aim
        ---
        Test `flip_image` function with package data.

        Criteria
        --------
        The flipped image is consistent with answer.

        Details
        -------
        The input image is in package data.
        """
        self.assertTrue(
            np.all(flip_image(read_default_image()) == np.array([[4, 3], [2, 1]]))
        )

        # the code fails for 1D array
        with self.assertRaises(AssertionError):
            flip_image(np.array([1, 2, 3, 4]))

    def test_flip_image_on_server(self):
        """
        Aim
        ---
        Test `flip_image` function with server data.

        Criteria
        --------
        The flipped image is consistent with answer at 1e-6 level.

        Details
        -------
        This is the same case with `test_flip_image` but the data is on server.
        """
        image_input = np.loadtxt(
            os.path.join(
                os.environ["UNIT_TEST_DATA_ROOT"],
                "csst_proto/test_flip/input/image.txt",
            ),
            delimiter=",",
            dtype=int,
        )
        image_answer = np.loadtxt(
            os.path.join(
                os.environ["UNIT_TEST_DATA_ROOT"],
                "csst_proto/test_flip/answer/flipped_image.txt",
            ),
            delimiter=",",
            dtype=int,
        )
        self.assertLess(
            np.linalg.norm(flip_image(image_input) - image_answer),
            1e-6,
            "Test flip image on server failed",
        )

    def test_flip_multiple_images_mp(self):
        """
        Aim
        ---
        Test `flip_multiple_images_mp` function with package data.

        Criteria
        --------
        The flipped image equals answer.

        Details
        -------
        The input image is in package data.
        """
        n_jobs = 10
        imgs = [read_default_image() for _ in range(n_jobs)]
        flipped_imgs = flip_multiple_images_mp(imgs, n_jobs)
        for i_job in range(n_jobs):
            self.assertTrue(np.all(flipped_imgs[i_job] == np.array([[4, 3], [2, 1]])))

    def test_flip_multiple_images_jl(self):
        """
        Aim
        ---
        Test `flip_multiple_images_jl` function with package data.

        Criteria
        --------
        The flipped image equals answer.

        Details
        -------
        The input image is in package data.
        """
        n_jobs = 10
        imgs = [read_default_image() for _ in range(n_jobs)]
        flipped_imgs = flip_multiple_images_jl(imgs, n_jobs)
        for i_job in range(n_jobs):
            self.assertTrue(np.all(flipped_imgs[i_job] == np.array([[4, 3], [2, 1]])))
