from .dict import DotDict
from .instrument import csst
from astropy.table import Table, vstack
import numpy as np

# from csst_dag.dict import DotDict
# from csst_dag.instrument import csst


class CsstPlanObsid(DotDict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @property
    def detectors(self):
        if self.instrument in ("MSC", "MCI", "IFS", "CPIC"):
            return csst[self.instrument].effective_detectors
        elif self.instrument == "HSTDM":
            if self.params.detector == "SIS12":
                return csst[self.instrument].effective_detectors
            else:
                return [csst[self.instrument][self.params.detector]]
        else:
            raise ValueError(f"Unknown instrument: {self.instrument}")

    @property
    def n_detector(self):
        return len(self.detectors)

    @property
    def n_file_expected(self):
        if self.instrument in ("MSC", "MCI", "IFS", "CPIC"):
            return self.n_detector
        elif self.instrument == "HSTDM":
            return self.n_detector * self.params.num_epec_frame
        else:
            raise ValueError(f"Unknown instrument: {self.instrument}")

    @staticmethod
    def from_plan(plan_data: dict) -> "CsstPlanObsid":
        return CsstPlanObsid(**plan_data)

    @property
    def expected_data_table(self):
        contents = []
        for detector in self.detectors:
            n_file = self.params.num_epec_frame if self.instrument == "HSTDM" else 1
            for _ in range(n_file):
                contents.append(
                    dict(
                        dataset=self.dataset,
                        instrument=self.instrument,
                        obs_type=self.obs_type,
                        obs_group=self.obs_group,
                        obs_id=self.obs_id,
                        detector=detector.name,
                    )
                )
        # post processing
        t = Table(contents)
        t.sort(
            keys=[
                "dataset",
                "instrument",
                "obs_type",
                "obs_group",
                "obs_id",
                "detector",
            ]
        )
        return t


class CsstPlanObsgroup(DotDict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # construct plan table
        self._plan_table = Table(list(self.values()))

        # assert unique obs_group
        u_keys = np.unique(
            self._plan_table["dataset", "instrument", "obs_type", "obs_group"]
        )
        assert len(u_keys) == 1, "Multiple `instruments/obs_types/datasets` found."

        # no duplicated obs_id's
        u_obsid = np.unique(self._plan_table["obs_id"])
        assert len(u_obsid) == len(
            self
        ), f"n_obsid {len(u_obsid)} != n_plan {len(self)}"

        # assign parameters
        self._dataset = str(u_keys[0]["dataset"])
        self._instrument = str(u_keys[0]["instrument"])
        self._obs_type = str(u_keys[0]["obs_type"])
        self._obs_group = str(u_keys[0]["obs_group"])

    @property
    def dataset(self):
        return self._dataset

    @property
    def instrument(self):
        return self._instrument

    @property
    def obs_type(self):
        return self._obs_type

    @property
    def obs_group(self):
        return self._obs_group

    @property
    def obs_id_list(self):
        return list(self.keys())

    @property
    def n_file_expected(self):
        return sum([self[_].n_file_expected for _ in self.obs_id_list])

    def __repr__(self):
        return (
            f"<CsstPlanObsgroup '{self.obs_group}' (n_obs_id={len(self)}): "
            f"instrument='{self.instrument}' "
            f"dataset='{self.dataset}' "
            f"obs_type='{self.obs_type}'>"
        )

    @staticmethod
    def from_plan(plan_data: list[dict]) -> "CsstPlanObsgroup":
        return CsstPlanObsgroup(
            **{_["obs_id"]: CsstPlanObsid.from_plan(_) for _ in plan_data}
        )

    @property
    def expected_data_table(self):
        t = vstack([self[k].expected_data_table for k in self])
        t.sort(
            keys=[
                "dataset",
                "instrument",
                "obs_type",
                "obs_group",
                "obs_id",
                "detector",
            ]
        )
        return t


class CsstLevel0(DotDict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @property
    def n_file_expected(self):
        return sum([self[_].n_file_expected for _ in self.obs_id_list])

    @staticmethod
    def extract_data_table(level0_data: list[dict]) -> Table:
        d_list = [
            dict(
                dataset=d["dataset"],
                instrument=d["instrument"],
                obs_type=d["obs_type"],
                obs_group=d["obs_group"],
                obs_id=d["obs_id"],
                detector=d["detector"],
            )
            for d in level0_data
        ]

        return CsstPlanObsgroup(
            **{_["obs_id"]: CsstPlanObsid.from_plan(_) for _ in d_list}
        )
