import numpy as np
import joblib
from astropy import table
from tqdm import trange

from .._csst import csst


TQDM_KWARGS = dict(unit="task", dynamic_ncols=False)

# join_type for data x plan
PLAN_JOIN_TYPE = "inner"
"""
References:
    - https://docs.astropy.org/en/stable/api/astropy.table.join.html
    - https://docs.astropy.org/en/stable/table/operations.html#join

Typical types:
    - inner join: Only matching rows from both tables
    - left join: All rows from left table, matching rows from right table
    - right join: All rows from right table, matching rows from left table
    - outer join: All rows from both tables
    - cartesian join: Every combination of rows from both tables
"""


def split_data_basis(data_basis: table.Table, n_split: int = 1) -> list[table.Table]:
    """Split data basis into n_split parts via obs_id"""
    assert (
        np.unique(data_basis["dataset"]).size == 1
    ), "Only one dataset is allowed for splitting."
    # sort
    data_basis.sort(keys=["dataset", "obs_id"])
    # get unique obsid
    u_obsid, i_obsid, c_obsid = np.unique(
        data_basis["obs_id"].data, return_index=True, return_counts=True
    )
    # set chunk size
    chunk_size = int(np.fix(len(u_obsid) / n_split))
    # initialize chunks
    chunks = []
    for i_split in range(n_split):
        if i_split < n_split - 1:
            chunks.append(
                data_basis[
                    i_obsid[i_split * chunk_size] : i_obsid[(i_split + 1) * chunk_size]
                ]
            )
        else:
            chunks.append(data_basis[i_obsid[i_split * chunk_size] :])
    # np.unique(table.vstack(chunks)["_id"])
    # np.unique(table.vstack(chunks)["obs_id"])
    return chunks


class Dispatcher:
    """
    A class to dispatch tasks based on the observation type.
    """

    @staticmethod
    def dispatch_file(
        plan_basis: table.Table,
        data_basis: table.Table,
    ) -> list[dict]:
        # unique obsid --> useless
        # u_obsid = table.unique(data_basis["dataset", "obs_id"])

        # return an empty list if input is empty
        if len(plan_basis) == 0 or len(data_basis) == 0:
            return []

        # initialize task list
        task_list = []

        # sort data_basis before dispatching
        data_basis.sort(keys=["dataset", "obs_id", "detector"])

        # loop over data
        for i_data_basis in trange(len(data_basis), **TQDM_KWARGS):
            # i_data_basis = 1
            this_task = dict(data_basis[i_data_basis])
            this_data_basis = data_basis[i_data_basis : i_data_basis + 1]
            this_relevant_plan = table.join(
                this_data_basis[
                    "dataset",
                    "instrument",
                    "obs_type",
                    "obs_group",
                    "obs_id",
                ],
                plan_basis,
                keys=[
                    "dataset",
                    "instrument",
                    "obs_type",
                    "obs_group",
                    "obs_id",
                ],
                join_type="inner",
                table_names=["data", "plan"],
            )
            # set n_file_expected and n_file_found
            this_task["n_file_expected"] = 1
            this_task["n_file_found"] = 1
            # append this task
            task_list.append(
                dict(
                    task=this_task,
                    success=True,
                    relevant_plan=this_relevant_plan,
                    relevant_data=data_basis[i_data_basis : i_data_basis + 1],
                    n_relevant_plan=len(this_relevant_plan),
                    n_relevant_data=1,
                    relevant_data_id_list=[data_basis[i_data_basis]["_id"]],
                    n_file_expected=1,
                    n_file_found=1,
                )
            )

        return task_list

    @staticmethod
    def dispatch_detector(
        plan_basis: table.Table,
        data_basis: table.Table,
        n_jobs: int = 1,
    ) -> list[dict]:
        """

        Parameters
        ----------
        plan_basis
        data_basis
        n_jobs

        Returns
        -------

        """
        if n_jobs != 1:
            task_list = joblib.Parallel(n_jobs=n_jobs)(
                joblib.delayed(Dispatcher.dispatch_detector)(plan_basis, _)
                for _ in split_data_basis(data_basis, n_split=n_jobs)
            )
            return sum(task_list, [])

        # return an empty list if input is empty
        if len(plan_basis) == 0 or len(data_basis) == 0:
            return []

        # unique obsid
        u_obsid = table.unique(data_basis["dataset", "obs_id"])
        relevant_plan = table.join(
            u_obsid,
            plan_basis,
            keys=["dataset", "obs_id"],
            join_type=PLAN_JOIN_TYPE,
        )
        print(f"{len(relevant_plan)} relevant plan records")

        u_data_detector = table.unique(
            data_basis[
                "dataset",
                "instrument",
                "obs_type",
                "obs_group",
                "obs_id",
                "detector",
            ]
        )

        # initialize task list
        task_list = []

        # loop over plan
        for i_data_detector in trange(len(u_data_detector), **TQDM_KWARGS):
            # i_data_detector = 1
            this_task = dict(u_data_detector[i_data_detector])
            this_data_detector = u_data_detector[i_data_detector : i_data_detector + 1]

            # join data and plan
            this_data_detector_files = table.join(
                this_data_detector,
                data_basis,
                keys=[
                    "dataset",
                    "instrument",
                    "obs_type",
                    "obs_group",
                    "obs_id",
                    "detector",
                ],
                join_type="inner",
            )
            this_data_detector_plan = table.join(
                this_data_detector,
                relevant_plan,
                keys=[
                    "dataset",
                    "instrument",
                    "obs_type",
                    "obs_group",
                    "obs_id",
                ],
                join_type=PLAN_JOIN_TYPE,
            )

            # whether detector effective
            this_detector = this_data_detector["detector"][0]
            this_instrument = this_data_detector["instrument"][0]
            this_detector_effective = (
                this_detector in csst[this_instrument].effective_detector_names
            )

            n_file_expected = (
                this_data_detector_plan["n_file"][0]
                if len(this_data_detector_plan) > 0
                else 0
            )
            n_file_found = len(this_data_detector_files)
            # set n_file_expected and n_file_found
            this_task["n_file_expected"] = n_file_expected
            this_task["n_file_found"] = n_file_found
            # append this task
            task_list.append(
                dict(
                    task=this_task,
                    success=(
                        len(this_data_detector_plan) == 1
                        and len(this_data_detector_files) == 1
                        and this_detector_effective
                        and n_file_found == n_file_expected
                    ),
                    relevant_plan=this_data_detector_plan,
                    relevant_data=this_data_detector_files,
                    n_relevant_plan=len(this_data_detector_plan),
                    n_relevant_data=len(this_data_detector_files),
                    relevant_data_id_list=(
                        []
                        if len(this_data_detector_files) == 0
                        else list(this_data_detector_files["_id"])
                    ),
                    n_file_expected=this_data_detector_plan["n_file"].sum(),
                    n_file_found=len(this_data_detector_files),
                )
            )
        return task_list

    @staticmethod
    def dispatch_obsid(
        plan_basis: table.Table,
        data_basis: table.Table,
        n_jobs: int = 1,
    ) -> list[dict]:

        if n_jobs != 1:
            task_list = joblib.Parallel(n_jobs=n_jobs)(
                joblib.delayed(Dispatcher.dispatch_obsid)(plan_basis, _)
                for _ in split_data_basis(data_basis, n_split=n_jobs)
            )
            return sum(task_list, [])

        # return an empty list if input is empty
        if len(plan_basis) == 0 or len(data_basis) == 0:
            return []

        group_keys = ["dataset", "instrument", "obs_type", "obs_group", "obs_id"]
        obsid_basis = data_basis.group_by(group_keys)

        # initialize task list
        task_list = []
        # loop over obsid
        for this_obsid_basis in obsid_basis.groups:
            # find relevant plan
            this_relevant_plan_basis = table.join(
                this_obsid_basis[group_keys][:1],
                plan_basis,
                keys=group_keys,
                join_type=PLAN_JOIN_TYPE,
            )
            assert len(this_relevant_plan_basis) == 1
            # generate task
            this_task = dict(this_relevant_plan_basis[group_keys][0])
            n_file_expected = this_relevant_plan_basis[0]["n_file"]
            n_file_found = len(this_obsid_basis)
            this_instrument = this_relevant_plan_basis[0]["instrument"]
            detectors_found = set(this_obsid_basis["detector"])
            detectors_expected = set(csst[this_instrument].effective_detector_names)
            this_success = (
                n_file_expected == n_file_found
                and detectors_found == detectors_expected
            )
            this_task["n_file_expected"] = n_file_expected
            this_task["n_file_found"] = n_file_found
            # append this task
            task_list.append(
                dict(
                    task=this_task,
                    success=this_success,
                    relevant_plan=this_relevant_plan_basis,
                    relevant_data=this_obsid_basis,
                    n_relevant_plan=len(this_relevant_plan_basis),
                    n_relevant_data=len(this_obsid_basis),
                    relevant_data_id_list=(
                        []
                        if len(this_obsid_basis) == 0
                        else list(this_obsid_basis["_id"])
                    ),
                    n_file_expected=n_file_expected,
                    n_file_found=n_file_found,
                )
            )
        return task_list

    @staticmethod
    def dispatch_obsgroup_detector(
        plan_basis: table.Table,
        data_basis: table.Table,
        # n_jobs: int = 1,
    ) -> list[dict]:
        # return an empty list if input is empty
        if len(plan_basis) == 0 or len(data_basis) == 0:
            return []

        # unique obsgroup basis (using group_by)
        obsgroup_plan_basis = plan_basis.group_by(
            keys=[
                "dataset",
                "instrument",
                "obs_type",
                "obs_group",
            ]
        )

        # initialize task list
        task_list = []

        # loop over obsgroup
        for i_obsgroup_plan in trange(len(obsgroup_plan_basis.groups), **TQDM_KWARGS):
            this_obsgroup_plan_basis = obsgroup_plan_basis.groups[i_obsgroup_plan]
            this_obsgroup_obsid = this_obsgroup_plan_basis["obs_id"].data
            n_file_expected = len(this_obsgroup_obsid)

            this_instrument = this_obsgroup_plan_basis["instrument"][0]
            effective_detector_names = csst[this_instrument].effective_detector_names

            # loop over effective detectors
            for this_effective_detector_name in effective_detector_names:
                this_task = dict(
                    dataset=this_obsgroup_plan_basis["dataset"][0],
                    instrument=this_obsgroup_plan_basis["instrument"][0],
                    obs_type=this_obsgroup_plan_basis["obs_type"][0],
                    obs_group=this_obsgroup_plan_basis["obs_group"][0],
                    detector=this_effective_detector_name,
                )
                this_obsgroup_detector_expected = table.Table(
                    [
                        dict(
                            dataset=this_obsgroup_plan_basis["dataset"][0],
                            instrument=this_obsgroup_plan_basis["instrument"][0],
                            obs_type=this_obsgroup_plan_basis["obs_type"][0],
                            obs_group=this_obsgroup_plan_basis["obs_group"][0],
                            obs_id=this_obsid,
                            detector=this_effective_detector_name,
                        )
                        for this_obsid in this_obsgroup_obsid
                    ]
                )
                this_obsgroup_detector_found = table.join(
                    this_obsgroup_detector_expected,
                    data_basis,
                    keys=[
                        "dataset",
                        "instrument",
                        "obs_type",
                        "obs_group",
                        "obs_id",
                        "detector",
                    ],
                    join_type="inner",
                )
                n_file_found = len(this_obsgroup_detector_found)
                this_success = n_file_found == n_file_expected and set(
                    this_obsgroup_detector_found["obs_id"]
                ) == set(this_obsgroup_obsid)
                # set n_file_expected and n_file_found
                this_task["n_file_expected"] = n_file_expected
                this_task["n_file_found"] = n_file_found
                # append this task if this_success
                if this_success:
                    task_list.append(
                        dict(
                            task=this_task,
                            success=this_success,
                            relevant_plan=this_obsgroup_plan_basis,
                            relevant_data=this_obsgroup_detector_found,
                            n_relevant_plan=len(this_obsgroup_plan_basis),
                            n_relevant_data=len(this_obsgroup_detector_found),
                            relevant_data_id_list=(
                                list(this_obsgroup_detector_found["_id"])
                                if n_file_found > 0
                                else []
                            ),
                            n_file_expected=n_file_expected,
                            n_file_found=n_file_found,
                        )
                    )
        return task_list

    @staticmethod
    def dispatch_obsgroup(
        plan_basis: table.Table,
        data_basis: table.Table,
        # n_jobs: int = 1,
    ) -> list[dict]:

        # return an empty list if input is empty
        if len(plan_basis) == 0 or len(data_basis) == 0:
            return []

        # unique obsgroup basis
        obsgroup_basis = table.unique(
            plan_basis[
                "dataset",
                "instrument",
                "obs_type",
                "obs_group",
            ]
        )

        # initialize task list
        task_list = []

        # loop over obsgroup
        for i_obsgroup in trange(len(obsgroup_basis), **TQDM_KWARGS):

            # i_obsgroup = 1
            this_task = dict(obsgroup_basis[i_obsgroup])
            this_success = True

            this_obsgroup_plan = table.join(
                obsgroup_basis[i_obsgroup : i_obsgroup + 1],  # this obsgroup
                plan_basis,
                keys=["dataset", "instrument", "obs_type", "obs_group"],
                join_type=PLAN_JOIN_TYPE,
            )
            this_obsgroup_file = table.join(
                this_obsgroup_plan,
                data_basis,
                keys=["dataset", "instrument", "obs_type", "obs_group", "obs_id"],
                join_type="inner",
                table_names=["plan", "data"],
            )

            # loop over obsid
            for i_obsid in range(len(this_obsgroup_plan)):
                # i_obsid = 1
                # print(i_obsid)
                this_instrument = this_obsgroup_plan[i_obsid]["instrument"]
                this_n_file = this_obsgroup_plan[i_obsid]["n_file"]
                this_effective_detector_names = csst[
                    this_instrument
                ].effective_detector_names

                this_obsgroup_obsid_file = table.join(
                    this_obsgroup_plan[i_obsid : i_obsid + 1],  # this obsid
                    data_basis,
                    keys=["dataset", "instrument", "obs_type", "obs_group", "obs_id"],
                    join_type="inner",
                    table_names=["plan", "data"],
                )

                if this_instrument == "HSTDM":
                    # 不确定以后是1个探测器还是2个探测器
                    this_n_file_found = len(this_obsgroup_obsid_file)
                    this_n_file_expected = this_n_file
                    this_success &= this_n_file_found == this_n_file_expected
                else:
                    # for other instruments, e.g., MSC
                    # n_file_found = len(this_obsgroup_obsid_file)
                    # n_file_expected = len(effective_detector_names)
                    # this_success &= n_file_found == n_file_expected

                    # or more strictly, expected files are a subset of files found
                    this_success &= set(this_effective_detector_names) <= set(
                        this_obsgroup_obsid_file["detector"]
                    )

            n_file_expected = int(this_obsgroup_plan["n_file"].sum())
            n_file_found = len(this_obsgroup_file)
            # set n_file_expected and n_file_found
            this_task["n_file_expected"] = n_file_expected
            this_task["n_file_found"] = n_file_found
            # append this task
            task_list.append(
                dict(
                    task=this_task,
                    success=this_success,
                    relevant_plan=this_obsgroup_plan,
                    relevant_data=this_obsgroup_file,
                    n_relevant_plan=len(this_obsgroup_plan),
                    n_relevant_data=len(this_obsgroup_file),
                    relevant_data_id_list=(
                        []
                        if len(this_obsgroup_file) == 0
                        else list(this_obsgroup_file["_id_data"])
                    ),
                    n_file_expected=this_obsgroup_plan["n_file"].sum(),
                    n_file_found=len(this_obsgroup_file),
                )
            )
        return task_list

    @staticmethod
    def load_test_data() -> tuple:
        import joblib

        plan_basis = joblib.load("dagtest/csst-msc-c9-25sqdeg-v3.plan_basis.dump")
        data_basis = joblib.load("dagtest/csst-msc-c9-25sqdeg-v3.level0_basis.dump")
        return plan_basis, data_basis
