_base_dag.py 5.66 KB
Newer Older
BO ZHANG's avatar
tweaks    
BO ZHANG committed
1
import json
BO ZHANG's avatar
BO ZHANG committed
2
import os
BO ZHANG's avatar
BO ZHANG committed
3
from typing import Any, Callable, Optional
BO ZHANG's avatar
BO ZHANG committed
4
5

import yaml
BO ZHANG's avatar
BO ZHANG committed
6
from astropy import table
BO ZHANG's avatar
BO ZHANG committed
7
8
9

from .._dfs import DFS, dfs
from ..hash import generate_sha1_from_time
BO ZHANG's avatar
tweaks    
BO ZHANG committed
10
from ._dispatcher import Dispatcher, override_common_keys
BO ZHANG's avatar
BO ZHANG committed
11

BO ZHANG's avatar
tweaks    
BO ZHANG committed
12
13
14
15
16

DAG_CONFIG_DIR = os.path.join(
    os.path.dirname(os.path.dirname(__file__)),
    "dag_config",
)
BO ZHANG's avatar
tweaks  
BO ZHANG committed
17
18
19
20
21
22
23
24
25
26

"""
- BaseTrigger
  - AutomaticTrigger
  - ManualTrigger
    - with Parameters
    - without Parameters
"""


BO ZHANG's avatar
BO ZHANG committed
27
class BaseDAG:
BO ZHANG's avatar
BO ZHANG committed
28
29
30
    """Base class for all Directed Acyclic Graph (DAG) implementations.

    This class provides core functionality for DAG configuration, message generation,
BO ZHANG's avatar
BO ZHANG committed
31
    and execution management within the CSST dlist processing system.
BO ZHANG's avatar
BO ZHANG committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

    Attributes
    ----------
    dag : str
        Name of the DAG, must exist in DAG_MAP
    dag_cfg : dict
        Configuration loaded from YAML file
    dag_run_template : dict
        Message template structure loaded from JSON file
    dfs : DFS
        Data Flow System instance for execution

    Raises
    ------
    AssertionError
        If DAG name is not in DAG_MAP or config name mismatch
    """

BO ZHANG's avatar
BO ZHANG committed
50
51
52
53
54
55
    def __init__(
        self,
        dag: str,
        pattern: table.Table,
        dispatcher: Callable,
    ):
BO ZHANG's avatar
BO ZHANG committed
56
57
58
59
60
        """Initialize a DAG instance with configuration loading.

        Parameters
        ----------
        dag : str
BO ZHANG's avatar
BO ZHANG committed
61
            DAG name, must exist in DAG_MAP
BO ZHANG's avatar
BO ZHANG committed
62
63
64
        """
        # Set DAG name
        self.dag = dag
BO ZHANG's avatar
BO ZHANG committed
65
66
        self.pattern = pattern
        self.dispatcher = dispatcher
BO ZHANG's avatar
BO ZHANG committed
67
68
69
70

        # Load yaml and json config
        yml_path = os.path.join(DAG_CONFIG_DIR, f"{dag}.yml")
        json_path = os.path.join(DAG_CONFIG_DIR, f"{dag}.json")
BO ZHANG's avatar
tweaks  
BO ZHANG committed
71

BO ZHANG's avatar
tweaks    
BO ZHANG committed
72
        with open(yml_path, "r") as f:
BO ZHANG's avatar
tweaks    
BO ZHANG committed
73
            self.dag_cfg = yaml.safe_load(f)
BO ZHANG's avatar
tweaks  
BO ZHANG committed
74
        assert (
BO ZHANG's avatar
BO ZHANG committed
75
76
            self.dag_cfg["name"] == self.dag
        ), f"{self.dag_cfg['name']} != {self.dag}"  # , f"{dag_cfg} not consistent with definition in .yml file."
BO ZHANG's avatar
tweaks    
BO ZHANG committed
77
        with open(json_path, "r") as f:
BO ZHANG's avatar
BO ZHANG committed
78
79
80
            self.dag_run_template = json.load(f)

        # DFS instance
BO ZHANG's avatar
BO ZHANG committed
81
        self.dfs = dfs
BO ZHANG's avatar
tweaks    
BO ZHANG committed
82

BO ZHANG's avatar
BO ZHANG committed
83
84
85
86
87
88
89
    def filter_basis(self, plan_basis, data_basis):
        filtered_data_basis = table.join(
            self.pattern,
            data_basis,
            keys=self.pattern.colnames,
            join_type="inner",
        )
BO ZHANG's avatar
BO ZHANG committed
90
        if len(filtered_data_basis) == 0:
BO ZHANG's avatar
BO ZHANG committed
91
            return plan_basis[:0], filtered_data_basis
BO ZHANG's avatar
BO ZHANG committed
92
        u_data_basis = table.unique(filtered_data_basis["dataset", "obs_id"])
BO ZHANG's avatar
BO ZHANG committed
93
94
95
96
97
98
        filtered_plan_basis = table.join(
            u_data_basis,
            plan_basis,
            keys=["dataset", "obs_id"],
            join_type="inner",
        )
BO ZHANG's avatar
BO ZHANG committed
99
        return filtered_plan_basis, filtered_data_basis
BO ZHANG's avatar
BO ZHANG committed
100

BO ZHANG's avatar
BO ZHANG committed
101
102
103
104
105
106
107
108
109
110
111
112
    def schedule(
        self,
        dag_group_run: dict,
        data_basis: table.Table,
        plan_basis: table.Table,
        **kwargs,
    ) -> list[dict]:
        filtered_plan_basis, filtered_data_basis = self.filter_basis(
            plan_basis, data_basis
        )
        task_list = self.dispatcher(filtered_plan_basis, filtered_data_basis)
        dag_run_list = []
BO ZHANG's avatar
tweaks    
BO ZHANG committed
113
        for this_task in task_list:
BO ZHANG's avatar
BO ZHANG committed
114
            dag_run = self.gen_dag_run(
BO ZHANG's avatar
tweaks    
BO ZHANG committed
115
116
117
                **dag_group_run,
                **this_task["task"],
                **kwargs,
BO ZHANG's avatar
BO ZHANG committed
118
119
            )
            dag_run_list.append(dag_run)
BO ZHANG's avatar
BO ZHANG committed
120

BO ZHANG's avatar
BO ZHANG committed
121
        return dag_run_list
BO ZHANG's avatar
tweaks  
BO ZHANG committed
122

BO ZHANG's avatar
tweaks    
BO ZHANG committed
123
    @staticmethod
BO ZHANG's avatar
BO ZHANG committed
124
125
126
127
128
129
130
    def generate_sha1():
        """Generate a unique SHA1 hash based on current timestamp.

        Returns
        -------
        str
            SHA1 hash string
BO ZHANG's avatar
tweaks    
BO ZHANG committed
131
        """
BO ZHANG's avatar
BO ZHANG committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        return generate_sha1_from_time(verbose=False)

    @staticmethod
    def gen_dag_group_run(
        dag_group: str = "-",
        batch_id: str = "-",
        priority: int = 1,
    ):
        """Generate a DAG group run configuration.

        Parameters
        ----------
        dag_group : str, optional
            Group identifier (default: "-")
        batch_id : str, optional
            Batch identifier (default: "-")
        priority : int, optional
            Execution priority (default: 1)

        Returns
        -------
        dict
            Dictionary containing:
            - dag_group: Original group name
            - dag_group_run: Generated SHA1 identifier
            - batch_id: Batch identifier
            - priority: Execution priority
BO ZHANG's avatar
tweaks    
BO ZHANG committed
159
        """
BO ZHANG's avatar
BO ZHANG committed
160
161
162
163
164
165
166
        return dict(
            dag_group=dag_group,
            dag_group_run=BaseDAG.generate_sha1(),
            batch_id=batch_id,
            priority=priority,
        )

BO ZHANG's avatar
tweaks    
BO ZHANG committed
167
    def gen_dag_run(self, **kwargs) -> dict:
BO ZHANG's avatar
BO ZHANG committed
168
169
170
171
        """Generate a complete DAG run message.

        Parameters
        ----------
BO ZHANG's avatar
tweaks    
BO ZHANG committed
172
173
        kwargs : dict
            Additional keyword arguments to override.
BO ZHANG's avatar
BO ZHANG committed
174
175
176
177
178
179
180
181
182
183
184
185
186

        Returns
        -------
        dict
            Complete DAG run message

        Raises
        ------
        AssertionError
            If any key is not in the message template
        """
        # copy template
        dag_run = self.dag_run_template.copy()
BO ZHANG's avatar
tweaks    
BO ZHANG committed
187
188
        # update values
        dag_run = override_common_keys(dag_run, kwargs)
BO ZHANG's avatar
BO ZHANG committed
189
190
        # set hash
        dag_run = override_common_keys(dag_run, {"dag_run": self.generate_sha1()})
BO ZHANG's avatar
BO ZHANG committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        return dag_run

    @staticmethod
    def push_dag_group_run(
        dag_group_run: dict,
        dag_run_list: list[dict],
    ):
        """Submit a DAG group run to the DFS system.

        Parameters
        ----------
        dag_group_run : dict
            Group run configuration
        dag_run_list : list[dict]
            List of individual DAG run messages

        Returns
        -------
        Any
            Result from dfs.dag.new_dag_group_run()
        """

        return dfs.dag.new_dag_group_run(
            dag_group_run=dag_group_run,
            dag_run_list=dag_run_list,
        )