import json import os from typing import Callable, Optional import yaml from astropy import table, time from ._dag_utils import ( force_string, override_common_keys, generate_sha1_from_time, ) from ..dfs import DFS from ._dispatcher import Dispatcher DAG_CONFIG_DIR = os.path.join( os.path.dirname(os.path.dirname(__file__)), "dag_config", ) class BaseDAG: """Base class for all Directed Acyclic Graph (DAG) implementations. This class provides core functionality for DAG configuration, message generation, and task scheduling. """ @staticmethod def generate_sha1(): """Generate a unique SHA1 hash based on current timestamp. Returns ------- str SHA1 hash string """ return generate_sha1_from_time(verbose=False) @staticmethod def generate_dag_group_run( dag_group: str = "default-dag-group", batch_id: str = "default-batch", priority: int | str = 1, ): """Generate a DAG group run configuration. Parameters ---------- dag_group : str, optional Group identifier (default: "-") batch_id : str, optional Batch identifier (default: "-") priority : int | str, optional Execution priority (default: 1) Returns ------- dict Dictionary containing: - dag_group: Original group name - dag_group_run: Generated SHA1 identifier - batch_id: Batch identifier - priority: Execution priority """ return dict( dag_group=dag_group, dag_group_run=BaseDAG.generate_sha1(), batch_id=batch_id, priority=priority, created_time=time.Time.now().isot, ) @staticmethod def force_string(d: dict): return force_string(d) class Level2DAG(BaseDAG): """Level 2 DAG base class. Base class for all Level 2 Directed Acyclic Graph (DAG) implementations. This class provides core functionality for DAG configuration, message generation, and task scheduling. """ def __init__(self): pass def schedule(self, plan_basis: table.Table, data_basis: table.Table): """Schedule the DAG for execution. Parameters ---------- plan_basis : table.Table Plan basis table data_basis : table.Table Data basis table """ pass def generate_dag_run(self): """Generate a DAG run configuration. Returns ------- dict Dictionary containing DAG run configuration """ pass class Level1DAG(BaseDAG): """Level 1 DAG base class. Base class for all Level 1 Directed Acyclic Graph (DAG) implementations. This class provides core functionality for DAG configuration, message generation, and execution management within the CSST dlist processing system. Attributes ---------- dag : str Name of the DAG, must exist in DAG_MAP dag_cfg : dict Configuration loaded from YAML file dag_run_template : dict Message template structure loaded from JSON file Raises ------ AssertionError If DAG name is not in DAG_MAP or config name mismatch """ def __init__( self, dag: str, pattern: table.Table, dispatcher: Callable, ): """Initialize a DAG instance with configuration loading. Parameters ---------- dag : str DAG name, must exist in DAG_MAP """ # Set DAG name self.dag = dag self.pattern = pattern self.dispatcher = dispatcher # Load yaml and json config yml_path = os.path.join(DAG_CONFIG_DIR, f"{dag}.yml") json_path = os.path.join(DAG_CONFIG_DIR, f"default-dag-run.json") # unified with open(yml_path, "r") as f: self.dag_cfg = yaml.safe_load(f) assert ( self.dag_cfg["name"] == self.dag ), f"{self.dag_cfg['name']} != {self.dag}" # , f"{dag_cfg} not consistent with definition in .yml file." with open(json_path, "r") as f: self.dag_run_template = json.load(f) def run( self, # DAG group parameters dag_group: str = "default-dag-group", batch_id: str = "default-batch", priority: int | str = 1, # plan filter dataset: str | None = None, instrument: str | None = None, obs_type: str | None = None, obs_group: str | None = None, obs_id: str | None = None, proposal_id: str | None = None, # data filter detector: str | None = None, filter: str | None = None, prc_status: str | None = None, qc_status: str | None = None, # prc paramters pmapname: str = "", ref_cat: str = "", extra_kwargs: Optional[dict] = None, # additional parameters force_success: bool = False, return_details: bool = False, return_data_list: bool = False, # no custom_id ): if self.dispatcher is Dispatcher.dispatch_obsgroup: assert ( obs_group is not None ), "obs_group is required for obsgroup dispatcher" assert obs_id is None, "obs_id is not allowed for obsgroup dispatcher" assert detector is None, "detector is not allowed for obsgroup dispatcher" assert filter is None, "filter is not allowed for obsgroup dispatcher" if extra_kwargs is None: extra_kwargs = {} dag_group_run = self.generate_dag_group_run( dag_group=dag_group, batch_id=batch_id, priority=priority, ) plan_basis = DFS.dfs1_find_plan_basis( dataset=dataset, instrument=instrument, obs_type=obs_type, obs_group=obs_group, obs_id=obs_id, proposal_id=proposal_id, ) data_basis = DFS.dfs1_find_level0_basis( dataset=dataset, instrument=instrument, obs_type=obs_type, obs_group=obs_group, obs_id=obs_id, detector=detector, filter=filter, prc_status=prc_status, qc_status=qc_status, ) plan_basis, data_basis = self.filter_basis(plan_basis, data_basis) dag_run_list = self.schedule( dag_group_run=dag_group_run, data_basis=data_basis, plan_basis=plan_basis, force_success=force_success, return_data_list=return_data_list, # directly passed to dag_run's pmapname=pmapname, ref_cat=ref_cat, extra_kwargs=extra_kwargs, ) if return_details: return dag_group_run, dag_run_list else: return dag_group_run, [_["dag_run"] for _ in dag_run_list] def filter_basis(self, plan_basis, data_basis): # filter data basis via pattern filtered_data_basis = table.join( self.pattern, data_basis, keys=self.pattern.colnames, join_type="inner", ) # sort via obs_id filtered_data_basis.sort(keys=["dataset", "obs_id", "detector"]) if len(filtered_data_basis) == 0: return plan_basis[:0], filtered_data_basis u_data_basis = table.unique(filtered_data_basis["dataset", "obs_id"]) # filter plan basis via data basis filtered_plan_basis = table.join( u_data_basis, plan_basis, keys=["dataset", "obs_id"], join_type="inner", ) # sort via obs_id filtered_plan_basis.sort(keys=["dataset", "obs_id"]) return filtered_plan_basis, filtered_data_basis def schedule( self, dag_group_run: dict, # dag_group, dag_group_run data_basis: table.Table, plan_basis: table.Table, force_success: bool = False, return_data_list: bool = False, **kwargs, ) -> list[dict]: """Schedule tasks for DAG execution. This method filters plan and data basis, dispatches tasks, and generates DAG run messages for successful tasks. Parameters ---------- dag_group_run : dict DAG group run configuration containing: - dag_group: Group identifier - dag_group_run: SHA1 identifier for this run - batch_id: Batch identifier - priority: Execution priority data_basis : table.Table Table of data records to process plan_basis : table.Table Table of plan records to execute force_success : bool, optional If True, generate DAG run messages for all tasks, even if they failed (default: False) return_data_list : bool, optional If True, fill the data_list parameter with the data_basis records (default: False) **kwargs Additional keyword arguments passed to `dag_run` Returns ------- list[dict]: A tuple containing: - List of task dictionaries with DAG run messages added for successful tasks - Set of obs_id strings for tasks that failed or were skipped """ # filter plan and data basis filtered_plan_basis, filtered_data_basis = self.filter_basis( plan_basis, data_basis ) # dispatch tasks task_list = self.dispatcher(filtered_plan_basis, filtered_data_basis) for this_task in task_list: # only convert success tasks if force_success or this_task["success"]: dag_run = self.generate_dag_run( **dag_group_run, **this_task["task"], **kwargs, ) this_task["dag_run"] = dag_run if return_data_list: this_task["dag_run"]["data_list"] = [ list(this_task["relevant_data"]["_id"]) ] else: this_task["dag_run"] = None return task_list def generate_dag_run(self, **kwargs) -> dict: """Generate a complete DAG run message. Parameters ---------- kwargs : dict Additional keyword arguments to override. Returns ------- dict Complete DAG run message Raises ------ AssertionError If any key is not in the message template """ # copy template dag_run = self.dag_run_template.copy() # update values dag_run = override_common_keys(dag_run, kwargs) # set hash dag_run = override_common_keys( dag_run, { "dag": self.dag, "dag_run": self.generate_sha1(), }, ) # It seems that the dag_run_template is already stringified, # so we don't need to force_string here. # force values to be string dag_run = self.force_string(dag_run) return dag_run