"""
Identifier:     csst_common/pipeline.py
Name:           file.py
Description:    pipeline operation
Author:         Bo Zhang
Created:        2023-12-13
Modified-History:
    2023-07-11, Bo Zhang, created
    2023-07-11, Bo Zhang, add Pipeline class
    2023-12-10, Bo Zhang, update Pipeline
    2023-12-15, Bo Zhang, add module header
    2025-12-27, Bo Zhang, rewrite Pipeline class
"""

import sys
import json
import subprocess
import os
import copy
import shutil
import traceback
import warnings
from typing import Callable, NamedTuple, Optional, Any, Union
from astropy.time import Time, TimeDelta
from astropy.io import fits
from astropy import table
import joblib

import csst_dfs_client
import csst_fs
from csst_fs import s3_fs

from .ccds import CCDS
from .utils import retry
from .file import File
from .logger import get_logger
from .status import CsstStatus, CsstResult
from .fits import s3_options
from . import io


s3_options = csst_fs.s3_config.load_s3_options()


def print_directory_tree(directory="."):
    for root, dirs, files in os.walk(directory):
        level = root.replace(directory, "").count(os.sep)
        indent = " " * 4 * (level)
        print(f"{indent}{os.path.basename(root)}/")
        subindent = " " * 4 * (level + 1)
        for file in files:
            print(f"{subindent}{file}")


# reason: code
EXIT_CODES = {
    "success": 0,
    "ccds_reference_file_error": 10,  # CCDS 参考文件错误
    "ccds_pmap_error": 11,  # CCDS pmap 错误
    "reference_catalog_error": 12,  # 参考星表错误
    "input_data_invalid": 20,  # 输入数据错误
    "input_data_not_unique": 21,  # 输入数据不唯一
    "algorithm_error": 30,  # 算法错误
    "data_product_invalid": 40,  # 数据产品无效
    "data_product_ingestion_error": 41,  # 数据产品摄取错误
}


class Pipeline:
    """
    Examples
    --------
    >>> p = Pipeline()
    >>> p.info()
    >>> p.task
    {'a': 1}
    """

    def __init__(self, **env_vars: Any):
        # set environment variables
        for k, v in env_vars.items():
            os.environ[k] = str(v)

        # get settings from environment variables
        self.settings: dict = {
            # DFS & CCDS directories
            "DFS_ROOT": os.getenv("DFS_ROOT", "/dfs_root"),
            "CCDS_ROOT": os.getenv("CCDS_ROOT", "/ccds_root"),
            "CCDS_CACHE": os.getenv("CCDS_CACHE", "/pipeline/ccds_cache"),
            # working directories
            "DIR_INPUT": os.getenv("DIR_INPUT", "/pipeline/input"),
            "DIR_OUTPUT": os.getenv("DIR_OUTPUT", "/pipeline/output"),
            "DIR_TEMP": os.getenv("DIR_TEMP", "/pipeline/temp"),
            "DIR_AUX": os.getenv("DIR_AUX", "/pipeline/aux"),
            "LOG_FILE": os.getenv("LOG_FILE", "pipeline.log"),
            # docker image information
            "DOCKER_IMAGE": os.getenv("DOCKER_IMAGE", "-"),
            "BUILD": os.getenv("BUILD", "-"),
            "CREATED": os.getenv("CREATED", "-"),
            # additional settings
            "VERBOSE": os.getenv("VERBOSE", "false").lower() == "true",
            "DUMPDATA": os.getenv("DUMPDATA", "false").lower() == "true",
            "IGNORE_WARNINGS": os.getenv("IGNORE_WARNINGS", "true").lower() == "true",
            "USE_OSS": (os.getenv("USE_OSS", "false")).lower() == "true",
        }
        # set attributes
        for k, v in self.settings.items():
            setattr(self, k.lower(), v)

        # set logger
        self.logger = get_logger(
            name="pipeline logger",
            filename=str(os.path.join(self.dir_output, self.log_file)),
        )
        # filter warnings
        if self.settings["IGNORE_WARNINGS"]:
            self.filter_warnings("ignore")

        # DFS1, DFS2 & CCDS
        self.dfs1 = csst_dfs_client
        self.dfs2 = csst_fs
        self.ccds = CCDS(
            use_oss=self.use_oss,
            ccds_root=self.ccds_root,
            ccds_cache=self.ccds_cache,
            logger=self.logger,
        )

        # exit code -> p.exit(reason="ccds_error")
        self.EXIT_CODES = EXIT_CODES

        # record start time
        self.t_start = Time.now()
        self.logger.info(f"p.t_start = {self.t_start.isot}")
        # set message
        self.logger.info(f"sys.argv[1] = {sys.argv[1]}")
        if len(sys.argv) == 2:
            self.task = self.json2dict(sys.argv[1])
        else:
            self.task = {}
        self.logger.info(f"p.task = {self.task}")
        self.logger.info(f"p.settings = {self.settings}")
        if self.dumpdata:
            joblib.dump(self.task, os.path.join(self.dir_output, "task.joblib"))

        # record basic information
        self.logger.info(f"DOCKER_IMAGE={self.docker_image}")
        self.logger.info(f"BUILD={self.build}")
        self.logger.info(f"CREATED={self.created}")
        self.logger.info(f"VERBOSE={self.verbose}")
        self.logger.info(f"DUMPDATA={self.dumpdata}")

    # warning operations
    @staticmethod
    def filter_warnings(level: str = "ignore"):
        # Suppress all warnings
        warnings.filterwarnings(level)

    @staticmethod
    def reset_warnings(self):
        """Reset warning filters."""
        warnings.resetwarnings()

    # message operations
    @staticmethod
    def dict2json(d: dict):
        """Convert `dict` to JSON format string."""
        return json.dumps(d, ensure_ascii=False)

    @staticmethod
    def json2dict(m: str):
        """Convert JSON format string to `dict`."""
        return json.loads(m)

    # file operations
    @staticmethod
    def mkdir(d):
        """Create a directory if it does not exist."""
        if not os.path.exists(d):
            os.makedirs(d)

    @staticmethod
    def clean_directory(d):
        """Clean a directory."""
        print(f"Clean output directory '{d}'...")
        try:
            r = subprocess.run(f"rm -rf {d}/*", shell=True, capture_output=True)
            print("> ", r)
            r.check_returncode()
        except:
            print("Failed to clean output directory!")
            print_directory_tree(d)

    # file operations
    @staticmethod
    def move(file_src: str, file_dst: str) -> str:
        """Move file `file_src` to `file_dist`."""
        return shutil.move(file_src, file_dst)

    @staticmethod
    def copy_nas_file(file_src: str, file_dst: str) -> str:
        """Copy NAS file `file_src` to `file_dist`."""
        shutil.copy(file_src, file_dst)
        assert os.path.exists(
            file_dst
        ), f"Failed to copy NAS file {file_src} to {file_dst}!"
        return file_dst

    @staticmethod
    def dump_oss_file(rpath: str, lpath: str) -> str:
        """Copy OSS file `file_src` to `file_dist`."""
        print(f"rpath={rpath}, lpath={lpath}")
        s3_fs.get(rpath, lpath)
        assert os.path.exists(lpath), f"Failed to dump OSS file {rpath} to {lpath}!"
        return lpath

    # @property
    # def summarize(self):
    #     """Summarize this run."""
    #     t_stop: Time = Time.now()
    #     t_cost: float = (t_stop - self.t_start).value * 86400.0
    #     self.logger.info(f"Total cost: {t_cost:.1f} sec")

    def clean_output(self):
        """Clean output directory."""
        self.clean_directory(self.dir_output)

    def file(self, file_path):
        """Initialize File object."""
        return File(file_path, new_dir=self.dir_output)

    def new(self, file_name="test.fits") -> str:
        """Create new file in output directory."""
        return os.path.join(self.dir_output, file_name)

    def dump_file(self, remote_file_path: str, local_file_path: str = None) -> str:
        """Copy file `remote_file_path` to `local_file_path`."""
        is_oss = remote_file_path.__contains__(":")
        self.logger.info(f"Dumping file: {remote_file_path} -> {local_file_path}")
        if is_oss:
            local_file_path = self.dump_oss_file(remote_file_path, local_file_path)
        else:
            local_file_path = self.copy_nas_file(remote_file_path, local_file_path)
        return local_file_path

    # abspath
    def convert_to_abspath_for_dfs_recs(self, dfs_rec_list: list[dict]) -> list[dict]:
        """Convert `file_path` to absolute path for DFS."""
        dfs_recs_abs = copy.deepcopy(dfs_rec_list)
        for rec in dfs_recs_abs:
            rec["file_path"] = (
                os.path.join(self.dfs_root, rec["file_path"])
                if not rec["file_path"].__contains__(":")
                else rec["file_path"]
            )
        return dfs_recs_abs

    def convert_to_abspath_for_ccds_refs(self, ccds_refs: dict) -> dict:
        """Convert `file_path` to absolute path for CCDS."""
        ccds_refs_abs = copy.deepcopy(ccds_refs)
        print(ccds_refs_abs)
        for ref_name, ref_path in ccds_refs_abs["refs"].items():
            print(ref_name, ref_path)
            ccds_refs_abs["refs"][ref_name] = (
                os.path.join(self.ccds_root, ref_path)
                if not ref_path["file_path"].__contains__(":")
                else ref_path
            )
        print(ccds_refs_abs)
        return ccds_refs_abs

    def dump_dfs_plan(self, dfs_plan):
        self.mkdir(os.path.join(self.dir_output, "dfs"))
        output_path = os.path.join(self.dir_output, "dfs", "dfs_plan.joblib")
        joblib.dump(dfs_plan, output_path)
        return output_path

    def dump_dfs_recs(
        self, dfs_recs_abs: list[dict], dir_dump: str = None
    ) -> list[dict]:
        """Copy DFS files to output directory."""
        # set default dir_dump to output directory
        if dir_dump is None:
            dir_dump = os.path.join(self.dir_output, "dfs")
            self.mkdir(dir_dump)
        # dump data to dir_dump
        dfs_recs_dump = copy.deepcopy(dfs_recs_abs)
        for rec in dfs_recs_dump:
            remote_file_path = rec["file_path"]
            local_file_path = os.path.join(dir_dump, os.path.basename(remote_file_path))
            # copy DFS file to local_file_path
            self.dump_file(remote_file_path, local_file_path)
            rec["file_path"] = local_file_path
        joblib.dump(dfs_recs_dump, os.path.join(dir_dump, "dfs_rec_list.joblib"))
        return dfs_recs_dump

    def dump_ccds_refs_list(
        self, ccds_refs_list: list[dict], dir_dump: str = None
    ) -> list[dict]:
        """Copy raw file from CCDS to output directory."""
        # set default dir_dump to output directory
        if dir_dump is None:
            dir_dump = os.path.join(self.dir_output, "ccds")
            self.mkdir(dir_dump)
        # dump data to dir_dump
        ccds_refs_list_dump = copy.deepcopy(ccds_refs_list)
        # loop over recs
        for this_ccds_refs in ccds_refs_list_dump:
            this_dir_dump = os.path.join(
                dir_dump,
                os.path.basename(this_ccds_refs["file_path"]),
                this_ccds_refs["pmapname"],
            )
            self.mkdir(this_dir_dump)
            # loop over refs
            for ref_name, ref_path in this_ccds_refs["refs"].items():
                remote_file_path = ref_path
                local_file_path = os.path.join(
                    this_dir_dump,
                    os.path.basename(remote_file_path),
                )
                # copy DFS file to local_file_path
                self.dump_file(remote_file_path, local_file_path)
                this_ccds_refs["refs"][ref_name] = local_file_path

        joblib.dump(
            ccds_refs_list_dump,
            os.path.join(dir_dump, "ccds_refs_list.joblib"),
        )
        return ccds_refs_list_dump

    # time operations
    @staticmethod
    def now() -> str:
        """Return ISOT format datetime using `astropy`."""
        return Time.now().isot

    # call modules
    def call(self, func: Callable, *args: Any, **kwargs: Any):
        self.logger.info(f"=====================================================")
        t_start: Time = Time.now()
        self.logger.info(f"Starting Module: **{func.__name__}**")
        # logger.info(f"Additional arguments: {args} {kwargs}")
        try:
            # if the module works well
            res: CsstResult = func(*args, **kwargs)
            assert isinstance(res, CsstResult), res
            # define results
            status = res.status
            products = res.products
            files = res.files
            output = res.output
        except Exception as e:
            # if the module raises error
            exc_info = traceback.format_exc()  # traceback info
            self.logger.error(f"Error occurs! \n{exc_info}")
            # define results
            status = CsstStatus.ERROR  # default status if exceptions occur
            products = None
            files = None
            output = {"exc_info": exc_info}  # default output if exceptions occur
        finally:
            t_stop: Time = Time.now()
            t_cost: float = (t_stop - t_start).value * 86400
            if isinstance(status, CsstStatus):
                # status is
                self.logger.info(
                    f"Module finished: status={status} | cost={t_cost:.1f} sec"
                )
            else:
                # invalid status
                self.logger.error(
                    f"Invalid status: {status} is not a CsstResult object!"
                )
            # record exception traceback info
            self.logger.info(
                f"ModuleResult: \n"
                f"   - name: {func.__name__}\n"
                f"   - status: {status}\n"
                f"   - products: {products}\n"
                f"   - files: {files}\n"
                f"   - output: {output}\n"
            )
            return ModuleResult(
                module=func.__name__,
                cost=t_cost,
                status=status,
                products=products,
                files=files,
                output=output,
            )

    # retry operations
    @staticmethod
    def retry(func: Callable, *args, **kwargs):
        return retry(5, func, *args, **kwargs)

    @property
    def pmapname(self) -> str:
        """Final CCDS `.pmap` name."""
        task_pmapname = self.task.get("pmapname", None)
        if task_pmapname:
            if self.ccds.validate(task_pmapname):
                # task specified pmap
                return task_pmapname
            else:
                self.ccds.list()
                raise ValueError(f"Invalid pmapname: {task_pmapname}")
        else:
            # CCDS recommended pmap
            return self.ccds.operational_context

    @property
    def ref_cat(self) -> None | str:
        """Final DFS catalog name."""
        task_ref_cat = self.task.get("ref_cat", None)
        if task_ref_cat and task_ref_cat in self.dfs1.catalog.all_catalog_names():
            # task specified catalog
            return task_ref_cat
        else:
            # DFS recommended catalog
            if self.docker_image in [
                "csst-msc-l1-mbi",
                "csst-msc-l1-ast-astrometry",
            ]:
                return "trilegal_093"
            else:
                return None

    def exit(self, reason: str = None):
        """Exit pipeline with reason.

        Examples
        --------
        >>> p.exit(reason="ccds_error")
        >>> p.exit(reason="success")
        """
        assert (
            reason in self.EXIT_CODES.keys()
        ), f"Reason {reason} not in {self.EXIT_CODES.keys()}"
        self.logger.info(f"Pipeline exit with reason: {reason}")
        sys.exit(self.EXIT_CODES[reason])

    def add_metadata_to_image(
        self,
        file_path,
        data_model: str,
        qc_status: int = -1024,
        filter: str = None,
        custom_id: str = None,
        object: str = None,
        proposal_id: str = None,
        ra: float = None,
        dec: float = None,
        healpix: int = None,
        obs_date: str = None,
        parent_uuids: list = [],
    ):
        meta_kwargs = dict(
            # 编排信息
            dataset=self.task.get("dataset"),
            instrument=self.task.get("instrument"),
            obs_type=self.task.get("obs_type"),
            obs_group=self.task.get("obs_group"),
            obs_id=self.task.get("obs_id"),
            # 探测信息
            detector=self.task.get("detector"),
            filter=self.task.get("filter") or filter,
            # 参考信息
            pmapname=self.pmapname,
            ref_cat=self.ref_cat,
            # 数据处理信息
            custom_id=self.task.get("custom_id") or custom_id,
            batch_id=self.task.get("batch_id"),
            dag_group=self.task.get("dag_group"),
            dag_group_run=self.task.get("dag_group_run"),
            dag=self.task.get("dag"),
            dag_run=self.task.get("dag_run"),
            priority=self.task.get("priority"),
            # data_list=self.task.get("data_list") or [],
            extra_kwargs=self.task.get("extra_kwargs") or {},
            created_time=self.task.get("created_time"),
            rerun=self.task.get("rerun"),
            # 数据产品信息
            data_model=data_model,  # 数据产品类型，手动设置
            data_uuid=None,  # UUID，自动设置
            qc_status=qc_status,  # QC状态
            # Docker镜像名称和版本
            docker_image=None,  # 镜像名称，自动设置
            build=None,  # 镜像版本，自动设置
            # 额外的观测筛选参数
            object=self.task.get("object") or object,  # 观测目标
            proposal_id=self.task.get("proposal_id") or proposal_id,  # 观测申请ID
            ra=ra,  # 赤经
            dec=dec,  # 赤纬
            healpix=healpix,  # HEALPix，每种数据产品的nside可以不一样
            obs_date=obs_date,  # 观测时间
            prc_date=self.now(),  # 处理时间
            parent_uuids=parent_uuids,  # 父数据产品UUID列表
        )
        """Add metadata to file."""
        self.logger.info(f"Add metadata to {file_path}...")
        with fits.open(file_path, mode="update") as hdulist:
            meta: dict = io.generate_meta(**meta_kwargs)
            self.logger.info(f"Metadata: {meta}")
            hdulist: fits.HDUList = io.append_meta(hdulist, meta)
            hdulist.flush()
        return meta

    def query_ref_cat(
        self,
        file_path,
        catalog_name: Optional[str] = None,
        radius: float = 2.0,
    ) -> None | str:
        """
        References
        ----------
        https://gea.esac.esa.int/archive/documentation/GDR3/Gaia_archive/chap_datamodel/sec_dm_main_source_catalogue/ssec_dm_gaia_source.html
        """
        if not catalog_name:
            catalog_name = self.ref_cat
        image = File(file_path, new_dir=self.dir_output)
        header = fits.getheader(image.file_path)
        pointing_ra = header["RA_OBJ"]
        pointing_dec = header["DEC_OBJ"]
        print(f"Query reference catalog RA={pointing_ra}, Dec={pointing_dec} ...")
        rec_cat = self.retry(
            self.dfs1.catalog.search,
            ra=pointing_ra,
            dec=pointing_dec,
            catalog_name=catalog_name,
            radius=radius,
            columns=["*"],
            min_mag=0,
            max_mag=30,
            obstime=-1,
            limit=-1,
        )
        assert rec_cat.success, rec_cat

        # convert pandas data frame to table
        ref_cat = table.Table.from_pandas(rec_cat["data"])
        self.logger.info(f"{len(ref_cat)} rows in reference catalog")
        if len(ref_cat) < 10:
            return None
        # rename columns to lower case
        ref_cat.rename_columns(
            ref_cat.colnames, [colname.lower() for colname in ref_cat.colnames]
        )
        ref_cat_path = image.derive1(ext="cat_ref.fits")
        self.logger.info(f"Writing to file: {ref_cat_path} ...")
        ref_cat.write(ref_cat_path, overwrite=True)
        return ref_cat_path


class ModuleResult(NamedTuple):
    module: str = ""
    cost: float = 0.0
    status: CsstStatus = CsstStatus.ERROR
    products: Optional[list[str]] = []
    files: Optional[list[str]] = []
    output: dict = {}
