"""
Identifier:     csst_common/pipeline.py
Name:           file.py
Description:    pipeline operation
Author:         Bo Zhang
Created:        2023-12-13
Modified-History:
    2023-07-11, Bo Zhang, created
    2023-07-11, Bo Zhang, add Pipeline class
    2023-12-10, Bo Zhang, update Pipeline
    2023-12-15, Bo Zhang, add module header
    2025-12-27, Bo Zhang, rewrite Pipeline class
"""

import sys
import json
import subprocess
import os
import shutil
import traceback
import warnings
from typing import Callable, NamedTuple, Optional, Any, Union
from astropy.time import Time, TimeDelta
from astropy.io import fits

import csst_dfs_client
import csst_fs

from .ccds import CCDS
from .utils import retry
from .file import File
from .logger import get_logger
from .status import CsstStatus, CsstResult
from .fits import s3_options


def print_directory_tree(directory="."):
    for root, dirs, files in os.walk(directory):
        level = root.replace(directory, "").count(os.sep)
        indent = " " * 4 * (level)
        print(f"{indent}{os.path.basename(root)}/")
        subindent = " " * 4 * (level + 1)
        for file in files:
            print(f"{subindent}{file}")


EXIT_CODES = {
    "success": 0,
    "ccds_reference_file_error": 10,
    "reference_catalog_error": 11,
    "input_data_error": 20,
    "input_data_not_unique": 21,
    "input_data_invalid": 22,
    "data_product_invalid": 30,
    "data_product_ingestion_error": 31,
}


class Pipeline:
    """
    Examples
    --------
    >>> p = Pipeline(msg="{'a':1}")
    >>> p.msg_dict
    {'a': 1}
    """

    def __init__(self, **env_vars: Any):
        # record start time
        self.t_start = Time.now()

        # set message
        self.msg = sys.argv[1]
        self.msg_dict = self.json2dict(self.msg)

        # set environment variables
        for k, v in env_vars.items():
            os.environ[k] = str(v)

        # get settings from environment variables
        self.settings: dict = {
            # DFS & CCDS directories
            "DFS_ROOT": os.getenv("DFS_ROOT", "/dfs_root"),
            "CCDS_ROOT": os.getenv("CCDS_ROOT", "/ccds_root"),
            "CCDS_CACHE": os.getenv("CCDS_CACHE", "/pipeline/ccds_cache"),
            # working directories
            "DIR_INPUT": os.getenv("DIR_INPUT", "/pipeline/input"),
            "DIR_OUTPUT": os.getenv("DIR_OUTPUT", "/pipeline/output"),
            "DIR_TEMP": os.getenv("DIR_TEMP", "/pipeline/temp"),
            "DIR_AUX": os.getenv("DIR_AUX", "/pipeline/aux"),
            "LOG_FILE": os.getenv("LOG_FILE", "pipeline.log"),
            # docker image information
            "DOCKER_IMAGE": os.getenv("DOCKER_IMAGE", "-"),
            "BUILD": os.getenv("BUILD", "-"),
            "CREATED": os.getenv("CREATED", "-"),
            # additional settings
            "VERBOSE": os.getenv("VERBOSE", "false").lower() == "true",
            "IGNORE_WARNINGS": os.getenv("IGNORE_WARNINGS", "true").lower() == "true",
            "USE_OSS": (os.getenv("USE_OSS", "false")).lower() == "true",
        }
        # set attributes
        for k, v in self.settings.items():
            setattr(self, k.lower(), v)

        # set logger
        self.logger = get_logger(
            name="pipeline logger",
            filename=str(os.path.join(self.dir_output, self.log_file)),
        )
        # filter warnings
        if self.settings["IGNORE_WARNINGS"]:
            self.filter_warnings("ignore")

        # DFS1, DFS2 & CCDS
        self.dfs1 = csst_dfs_client
        self.dfs2 = csst_fs
        self.ccds = CCDS(ccds_root=self.ccds_root, ccds_cache=self.ccds_cache)

        # exit code
        self.EXIT_CODES = EXIT_CODES

    @staticmethod
    def dict2json(d: dict):
        """Convert `dict` to JSON format string."""
        return json.dumps(d, ensure_ascii=False)

    @staticmethod
    def json2dict(m: str):
        """Convert JSON format string to `dict`."""
        return json.loads(m)

    @property
    def summarize(self):
        """Summarize this run."""
        t_stop: Time = Time.now()
        t_cost: float = (t_stop - self.t_start).value * 86400.0
        self.logger.info(f"Total cost: {t_cost:.1f} sec")

    def clean_output(self):
        """Clean output directory."""
        # self.clean_directory(self.dir_output)
        # clean in run command
        pass

    @staticmethod
    def clean_directory(d):
        """Clean a directory."""
        print(f"Clean output directory '{d}'...")
        try:
            r = subprocess.run(f"rm -rf {d}/*", shell=True, capture_output=True)
            print("> ", r)
            r.check_returncode()
        except:
            print("Failed to clean output directory!")
            print_directory_tree(d)

    def now(self):
        """Return ISOT format datetime using `astropy`."""
        return Time.now().isot

    @staticmethod
    def filter_warnings(level: str = "ignore"):
        # Suppress all warnings
        warnings.filterwarnings(level)

    @staticmethod
    def reset_warnings(self):
        """Reset warning filters."""
        warnings.resetwarnings()

    def file(self, file_path):
        """Initialize File object."""
        return File(file_path, new_dir=self.dir_output)

    def new(self, file_name="test.fits"):
        """Create new file in output directory."""
        return os.path.join(self.dir_output, file_name)

    @staticmethod
    def move(file_src: str, file_dst: str):
        """Move file `file_src` to `file_dist`."""
        shutil.move(file_src, file_dst)

    @staticmethod
    def copy(file_src: str, file_dst: str):
        """Move file `file_src` to `file_dist`."""
        shutil.copy(file_src, file_dst)

    def download_oss_file(self, oss_file_path: str) -> str:
        """Download an OSS file from OSS to output directory."""
        local_file_path = os.path.join(self.dir_output, os.path.basename(oss_file_path))
        csst_fs.s3_fs.get(oss_file_path, local_file_path, s3_options=s3_options)
        assert os.path.exists(
            local_file_path
        ), f"Failed to download {oss_file_path} to {local_file_path}"
        return local_file_path

    def call(self, func: Callable, *args: Any, **kwargs: Any):
        self.logger.info(f"=====================================================")
        t_start: Time = Time.now()
        self.logger.info(f"Starting Module: **{func.__name__}**")
        # logger.info(f"Additional arguments: {args} {kwargs}")
        try:
            # if the module works well
            res: CsstResult = func(*args, **kwargs)
            assert isinstance(res, CsstResult), res
            # define results
            status = res.status
            files = res.files
            output = res.output
        except Exception as e:
            # if the module raises error
            exc_info = traceback.format_exc()  # traceback info
            self.logger.error(f"Error occurs! \n{exc_info}")
            # define results
            status = CsstStatus.ERROR  # default status if exceptions occur
            files = None
            output = {"exc_info": exc_info}  # default output if exceptions occur
        finally:
            t_stop: Time = Time.now()
            t_cost: float = (t_stop - t_start).value * 86400
            if isinstance(status, CsstStatus):
                # status is
                self.logger.info(
                    f"Module finished: status={status} | cost={t_cost:.1f} sec"
                )
            else:
                # invalid status
                self.logger.error(
                    f"Invalid status: {status} is not a CsstResult object!"
                )
            # record exception traceback info
            self.logger.info(
                f"ModuleResult: \n"
                f"   - name: {func.__name__}\n"
                f"   - status: {status}\n"
                f"   - files: {files}\n"
                f"   - output: {output}\n"
            )
            return ModuleResult(
                module=func.__name__,
                cost=t_cost,
                status=status,
                files=files,
                output=output,
            )

    @staticmethod
    def retry(*args, **kwargs):
        return retry(*args, **kwargs)

    @property
    def pmapname(self):
        """CCDS `.pmap` name (operational context)."""
        if self.ccds is not None:
            return self.ccds.operational_context
        else:
            raise ValueError("CCDS client not initialized!")

    def info(self):
        """Return pipeline information."""
        self.logger.info(f"PIPELINE_ID={self.pipeline_id}")
        self.logger.info(f"BUILD={self.build}")
        self.logger.info(f"CREATED={self.created}")
        self.logger.info(f"VERBOSE={self.verbose}")


class ModuleResult(NamedTuple):
    module: str
    cost: float
    status: CsstStatus
    files: Optional[list]
    output: dict
