_base_dag.py 11.1 KB
Newer Older
BO ZHANG's avatar
tweaks    
BO ZHANG committed
1
import json
BO ZHANG's avatar
BO ZHANG committed
2
import os
3
from typing import Callable, Optional
BO ZHANG's avatar
BO ZHANG committed
4
5

import yaml
6
from astropy import table, time
BO ZHANG's avatar
BO ZHANG committed
7

8
9
10
11
12
13
14
from ._dag_utils import (
    force_string,
    override_common_keys,
    generate_sha1_from_time,
)
from ..dfs import DFS
from ._dispatcher import Dispatcher
BO ZHANG's avatar
tweaks    
BO ZHANG committed
15
16
17
18
19

DAG_CONFIG_DIR = os.path.join(
    os.path.dirname(os.path.dirname(__file__)),
    "dag_config",
)
BO ZHANG's avatar
tweaks  
BO ZHANG committed
20
21


BO ZHANG's avatar
BO ZHANG committed
22
class BaseDAG:
BO ZHANG's avatar
BO ZHANG committed
23
24
    """Base class for all Directed Acyclic Graph (DAG) implementations.

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    This class provides core functionality for DAG configuration, message generation,
    and task scheduling.
    """

    @staticmethod
    def generate_sha1():
        """Generate a unique SHA1 hash based on current timestamp.

        Returns
        -------
        str
            SHA1 hash string
        """
        return generate_sha1_from_time(verbose=False)

    @staticmethod
    def generate_dag_group_run(
        dag_group: str = "default-dag-group",
        batch_id: str = "default-batch",
        priority: int | str = 1,
    ):
        """Generate a DAG group run configuration.

        Parameters
        ----------
        dag_group : str, optional
            Group identifier (default: "-")
        batch_id : str, optional
            Batch identifier (default: "-")
        priority : int | str, optional
            Execution priority (default: 1)

        Returns
        -------
        dict
            Dictionary containing:
            - dag_group: Original group name
            - dag_group_run: Generated SHA1 identifier
            - batch_id: Batch identifier
            - priority: Execution priority
        """
        return dict(
            dag_group=dag_group,
            dag_group_run=BaseDAG.generate_sha1(),
            batch_id=batch_id,
            priority=priority,
            created_time=time.Time.now().isot,
        )

    @staticmethod
    def force_string(d: dict):
        return force_string(d)


class Level2DAG(BaseDAG):
    """Level 2 DAG base class.

    Base class for all Level 2 Directed Acyclic Graph (DAG) implementations.

    This class provides core functionality for DAG configuration, message generation,
    and task scheduling.
    """

    def __init__(self):
        pass

    def schedule(self, plan_basis: table.Table, data_basis: table.Table):
        """Schedule the DAG for execution.

        Parameters
        ----------
        plan_basis : table.Table
            Plan basis table
        data_basis : table.Table
            Data basis table
        """
        pass

    def generate_dag_run(self):
        """Generate a DAG run configuration.

        Returns
        -------
        dict
            Dictionary containing DAG run configuration
        """
        pass


class Level1DAG(BaseDAG):
    """Level 1 DAG base class.

    Base class for all Level 1 Directed Acyclic Graph (DAG) implementations.

BO ZHANG's avatar
BO ZHANG committed
119
    This class provides core functionality for DAG configuration, message generation,
BO ZHANG's avatar
BO ZHANG committed
120
    and execution management within the CSST dlist processing system.
BO ZHANG's avatar
BO ZHANG committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

    Attributes
    ----------
    dag : str
        Name of the DAG, must exist in DAG_MAP
    dag_cfg : dict
        Configuration loaded from YAML file
    dag_run_template : dict
        Message template structure loaded from JSON file

    Raises
    ------
    AssertionError
        If DAG name is not in DAG_MAP or config name mismatch
    """

BO ZHANG's avatar
BO ZHANG committed
137
138
139
140
141
142
    def __init__(
        self,
        dag: str,
        pattern: table.Table,
        dispatcher: Callable,
    ):
BO ZHANG's avatar
BO ZHANG committed
143
144
145
146
147
        """Initialize a DAG instance with configuration loading.

        Parameters
        ----------
        dag : str
BO ZHANG's avatar
BO ZHANG committed
148
            DAG name, must exist in DAG_MAP
BO ZHANG's avatar
BO ZHANG committed
149
150
151
        """
        # Set DAG name
        self.dag = dag
BO ZHANG's avatar
BO ZHANG committed
152
153
        self.pattern = pattern
        self.dispatcher = dispatcher
BO ZHANG's avatar
BO ZHANG committed
154
155
156

        # Load yaml and json config
        yml_path = os.path.join(DAG_CONFIG_DIR, f"{dag}.yml")
157
        json_path = os.path.join(DAG_CONFIG_DIR, f"default-dag-run.json")  # unified
BO ZHANG's avatar
tweaks  
BO ZHANG committed
158

BO ZHANG's avatar
tweaks    
BO ZHANG committed
159
        with open(yml_path, "r") as f:
BO ZHANG's avatar
tweaks    
BO ZHANG committed
160
            self.dag_cfg = yaml.safe_load(f)
BO ZHANG's avatar
tweaks  
BO ZHANG committed
161
        assert (
BO ZHANG's avatar
BO ZHANG committed
162
163
            self.dag_cfg["name"] == self.dag
        ), f"{self.dag_cfg['name']} != {self.dag}"  # , f"{dag_cfg} not consistent with definition in .yml file."
BO ZHANG's avatar
tweaks    
BO ZHANG committed
164
        with open(json_path, "r") as f:
BO ZHANG's avatar
BO ZHANG committed
165
166
            self.dag_run_template = json.load(f)

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    def run(
        self,
        # DAG group parameters
        dag_group: str = "default-dag-group",
        batch_id: str = "default-batch",
        priority: int | str = 1,
        # plan filter
        dataset: str | None = None,
        instrument: str | None = None,
        obs_type: str | None = None,
        obs_group: str | None = None,
        obs_id: str | None = None,
        proposal_id: str | None = None,
        # data filter
        detector: str | None = None,
        filter: str | None = None,
        prc_status: str | None = None,
        qc_status: str | None = None,
        # prc paramters
        pmapname: str = "",
        ref_cat: str = "",
        extra_kwargs: Optional[dict] = None,
        # additional parameters
BO ZHANG's avatar
tweak    
BO ZHANG committed
190
        return_data_list: bool = True,
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        force_success: bool = False,
        return_details: bool = False,
        # no custom_id
    ):
        if self.dispatcher is Dispatcher.dispatch_obsgroup:
            assert (
                obs_group is not None
            ), "obs_group is required for obsgroup dispatcher"
            assert obs_id is None, "obs_id is not allowed for obsgroup dispatcher"
            assert detector is None, "detector is not allowed for obsgroup dispatcher"
            assert filter is None, "filter is not allowed for obsgroup dispatcher"
        if extra_kwargs is None:
            extra_kwargs = {}
        dag_group_run = self.generate_dag_group_run(
            dag_group=dag_group,
            batch_id=batch_id,
            priority=priority,
        )
        plan_basis = DFS.dfs1_find_plan_basis(
            dataset=dataset,
            instrument=instrument,
            obs_type=obs_type,
            obs_group=obs_group,
            obs_id=obs_id,
            proposal_id=proposal_id,
        )
        data_basis = DFS.dfs1_find_level0_basis(
            dataset=dataset,
            instrument=instrument,
            obs_type=obs_type,
            obs_group=obs_group,
            obs_id=obs_id,
            detector=detector,
            filter=filter,
            prc_status=prc_status,
            qc_status=qc_status,
        )
        plan_basis, data_basis = self.filter_basis(plan_basis, data_basis)
        dag_run_list = self.schedule(
            dag_group_run=dag_group_run,
            data_basis=data_basis,
            plan_basis=plan_basis,
            force_success=force_success,
            return_data_list=return_data_list,
            # directly passed to dag_run's
            pmapname=pmapname,
            ref_cat=ref_cat,
            extra_kwargs=extra_kwargs,
        )
        if return_details:
            return dag_group_run, dag_run_list
        else:
            return dag_group_run, [_["dag_run"] for _ in dag_run_list]
BO ZHANG's avatar
tweaks    
BO ZHANG committed
244

BO ZHANG's avatar
BO ZHANG committed
245
    def filter_basis(self, plan_basis, data_basis):
246
        # filter data basis via pattern
BO ZHANG's avatar
BO ZHANG committed
247
248
249
250
251
252
        filtered_data_basis = table.join(
            self.pattern,
            data_basis,
            keys=self.pattern.colnames,
            join_type="inner",
        )
BO ZHANG's avatar
BO ZHANG committed
253
        # sort via obs_id
BO ZHANG's avatar
BO ZHANG committed
254
        filtered_data_basis.sort(keys=["dataset", "obs_id", "detector"])
BO ZHANG's avatar
BO ZHANG committed
255
        if len(filtered_data_basis) == 0:
BO ZHANG's avatar
BO ZHANG committed
256
            return plan_basis[:0], filtered_data_basis
BO ZHANG's avatar
BO ZHANG committed
257
        u_data_basis = table.unique(filtered_data_basis["dataset", "obs_id"])
258
        # filter plan basis via data basis
BO ZHANG's avatar
BO ZHANG committed
259
260
261
262
263
264
        filtered_plan_basis = table.join(
            u_data_basis,
            plan_basis,
            keys=["dataset", "obs_id"],
            join_type="inner",
        )
265
        # sort via obs_id
BO ZHANG's avatar
BO ZHANG committed
266
        filtered_plan_basis.sort(keys=["dataset", "obs_id"])
BO ZHANG's avatar
BO ZHANG committed
267
        return filtered_plan_basis, filtered_data_basis
BO ZHANG's avatar
BO ZHANG committed
268

BO ZHANG's avatar
BO ZHANG committed
269
270
    def schedule(
        self,
271
        dag_group_run: dict,  # dag_group, dag_group_run
BO ZHANG's avatar
BO ZHANG committed
272
273
        data_basis: table.Table,
        plan_basis: table.Table,
BO ZHANG's avatar
BO ZHANG committed
274
        force_success: bool = False,
275
        return_data_list: bool = False,
BO ZHANG's avatar
BO ZHANG committed
276
        **kwargs,
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    ) -> list[dict]:
        """Schedule tasks for DAG execution.

        This method filters plan and data basis, dispatches tasks, and generates
        DAG run messages for successful tasks.

        Parameters
        ----------
        dag_group_run : dict
            DAG group run configuration containing:
            - dag_group: Group identifier
            - dag_group_run: SHA1 identifier for this run
            - batch_id: Batch identifier
            - priority: Execution priority
        data_basis : table.Table
            Table of data records to process
        plan_basis : table.Table
            Table of plan records to execute
        force_success : bool, optional
            If True, generate DAG run messages for all tasks, even if they failed
            (default: False)
        return_data_list : bool, optional
            If True, fill the data_list parameter with the data_basis records
            (default: False)
        **kwargs
            Additional keyword arguments passed to `dag_run`

        Returns
        -------
        list[dict]:
            A tuple containing:
            - List of task dictionaries with DAG run messages added for successful tasks
            - Set of obs_id strings for tasks that failed or were skipped
        """
BO ZHANG's avatar
BO ZHANG committed
311
        # filter plan and data basis
BO ZHANG's avatar
BO ZHANG committed
312
313
314
        filtered_plan_basis, filtered_data_basis = self.filter_basis(
            plan_basis, data_basis
        )
BO ZHANG's avatar
BO ZHANG committed
315
        # dispatch tasks
BO ZHANG's avatar
BO ZHANG committed
316
        task_list = self.dispatcher(filtered_plan_basis, filtered_data_basis)
BO ZHANG's avatar
tweaks    
BO ZHANG committed
317
        for this_task in task_list:
BO ZHANG's avatar
BO ZHANG committed
318
319
            # only convert success tasks
            if force_success or this_task["success"]:
320
                dag_run = self.generate_dag_run(
BO ZHANG's avatar
BO ZHANG committed
321
322
323
324
                    **dag_group_run,
                    **this_task["task"],
                    **kwargs,
                )
325
                this_task["dag_run"] = dag_run
326
327
                if return_data_list:
                    this_task["dag_run"]["data_list"] = [
BO ZHANG's avatar
tweak    
BO ZHANG committed
328
                        str(_) for _ in this_task["relevant_data_id_list"]
329
                    ]
330
            else:
BO ZHANG's avatar
BO ZHANG committed
331
                this_task["dag_run"] = None
332
        return task_list
BO ZHANG's avatar
tweaks  
BO ZHANG committed
333

334
    def generate_dag_run(self, **kwargs) -> dict:
BO ZHANG's avatar
BO ZHANG committed
335
336
337
338
        """Generate a complete DAG run message.

        Parameters
        ----------
BO ZHANG's avatar
tweaks    
BO ZHANG committed
339
340
        kwargs : dict
            Additional keyword arguments to override.
BO ZHANG's avatar
BO ZHANG committed
341
342
343
344
345
346
347
348
349
350
351
352
353

        Returns
        -------
        dict
            Complete DAG run message

        Raises
        ------
        AssertionError
            If any key is not in the message template
        """
        # copy template
        dag_run = self.dag_run_template.copy()
BO ZHANG's avatar
tweaks    
BO ZHANG committed
354
355
        # update values
        dag_run = override_common_keys(dag_run, kwargs)
BO ZHANG's avatar
BO ZHANG committed
356
        # set hash
BO ZHANG's avatar
BO ZHANG committed
357
358
        dag_run = override_common_keys(
            dag_run,
359
360
361
362
            {
                "dag": self.dag,
                "dag_run": self.generate_sha1(),
            },
BO ZHANG's avatar
BO ZHANG committed
363
        )
364
365
        # It seems that the dag_run_template is already stringified,
        # so we don't need to force_string here.
BO ZHANG's avatar
BO ZHANG committed
366
367
        # force values to be string
        dag_run = self.force_string(dag_run)
BO ZHANG's avatar
BO ZHANG committed
368
        return dag_run