"""
Identifier:     csst_common/pipeline.py
Name:           file.py
Description:    pipeline operation
Author:         Bo Zhang
Created:        2023-12-13
Modified-History:
    2023-07-11, Bo Zhang, created
    2023-07-11, Bo Zhang, add Pipeline class
    2023-12-10, Bo Zhang, update Pipeline
    2023-12-15, Bo Zhang, add module header
"""

import json
import subprocess
import os
import shutil
import traceback
import warnings
from typing import Callable, NamedTuple, Optional, Any, Union
from astropy.time import Time, TimeDelta
from astropy.io import fits

from .utils import retry
from .ccds import CCDS
from .dfs import DFS
from .file import File
from .io import reformat_header
from .logger import get_logger
from .status import CsstStatus, CsstResult


def print_directory_tree(directory="."):
    for root, dirs, files in os.walk(directory):
        level = root.replace(directory, "").count(os.sep)
        indent = " " * 4 * (level)
        print(f"{indent}{os.path.basename(root)}/")
        subindent = " " * 4 * (level + 1)
        for file in files:
            print(f"{subindent}{file}")


class Pipeline:
    """
    CSST pipeline configuration class.

    It is used for each pipeline to initialize environment.

    Parameters
    ----------
    dir_input : str
        Input directory.
    dir_output : str
        Output directory.
    dir_temp : str
        Temp directory.
    dir_aux : str
        Aux path.
    dfs_root : str
        DFS root path.
    ccds_root : str
        CCDS root path.
    ccds_cache : str
        CCDS cache path.
    filter_warnings : bool
        If `True`, filter warnings.
    dfs : bool
        If `True`, initialize DFS.
    ccds : bool
        If `True`, initialize CCDS.
    **kwargs : Any
        Additional keyword arguments.
    """

    def __init__(
        self,
        dir_input: str = "/pipeline/input",
        dir_output: str = "/pipeline/output",
        dir_temp: str = "/pipeline/temp",
        dir_aux: str = "/pipeline/aux",
        dfs_root: str = "/dfs_root",
        ccds_root: str = "/ccds_root",
        ccds_cache: str = "/pipeline/ccds_cache",
        pipeline_log: str = "pipeline.log",
        module_log: str = "module.log",
        filter_warnings: bool = False,
        dfs: bool = True,
        ccds: bool = False,
        clean_output_before_run: bool = True,
        **kwargs: Any,
    ):
        # record start time
        self.t_start = Time.now()

        # get pipeline information from env vars
        self.pipeline_id: str = os.getenv("PIPELINE_ID", "-")
        self.build: int = int(os.getenv("BUILD", "0"))
        self.created: str = os.getenv("CREATED", "-")
        self.verbose: bool = bool(os.getenv("VERBOSE", ""))

        # set directory information
        self.dir_input: str = dir_input
        self.dir_output: str = dir_output
        self.dir_temp: str = dir_temp
        self.dir_aux: str = dir_aux
        self.dfs_root: str = dfs_root
        self.ccds_root: str = ccds_root
        self.ccds_cache: str = ccds_cache

        if clean_output_before_run:
            self.clean_output()

        # additional parameters
        self.kwargs: dict = kwargs

        # set logger
        self.pipeline_logger = get_logger(
            name="pipeline logger",
            filename=os.path.join(self.dir_output, pipeline_log),
        )
        self.module_logger = get_logger(
            name="module logger",
            filename=os.path.join(self.dir_output, module_log),
        )

        # change working directory
        print(f"Change directory to {self.dir_output}")
        os.chdir(self.dir_output)

        # Frequently used files
        self.message = Message(os.path.join(self.dir_output, "message.txt"))
        self.timestamp = Timestamp(os.path.join(self.dir_output, "timestamp.txt"))
        # self.exit_code = ExitCode(os.path.join(self.dir_output, "exit_code"))
        # self.error_trace = ErrorTrace(os.path.join(self.dir_output, "error_trace"))

        if dfs:
            self.dfs: Union[DFS | None] = DFS()
        else:
            self.dfs: Union[DFS | None] = None
        if ccds:
            self.ccds: Union[CCDS | None] = CCDS(
                ccds_root=ccds_root, ccds_cache=ccds_cache
            )
        else:
            self.ccds: Union[CCDS | None] = None

        if filter_warnings:
            self.filter_warnings()

    def info(self):
        """List environment variables such as `PIPELINE_ID`, etc."""
        print(f"PIPELINE_ID={self.pipeline_id}")
        print(f"BUILD={self.build}")
        print(f"CREATED={self.created}")
        print(f"VERBOSE={self.verbose}")

    @property
    def info_header(self) -> fits.Header:
        """Summarize pipeline info into a `astropy.io.fits.Header`."""
        h = fits.Header()
        h.set("PIPELINE", self.pipeline_id, comment="pipeline ID")
        h.set("BUILD", self.build, comment="pipeline build number")
        h.set("CREATED", self.pipeline_id, comment="pipeline build time")
        return reformat_header(h, strip=False, comment="Pipeline info")

    def summarize(self):
        """Summarize this run."""
        t_stop: Time = Time.now()
        t_cost: float = (t_stop - self.t_start).value * 86400.0
        self.pipeline_logger.info(f"Total cost: {t_cost:.1f} sec")

    def clean_output(self):
        """Clean output directory."""
        # self.clean_directory(self.dir_output)
        # clean in run command
        pass

    @staticmethod
    def clean_directory(d):
        """Clean a directory."""
        print(f"Clean output directory '{d}'...")
        try:
            r = subprocess.run(f"rm -rf {d}/*", shell=True, capture_output=True)
            print("> ", r)
            r.check_returncode()
        except:
            print("Failed to clean output directory!")
            print_directory_tree(d)

    def now(self):
        """Return ISOT format datetime using `astropy`."""
        return Time.now().isot

    @staticmethod
    def filter_warnings():
        # Suppress all warnings
        warnings.filterwarnings("ignore")

    @staticmethod
    def reset_warnings(self):
        """Reset warning filters."""
        warnings.resetwarnings()

    def file(self, file_path):
        """Initialize File object."""
        return File(file_path, new_dir=self.dir_output)

    def new(self, file_name="test.fits"):
        """Create new file in output directory."""
        return os.path.join(self.dir_output, file_name)

    @staticmethod
    def move(file_src: str, file_dst: str):
        """Move file `file_src` to `file_dist`."""
        shutil.move(file_src, file_dst)

    @staticmethod
    def copy(file_src: str, file_dst: str):
        """Move file `file_src` to `file_dist`."""
        shutil.copy(file_src, file_dst)

    def copy_to_output(self, file_paths: list):
        for file_path in file_paths:
            pass

    def call(self, func: Callable, *args: Any, **kwargs: Any):
        self.pipeline_logger.info(
            f"====================================================="
        )
        t_start: Time = Time.now()
        self.pipeline_logger.info(f"Starting Module: **{func.__name__}**")
        # logger.info(f"Additional arguments: {args} {kwargs}")
        try:
            # if the module works well
            res: CsstResult = func(*args, **kwargs)
            assert isinstance(res, CsstResult), res
            # define results
            status = res.status
            files = res.files
            output = res.output
        except Exception as e:
            # if the module raises error
            exc_info = traceback.format_exc()  # traceback info
            self.pipeline_logger.error(f"Error occurs! \n{exc_info}")
            # define results
            status = CsstStatus.ERROR  # default status if exceptions occur
            files = None
            output = {"exc_info": exc_info}  # default output if exceptions occur
        finally:
            t_stop: Time = Time.now()
            t_cost: float = (t_stop - t_start).value * 86400
            if isinstance(status, CsstStatus):
                # status is
                self.pipeline_logger.info(
                    f"Module finished: status={status} | cost={t_cost:.1f} sec"
                )
            else:
                # invalid status
                self.pipeline_logger.error(
                    f"Invalid status: {status} is not a CsstResult object!"
                )
            # record exception traceback info
            self.pipeline_logger.info(
                f"ModuleResult: \n"
                f"   - name: {func.__name__}\n"
                f"   - status: {status}\n"
                f"   - files: {files}\n"
                f"   - output: {output}\n"
            )
            return ModuleResult(
                module=func.__name__,
                cost=t_cost,
                status=status,
                files=files,
                output=output,
            )

    @staticmethod
    def retry(*args, **kwargs):
        return retry(*args, **kwargs)

    @property
    def pmapname(self):
        """CCDS `.pmap` name (operational context)."""
        if self.ccds is not None:
            return self.ccds.operational_context
        else:
            raise ValueError("CCDS client not initialized!")

    @property
    def pipeline_build(self):
        """{PIPELINE_ID}-{BUILD}"""
        return f"{self.pipeline_id}-{self.build}"


# class ErrorTrace:
#     """Write error trace to file."""
#
#     def __init__(self, file_path=""):
#         self.file_path = file_path
#
#     def __repr__(self):
#         return f"< ErrorTrace [{self.file_path}] >"
#
#     def write(self, s: str):
#         with open(self.file_path, "w+") as f:
#             f.write(s)


class Message:
    """Write JSON format messages to file."""

    def __init__(self, file_path: str = ""):
        self.file_path = file_path

    def __repr__(self):
        return f"< Message [{self.file_path}] >"

    def write(self, dlist: list[dict]):
        """Write messages to file."""
        with open(self.file_path, "w+") as f:
            for d in dlist:
                f.write(self.dict2msg(d) + "\n")

    def preview(self, dlist: list[dict], n: int = 10) -> None:
        """Preview top `n` messages."""
        print(f"No. of messages = {len(dlist)}")
        print(f"=========== Top {n} ===========")
        s = ""
        for d in dlist[:n]:
            s += self.dict2msg(d) + "\n"
        print(s)
        print("")

    @staticmethod
    def dict2msg(d: dict):
        """Convert `dict` to JSON format string."""
        m = json.dumps(d).replace(" ", "")
        return m

    @staticmethod
    def msg2dict(m: str):
        """Convert JSON format string to `dict`."""
        d = json.loads(m)
        return d


# DEPRECATED
# class ExitCode:
#     def __init__(self, file_path=""):
#         self.file_path = file_path
#
#     def __repr__(self):
#         return f"< ExitCode [{self.file_path}] >"
#
#     def truncate(self):
#         with open(self.file_path, "w") as file:
#             file.truncate(0)
#
#     def write(self, code=0):
#         with open(self.file_path, "w+") as f:
#             f.write(str(code))
#         print(f"Exit with code {code} (written to '{self.file_path}')")


class Timestamp:
    def __init__(self, file_path: str = "timestamp.txt"):
        """
        Timestamp Class.

        Initialize a Timestamp object anc connect it to `file_path`.

        Parameters
        ----------
        file_path : str
            Time stamp file path.
        """
        self.file_path = file_path

    def __repr__(self):
        return f"< Timestamp [{self.file_path}] >"

    def empty(self):
        """Clean time stamp file."""
        with open(self.file_path, "w") as file:
            file.truncate(0)

    def record(self):
        """Record a time stamp."""
        with open(self.file_path, "a+") as f:
            f.write(f"{Time.now().isot}+00:00\n")


class ModuleResult(NamedTuple):
    module: str
    cost: float
    status: CsstStatus
    files: Optional[list]
    output: dict
