import json import os from typing import Any, Callable, Optional import yaml from astropy import table from .._dfs import DFS, dfs from ..hash import generate_sha1_from_time from ._dispatcher import Dispatcher, override_common_keys DAG_CONFIG_DIR = os.path.join( os.path.dirname(os.path.dirname(__file__)), "dag_config", ) """ - BaseTrigger - AutomaticTrigger - ManualTrigger - with Parameters - without Parameters """ class BaseDAG: """Base class for all Directed Acyclic Graph (DAG) implementations. This class provides core functionality for DAG configuration, message generation, and execution management within the CSST dlist processing system. Attributes ---------- dag : str Name of the DAG, must exist in DAG_MAP dag_cfg : dict Configuration loaded from YAML file dag_run_template : dict Message template structure loaded from JSON file dfs : DFS Data Flow System instance for execution Raises ------ AssertionError If DAG name is not in DAG_MAP or config name mismatch """ def __init__( self, dag: str, pattern: table.Table, dispatcher: Callable, ): """Initialize a DAG instance with configuration loading. Parameters ---------- dag : str DAG name, must exist in DAG_MAP """ # Set DAG name self.dag = dag self.pattern = pattern self.dispatcher = dispatcher # Load yaml and json config yml_path = os.path.join(DAG_CONFIG_DIR, f"{dag}.yml") json_path = os.path.join(DAG_CONFIG_DIR, f"{dag}.json") with open(yml_path, "r") as f: self.dag_cfg = yaml.safe_load(f) assert ( self.dag_cfg["name"] == self.dag ), f"{self.dag_cfg['name']} != {self.dag}" # , f"{dag_cfg} not consistent with definition in .yml file." with open(json_path, "r") as f: self.dag_run_template = json.load(f) # DFS instance self.dfs = dfs def filter_basis(self, plan_basis, data_basis): filtered_data_basis = table.join( self.pattern, data_basis, keys=self.pattern.colnames, join_type="inner", ) u_data_basis = table.unique(filtered_data_basis["dataset", "obs_id"]) filtered_plan_basis = table.join( u_data_basis, plan_basis, keys=["dataset", "obs_id"], join_type="inner", ) return filtered_plan_basis, filtered_data_basis def schedule( self, dag_group_run: dict, data_basis: table.Table, plan_basis: table.Table, **kwargs, ) -> list[dict]: filtered_plan_basis, filtered_data_basis = self.filter_basis( plan_basis, data_basis ) task_list = self.dispatcher(filtered_plan_basis, filtered_data_basis) dag_run_list = [] for this_task in task_list: dag_run = self.gen_dag_run( **dag_group_run, **this_task["task"], **kwargs, ) dag_run_list.append(dag_run) return dag_run_list @staticmethod def generate_sha1(): """Generate a unique SHA1 hash based on current timestamp. Returns ------- str SHA1 hash string """ return generate_sha1_from_time(verbose=False) @staticmethod def gen_dag_group_run( dag_group: str = "-", batch_id: str = "-", priority: int = 1, ): """Generate a DAG group run configuration. Parameters ---------- dag_group : str, optional Group identifier (default: "-") batch_id : str, optional Batch identifier (default: "-") priority : int, optional Execution priority (default: 1) Returns ------- dict Dictionary containing: - dag_group: Original group name - dag_group_run: Generated SHA1 identifier - batch_id: Batch identifier - priority: Execution priority """ return dict( dag_group=dag_group, dag_group_run=BaseDAG.generate_sha1(), batch_id=batch_id, priority=priority, ) def gen_dag_run(self, **kwargs) -> dict: """Generate a complete DAG run message. Parameters ---------- kwargs : dict Additional keyword arguments to override. Returns ------- dict Complete DAG run message Raises ------ AssertionError If any key is not in the message template """ # copy template dag_run = self.dag_run_template.copy() # update values dag_run = override_common_keys(dag_run, kwargs) return dag_run @staticmethod def push_dag_group_run( dag_group_run: dict, dag_run_list: list[dict], ): """Submit a DAG group run to the DFS system. Parameters ---------- dag_group_run : dict Group run configuration dag_run_list : list[dict] List of individual DAG run messages Returns ------- Any Result from dfs.dag.new_dag_group_run() """ return dfs.dag.new_dag_group_run( dag_group_run=dag_group_run, dag_run_list=dag_run_list, )