import json import os from typing import Callable import yaml from astropy import table from ._dispatcher import override_common_keys from .._dfs import DFS, dfs from ..hash import generate_sha1_from_time DAG_CONFIG_DIR = os.path.join( os.path.dirname(os.path.dirname(__file__)), "dag_config", ) class BaseDAG: """Base class for all Directed Acyclic Graph (DAG) implementations. This class provides core functionality for DAG configuration, message generation, and execution management within the CSST dlist processing system. Attributes ---------- dag : str Name of the DAG, must exist in DAG_MAP dag_cfg : dict Configuration loaded from YAML file dag_run_template : dict Message template structure loaded from JSON file dfs : DFS Data Flow System instance for execution Raises ------ AssertionError If DAG name is not in DAG_MAP or config name mismatch """ def __init__( self, dag: str, pattern: table.Table, dispatcher: Callable, ): """Initialize a DAG instance with configuration loading. Parameters ---------- dag : str DAG name, must exist in DAG_MAP """ # Set DAG name self.dag = dag self.pattern = pattern self.dispatcher = dispatcher # Load yaml and json config yml_path = os.path.join(DAG_CONFIG_DIR, f"{dag}.yml") json_path = os.path.join(DAG_CONFIG_DIR, f"{dag}.json") with open(yml_path, "r") as f: self.dag_cfg = yaml.safe_load(f) assert ( self.dag_cfg["name"] == self.dag ), f"{self.dag_cfg['name']} != {self.dag}" # , f"{dag_cfg} not consistent with definition in .yml file." with open(json_path, "r") as f: self.dag_run_template = json.load(f) # DFS instance self.dfs = dfs def filter_basis(self, plan_basis, data_basis): filtered_data_basis = table.join( self.pattern, data_basis, keys=self.pattern.colnames, join_type="inner", ) # sort via obs_id filtered_data_basis.sort(keys=["dataset", "obs_id"]) if len(filtered_data_basis) == 0: return plan_basis[:0], filtered_data_basis u_data_basis = table.unique(filtered_data_basis["dataset", "obs_id"]) filtered_plan_basis = table.join( u_data_basis, plan_basis, keys=["dataset", "obs_id"], join_type="inner", ) return filtered_plan_basis, filtered_data_basis def schedule( self, dag_group_run: dict, data_basis: table.Table, plan_basis: table.Table, force_success: bool = False, **kwargs, ) -> tuple[list, set]: # filter plan and data basis filtered_plan_basis, filtered_data_basis = self.filter_basis( plan_basis, data_basis ) # dispatch tasks task_list = self.dispatcher(filtered_plan_basis, filtered_data_basis) for this_task in task_list: # only convert success tasks if force_success or this_task["success"]: dag_run = self.gen_dag_run( **dag_group_run, **this_task["task"], **kwargs, ) this_task["dag_run"] = dag_run else: this_task["dag_run"] = None return task_list @staticmethod def generate_sha1(): """Generate a unique SHA1 hash based on current timestamp. Returns ------- str SHA1 hash string """ return generate_sha1_from_time(verbose=False) @staticmethod def gen_dag_group_run( dag_group: str = "-", batch_id: str = "-", priority: int = 1, ): """Generate a DAG group run configuration. Parameters ---------- dag_group : str, optional Group identifier (default: "-") batch_id : str, optional Batch identifier (default: "-") priority : int, optional Execution priority (default: 1) Returns ------- dict Dictionary containing: - dag_group: Original group name - dag_group_run: Generated SHA1 identifier - batch_id: Batch identifier - priority: Execution priority """ return dict( dag_group=dag_group, dag_group_run=BaseDAG.generate_sha1(), batch_id=batch_id, priority=priority, ) @staticmethod def force_string(d: dict): for k, v in d.items(): d.__setattr__(k, str(v)) return d def gen_dag_run(self, **kwargs) -> dict: """Generate a complete DAG run message. Parameters ---------- kwargs : dict Additional keyword arguments to override. Returns ------- dict Complete DAG run message Raises ------ AssertionError If any key is not in the message template """ # copy template dag_run = self.dag_run_template.copy() # update values dag_run = override_common_keys(dag_run, kwargs) # set hash dag_run = override_common_keys( dag_run, {"dag": self.dag, "dag_run": self.generate_sha1()}, ) # force values to be string dag_run = self.force_string(dag_run) return dag_run @staticmethod def push_dag_group_run( dag_group_run: dict, dag_run_list: list[dict], ): """Submit a DAG group run to the DFS system. Parameters ---------- dag_group_run : dict Group run configuration dag_run_list : list[dict] List of individual DAG run messages Returns ------- Any Result from dfs.dag.new_dag_group_run() """ return dfs.dag.new_dag_group_run( dag_group_run=dag_group_run, dag_run_list=dag_run_list, )