"""
Identifier:     csst_common/pipeline.py
Name:           file.py
Description:    pipeline operation
Author:         Bo Zhang
Created:        2023-12-13
Modified-History:
    2023-07-11, Bo Zhang, created
    2023-07-11, Bo Zhang, add Pipeline class
    2023-12-10, Bo Zhang, update Pipeline
    2023-12-15, Bo Zhang, add module header
    2025-12-27, Bo Zhang, rewrite Pipeline class
"""

import sys
import json
import subprocess
import os
import shutil
import traceback
import warnings
from typing import Callable, NamedTuple, Optional, Any, Union
from astropy.time import Time, TimeDelta
from astropy.io import fits

import csst_dfs_client
import csst_fs

from .ccds import CCDS
from .utils import retry
from .file import File
from .logger import get_logger
from .status import CsstStatus, CsstResult
from .fits import s3_options


def print_directory_tree(directory="."):
    for root, dirs, files in os.walk(directory):
        level = root.replace(directory, "").count(os.sep)
        indent = " " * 4 * (level)
        print(f"{indent}{os.path.basename(root)}/")
        subindent = " " * 4 * (level + 1)
        for file in files:
            print(f"{subindent}{file}")


# reason: code
EXIT_CODES = {
    "success": 0,
    "ccds_reference_file_error": 10,  # CCDS 参考文件错误
    "ccds_pmap_error": 11,  # CCDS pmap 错误
    "reference_catalog_error": 12,  # 参考星表错误
    "input_data_invalid": 20,  # 输入数据错误
    "input_data_not_unique": 21,  # 输入数据不唯一
    "algorithm_error": 30,  # 算法错误
    "data_product_invalid": 40,  # 数据产品无效
    "data_product_ingestion_error": 41,  # 数据产品摄取错误
}


class Pipeline:
    """
    Examples
    --------
    >>> p = Pipeline()
    >>> p.info()
    >>> p.task
    {'a': 1}
    """

    def __init__(self, **env_vars: Any):
        # set environment variables
        for k, v in env_vars.items():
            os.environ[k] = str(v)

        # get settings from environment variables
        self.settings: dict = {
            # DFS & CCDS directories
            "DFS_ROOT": os.getenv("DFS_ROOT", "/dfs_root"),
            "CCDS_ROOT": os.getenv("CCDS_ROOT", "/ccds_root"),
            "CCDS_CACHE": os.getenv("CCDS_CACHE", "/pipeline/ccds_cache"),
            # working directories
            "DIR_INPUT": os.getenv("DIR_INPUT", "/pipeline/input"),
            "DIR_OUTPUT": os.getenv("DIR_OUTPUT", "/pipeline/output"),
            "DIR_TEMP": os.getenv("DIR_TEMP", "/pipeline/temp"),
            "DIR_AUX": os.getenv("DIR_AUX", "/pipeline/aux"),
            "LOG_FILE": os.getenv("LOG_FILE", "pipeline.log"),
            # docker image information
            "DOCKER_IMAGE": os.getenv("DOCKER_IMAGE", "-"),
            "BUILD": os.getenv("BUILD", "-"),
            "CREATED": os.getenv("CREATED", "-"),
            # additional settings
            "VERBOSE": os.getenv("VERBOSE", "false").lower() == "true",
            "DUMPDATA": os.getenv("DUMPDATA", "false").lower() == "true",
            "IGNORE_WARNINGS": os.getenv("IGNORE_WARNINGS", "true").lower() == "true",
            "USE_OSS": (os.getenv("USE_OSS", "false")).lower() == "true",
        }
        # set attributes
        for k, v in self.settings.items():
            setattr(self, k.lower(), v)

        # set logger
        self.logger = get_logger(
            name="pipeline logger",
            filename=str(os.path.join(self.dir_output, self.log_file)),
        )
        # filter warnings
        if self.settings["IGNORE_WARNINGS"]:
            self.filter_warnings("ignore")

        # DFS1, DFS2 & CCDS
        self.dfs1 = csst_dfs_client
        self.dfs2 = csst_fs
        self.ccds = CCDS(
            use_oss=self.use_oss,
            ccds_root=self.ccds_root,
            ccds_cache=self.ccds_cache,
            logger=self.logger,
        )

        # exit code -> p.exit(reason="ccds_error")
        self.EXIT_CODES = EXIT_CODES

        # record start time
        self.t_start = Time.now()
        self.logger.info(f"t_start = {self.t_start.isot}")
        # set message
        self.logger.info(f"sys.argv[1] = {sys.argv[1]}")
        if len(sys.argv) == 2:
            self.task = self.json2dict(sys.argv[1])
        else:
            self.task = {}
        self.logger.info(f"task = {self.task}")

        # record basic information
        self.logger.info(f"DOCKER_IMAGE={self.docker_image}")
        self.logger.info(f"BUILD={self.build}")
        self.logger.info(f"CREATED={self.created}")
        self.logger.info(f"VERBOSE={self.verbose}")
        self.logger.info(f"DUMPDATA={self.dumpdata}")

    # warning operations
    @staticmethod
    def filter_warnings(level: str = "ignore"):
        # Suppress all warnings
        warnings.filterwarnings(level)

    @staticmethod
    def reset_warnings(self):
        """Reset warning filters."""
        warnings.resetwarnings()

    # message operations
    @staticmethod
    def dict2json(d: dict):
        """Convert `dict` to JSON format string."""
        return json.dumps(d, ensure_ascii=False)

    @staticmethod
    def json2dict(m: str):
        """Convert JSON format string to `dict`."""
        return json.loads(m)

    # file operations
    @staticmethod
    def mkdir(d):
        """Create a directory if it does not exist."""
        if not os.path.exists(d):
            os.makedirs(d)

    @staticmethod
    def clean_directory(d):
        """Clean a directory."""
        print(f"Clean output directory '{d}'...")
        try:
            r = subprocess.run(f"rm -rf {d}/*", shell=True, capture_output=True)
            print("> ", r)
            r.check_returncode()
        except:
            print("Failed to clean output directory!")
            print_directory_tree(d)

    @staticmethod
    def move(file_src: str, file_dst: str) -> str:
        """Move file `file_src` to `file_dist`."""
        return shutil.move(file_src, file_dst)

    @staticmethod
    def copy(file_src: str, file_dst: str) -> str:
        """Move file `file_src` to `file_dist`."""
        return shutil.copy(file_src, file_dst)

    @property
    def summarize(self):
        """Summarize this run."""
        t_stop: Time = Time.now()
        t_cost: float = (t_stop - self.t_start).value * 86400.0
        self.logger.info(f"Total cost: {t_cost:.1f} sec")

    def clean_output(self):
        """Clean output directory."""
        self.clean_directory(self.dir_output)

    def file(self, file_path):
        """Initialize File object."""
        return File(file_path, new_dir=self.dir_output)

    def new(self, file_name="test.fits") -> str:
        """Create new file in output directory."""
        return os.path.join(self.dir_output, file_name)

    def download_oss_file(self, oss_file_path: str, dir_dst: str = None) -> str:
        """Download an OSS file from OSS to output directory."""
        if dir_dst is None:
            dir_dst = self.dir_output
        local_file_path = os.path.join(dir_dst, os.path.basename(oss_file_path))
        csst_fs.s3_fs.get(oss_file_path, local_file_path, s3_options=s3_options)
        assert os.path.exists(
            local_file_path
        ), f"Failed to download {oss_file_path} to {local_file_path}"
        return local_file_path

    def abspath(self, file_path: str) -> str:
        """Return absolute path of `file_path`."""
        if file_path.__contains__(":"):
            # it's an OSS file path
            assert self.use_oss, "USE_OSS must be True to use OSS file path!"
            # download OSS file to output directory
            local_file_path = self.download_oss_file(file_path)
            # return local file path
            return local_file_path
        else:
            # it's a NAS file path
            if file_path.startswith("CSST"):
                # DFS
                return os.path.join(self.dfs_root, file_path)
            else:
                # CCDS
                return os.path.join(self.ccds_root, file_path)

    def download_dfs_file(self, file_path: str, dir_dst: str = None) -> str:
        """Copy DFS file to output directory."""
        if dir_dst is None:
            dir_dst = self.dir_output

        if self.use_oss:
            # download OSS file to dst directory
            return self.download_oss_file(file_path, dir_dst)
        else:
            # copy DFS file to dst directory
            local_file_path = os.path.join(dir_dst, os.path.basename(file_path))
            self.copy(self.abspath(file_path), local_file_path)
            return local_file_path

    def download_ccds_refs(self, refs: dict, dir_dst: str = None) -> dict:
        """Copy raw file from CCDS to output directory."""
        if dir_dst is None:
            dir_dst = self.dir_output

        local_refs = {}
        for ref_name, ref_path in refs.items():
            local_file_path = os.path.join(dir_dst, os.path.basename(ref_path))
            local_refs[ref_name] = self.copy(self.abspath(ref_path), local_file_path)
        return local_refs

    # time operations
    @staticmethod
    def now() -> str:
        """Return ISOT format datetime using `astropy`."""
        return Time.now().isot

    # call modules
    def call(self, func: Callable, *args: Any, **kwargs: Any):
        self.logger.info(f"=====================================================")
        t_start: Time = Time.now()
        self.logger.info(f"Starting Module: **{func.__name__}**")
        # logger.info(f"Additional arguments: {args} {kwargs}")
        try:
            # if the module works well
            res: CsstResult = func(*args, **kwargs)
            assert isinstance(res, CsstResult), res
            # define results
            status = res.status
            files = res.files
            output = res.output
        except Exception as e:
            # if the module raises error
            exc_info = traceback.format_exc()  # traceback info
            self.logger.error(f"Error occurs! \n{exc_info}")
            # define results
            status = CsstStatus.ERROR  # default status if exceptions occur
            files = None
            output = {"exc_info": exc_info}  # default output if exceptions occur
        finally:
            t_stop: Time = Time.now()
            t_cost: float = (t_stop - t_start).value * 86400
            if isinstance(status, CsstStatus):
                # status is
                self.logger.info(
                    f"Module finished: status={status} | cost={t_cost:.1f} sec"
                )
            else:
                # invalid status
                self.logger.error(
                    f"Invalid status: {status} is not a CsstResult object!"
                )
            # record exception traceback info
            self.logger.info(
                f"ModuleResult: \n"
                f"   - name: {func.__name__}\n"
                f"   - status: {status}\n"
                f"   - files: {files}\n"
                f"   - output: {output}\n"
            )
            return ModuleResult(
                module=func.__name__,
                cost=t_cost,
                status=status,
                files=files,
                output=output,
            )

    # retry operations
    @staticmethod
    def retry(*args, **kwargs):
        return retry(*args, **kwargs)

    @property
    def pmapname(self):
        """Final CCDS `.pmap` name."""
        task_pmapname = self.task.get("pmapname", None)
        # TODO: validate this pmapname
        if task_pmapname and self.ccds.validate(task_pmapname):
            # task specified pmap
            return task_pmapname
        else:
            # CCDS recommended pmap
            return self.ccds.operational_context

    @property
    def ref_cat(self):
        """Final DFS catalog name."""
        task_ref_cat = self.task.get("ref_cat", None)
        if task_ref_cat and task_ref_cat in self.dfs1.catalog.all_catalog_names:
            # task specified catalog
            return task_ref_cat
        else:
            # DFS recommended catalog
            if self.docker_image == "csst-msc-l1-mbi":
                return "trilegal_093"
            else:
                raise ValueError(
                    f"Invalid ref_cat: {task_ref_cat} not in {self.dfs1.catalog.all_catalog_names}"
                )

    def exit(self, reason: str = None):
        """Exit pipeline with reason.

        Examples
        --------
        >>> p.exit(reason="ccds_error")
        >>> p.exit(reason="success")
        """
        assert (
            reason in self.EXIT_CODES.keys()
        ), f"Reason {reason} not in {self.EXIT_CODES.keys()}"
        sys.exit(self.EXIT_CODES[reason])

    def add_metadata(self, meta: dict):
        """Add metadata to pipeline."""
        raise NotImplementedError()


class ModuleResult(NamedTuple):
    module: str
    cost: float
    status: CsstStatus
    files: Optional[list]
    output: dict
