_base_dag.py 11.3 KB
Newer Older
BO ZHANG's avatar
tweaks    
BO ZHANG committed
1
import json
BO ZHANG's avatar
BO ZHANG committed
2
import os
3
from typing import Callable, Optional
BO ZHANG's avatar
BO ZHANG committed
4
5

import yaml
6
from astropy import table, time
7
8
9
import csst_fs
import numpy as np
import functools
BO ZHANG's avatar
BO ZHANG committed
10

BO ZHANG's avatar
BO ZHANG committed
11
12
from ._dispatcher import Dispatcher
from ..dag_utils import (
13
    force_string_and_int,
14
15
16
17
    override_common_keys,
    generate_sha1_from_time,
)
from ..dfs import DFS
BO ZHANG's avatar
tweaks    
BO ZHANG committed
18
19
20
21
22

DAG_CONFIG_DIR = os.path.join(
    os.path.dirname(os.path.dirname(__file__)),
    "dag_config",
)
BO ZHANG's avatar
tweaks  
BO ZHANG committed
23
24


BO ZHANG's avatar
BO ZHANG committed
25
class BaseDAG:
BO ZHANG's avatar
BO ZHANG committed
26
27
    """Base class for all Directed Acyclic Graph (DAG) implementations.

28
29
30
31
    This class provides core functionality for DAG configuration, message generation,
    and task scheduling.
    """

32
33
34
35
36
37
38
39
40
    def __init__(self):

        # Load default DAG run template
        json_path = os.path.join(DAG_CONFIG_DIR, f"default-dag-run.json")  # unified
        with open(json_path, "r") as f:
            self.dag_run_template = json.load(f)

        self.dag = ""

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    @staticmethod
    def generate_sha1():
        """Generate a unique SHA1 hash based on current timestamp.

        Returns
        -------
        str
            SHA1 hash string
        """
        return generate_sha1_from_time(verbose=False)

    @staticmethod
    def generate_dag_group_run(
        dag_group: str = "default-dag-group",
        batch_id: str = "default-batch",
        priority: int | str = 1,
    ):
        """Generate a DAG group run configuration.

        Parameters
        ----------
        dag_group : str, optional
            Group identifier (default: "-")
        batch_id : str, optional
            Batch identifier (default: "-")
        priority : int | str, optional
            Execution priority (default: 1)

        Returns
        -------
        dict
            Dictionary containing:
            - dag_group: Original group name
            - dag_group_run: Generated SHA1 identifier
            - batch_id: Batch identifier
            - priority: Execution priority
        """
        return dict(
            dag_group=dag_group,
            dag_group_run=BaseDAG.generate_sha1(),
            batch_id=batch_id,
            priority=priority,
            created_time=time.Time.now().isot,
        )

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    def generate_dag_run(self, **kwargs) -> dict:
        """Generate a complete DAG run message.

        Parameters
        ----------
        **kwargs : Any
            Additional keyword arguments to override.

        Returns
        -------
        dict
            Complete DAG run message

        Raises
        ------
        AssertionError
            If any key is not in the message template
        """
        # copy template
        dag_run = self.dag_run_template.copy()
        # update values
        dag_run = override_common_keys(dag_run, kwargs)
        # set hash
        dag_run = override_common_keys(
            dag_run,
            {
                "dag": self.dag,
                "dag_run": self.generate_sha1(),
            },
        )
        # It seems that the dag_run_template is already stringified,
        # so we don't need to force_string here.
        # force values to be string
        dag_run = self.force_string_and_int(dag_run)
        return dag_run

122
    @staticmethod
123
124
    def force_string_and_int(d: dict):
        return force_string_and_int(d)
125
126
127
128
129
130
131
132
133
134
135
136


class Level2DAG(BaseDAG):
    """Level 2 DAG base class.

    Base class for all Level 2 Directed Acyclic Graph (DAG) implementations.

    This class provides core functionality for DAG configuration, message generation,
    and task scheduling.
    """

    def __init__(self):
137
        super().__init__()
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

    def schedule(self, plan_basis: table.Table, data_basis: table.Table):
        """Schedule the DAG for execution.

        Parameters
        ----------
        plan_basis : table.Table
            Plan basis table
        data_basis : table.Table
            Data basis table
        """
        pass


class Level1DAG(BaseDAG):
    """Level 1 DAG base class.

    Base class for all Level 1 Directed Acyclic Graph (DAG) implementations.

BO ZHANG's avatar
BO ZHANG committed
157
    This class provides core functionality for DAG configuration, message generation,
BO ZHANG's avatar
BO ZHANG committed
158
    and execution management within the CSST dlist processing system.
BO ZHANG's avatar
BO ZHANG committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

    Attributes
    ----------
    dag : str
        Name of the DAG, must exist in DAG_MAP
    dag_cfg : dict
        Configuration loaded from YAML file
    dag_run_template : dict
        Message template structure loaded from JSON file

    Raises
    ------
    AssertionError
        If DAG name is not in DAG_MAP or config name mismatch
    """

BO ZHANG's avatar
BO ZHANG committed
175
176
177
178
179
180
    def __init__(
        self,
        dag: str,
        pattern: table.Table,
        dispatcher: Callable,
    ):
BO ZHANG's avatar
BO ZHANG committed
181
182
183
184
185
        """Initialize a DAG instance with configuration loading.

        Parameters
        ----------
        dag : str
BO ZHANG's avatar
BO ZHANG committed
186
            DAG name, must exist in DAG_MAP
BO ZHANG's avatar
BO ZHANG committed
187
        """
188
        super().__init__()
BO ZHANG's avatar
BO ZHANG committed
189
190
        # Set DAG name
        self.dag = dag
BO ZHANG's avatar
BO ZHANG committed
191
192
        self.pattern = pattern
        self.dispatcher = dispatcher
BO ZHANG's avatar
BO ZHANG committed
193
194
195

        # Load yaml and json config
        yml_path = os.path.join(DAG_CONFIG_DIR, f"{dag}.yml")
BO ZHANG's avatar
tweaks  
BO ZHANG committed
196

BO ZHANG's avatar
tweaks    
BO ZHANG committed
197
        with open(yml_path, "r") as f:
BO ZHANG's avatar
tweaks    
BO ZHANG committed
198
            self.dag_cfg = yaml.safe_load(f)
BO ZHANG's avatar
tweaks  
BO ZHANG committed
199
        assert (
BO ZHANG's avatar
BO ZHANG committed
200
201
202
            self.dag_cfg["name"] == self.dag
        ), f"{self.dag_cfg['name']} != {self.dag}"  # , f"{dag_cfg} not consistent with definition in .yml file."

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    def run(
        self,
        # DAG group parameters
        dag_group: str = "default-dag-group",
        batch_id: str = "default-batch",
        priority: int | str = 1,
        # plan filter
        dataset: str | None = None,
        instrument: str | None = None,
        obs_type: str | None = None,
        obs_group: str | None = None,
        obs_id: str | None = None,
        proposal_id: str | None = None,
        # data filter
        detector: str | None = None,
        filter: str | None = None,
        prc_status: str | None = None,
        qc_status: str | None = None,
221
        # prc parameters
222
223
224
225
        pmapname: str = "",
        ref_cat: str = "",
        extra_kwargs: Optional[dict] = None,
        # additional parameters
BO ZHANG's avatar
tweak    
BO ZHANG committed
226
        return_data_list: bool = True,
227
228
229
        force_success: bool = False,
        return_details: bool = False,
        # no custom_id
230
    ) -> tuple[dict, list]:
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        if self.dispatcher is Dispatcher.dispatch_obsgroup:
            assert (
                obs_group is not None
            ), "obs_group is required for obsgroup dispatcher"
            assert obs_id is None, "obs_id is not allowed for obsgroup dispatcher"
            assert detector is None, "detector is not allowed for obsgroup dispatcher"
            assert filter is None, "filter is not allowed for obsgroup dispatcher"
        if extra_kwargs is None:
            extra_kwargs = {}
        dag_group_run = self.generate_dag_group_run(
            dag_group=dag_group,
            batch_id=batch_id,
            priority=priority,
        )
        plan_basis = DFS.dfs1_find_plan_basis(
            dataset=dataset,
            instrument=instrument,
            obs_type=obs_type,
            obs_group=obs_group,
            obs_id=obs_id,
            proposal_id=proposal_id,
        )
        data_basis = DFS.dfs1_find_level0_basis(
            dataset=dataset,
            instrument=instrument,
            obs_type=obs_type,
            obs_group=obs_group,
            obs_id=obs_id,
            detector=detector,
            filter=filter,
            prc_status=prc_status,
            qc_status=qc_status,
        )
264
265
266
        if len(plan_basis) == 0 or len(data_basis) == 0:
            return dag_group_run, []

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        plan_basis, data_basis = self.filter_basis(plan_basis, data_basis)
        dag_run_list = self.schedule(
            dag_group_run=dag_group_run,
            data_basis=data_basis,
            plan_basis=plan_basis,
            force_success=force_success,
            return_data_list=return_data_list,
            # directly passed to dag_run's
            pmapname=pmapname,
            ref_cat=ref_cat,
            extra_kwargs=extra_kwargs,
        )
        if return_details:
            return dag_group_run, dag_run_list
        else:
            return dag_group_run, [_["dag_run"] for _ in dag_run_list]
BO ZHANG's avatar
tweaks    
BO ZHANG committed
283

BO ZHANG's avatar
BO ZHANG committed
284
    def filter_basis(self, plan_basis, data_basis):
285
        # filter data basis via pattern
BO ZHANG's avatar
BO ZHANG committed
286
287
288
289
290
291
        filtered_data_basis = table.join(
            self.pattern,
            data_basis,
            keys=self.pattern.colnames,
            join_type="inner",
        )
BO ZHANG's avatar
BO ZHANG committed
292
        # sort via obs_id
BO ZHANG's avatar
BO ZHANG committed
293
        filtered_data_basis.sort(keys=["dataset", "obs_id", "detector"])
BO ZHANG's avatar
BO ZHANG committed
294
        if len(filtered_data_basis) == 0:
BO ZHANG's avatar
BO ZHANG committed
295
            return plan_basis[:0], filtered_data_basis
BO ZHANG's avatar
BO ZHANG committed
296
        u_data_basis = table.unique(filtered_data_basis["dataset", "obs_id"])
297
        # filter plan basis via data basis
BO ZHANG's avatar
BO ZHANG committed
298
299
300
301
302
303
        filtered_plan_basis = table.join(
            u_data_basis,
            plan_basis,
            keys=["dataset", "obs_id"],
            join_type="inner",
        )
304
        # sort via obs_id
BO ZHANG's avatar
BO ZHANG committed
305
        filtered_plan_basis.sort(keys=["dataset", "obs_id"])
BO ZHANG's avatar
BO ZHANG committed
306
        return filtered_plan_basis, filtered_data_basis
BO ZHANG's avatar
BO ZHANG committed
307

BO ZHANG's avatar
BO ZHANG committed
308
309
    def schedule(
        self,
310
        dag_group_run: dict,  # dag_group, dag_group_run
BO ZHANG's avatar
BO ZHANG committed
311
312
        data_basis: table.Table,
        plan_basis: table.Table,
BO ZHANG's avatar
BO ZHANG committed
313
        force_success: bool = False,
314
        return_data_list: bool = False,
BO ZHANG's avatar
BO ZHANG committed
315
        **kwargs,
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    ) -> list[dict]:
        """Schedule tasks for DAG execution.

        This method filters plan and data basis, dispatches tasks, and generates
        DAG run messages for successful tasks.

        Parameters
        ----------
        dag_group_run : dict
            DAG group run configuration containing:
            - dag_group: Group identifier
            - dag_group_run: SHA1 identifier for this run
            - batch_id: Batch identifier
            - priority: Execution priority
        data_basis : table.Table
            Table of data records to process
        plan_basis : table.Table
            Table of plan records to execute
        force_success : bool, optional
            If True, generate DAG run messages for all tasks, even if they failed
            (default: False)
        return_data_list : bool, optional
            If True, fill the data_list parameter with the data_basis records
            (default: False)
        **kwargs
            Additional keyword arguments passed to `dag_run`

        Returns
        -------
        list[dict]:
            A tuple containing:
            - List of task dictionaries with DAG run messages added for successful tasks
            - Set of obs_id strings for tasks that failed or were skipped
        """
BO ZHANG's avatar
BO ZHANG committed
350
        # filter plan and data basis
BO ZHANG's avatar
BO ZHANG committed
351
352
353
        filtered_plan_basis, filtered_data_basis = self.filter_basis(
            plan_basis, data_basis
        )
BO ZHANG's avatar
BO ZHANG committed
354
        # dispatch tasks
BO ZHANG's avatar
BO ZHANG committed
355
        task_list = self.dispatcher(filtered_plan_basis, filtered_data_basis)
BO ZHANG's avatar
tweaks    
BO ZHANG committed
356
        for this_task in task_list:
BO ZHANG's avatar
BO ZHANG committed
357
358
            # only convert success tasks
            if force_success or this_task["success"]:
359
                dag_run = self.generate_dag_run(
BO ZHANG's avatar
BO ZHANG committed
360
361
362
363
                    **dag_group_run,
                    **this_task["task"],
                    **kwargs,
                )
364
                this_task["dag_run"] = dag_run
365
366
                if return_data_list:
                    this_task["dag_run"]["data_list"] = [
BO ZHANG's avatar
tweak    
BO ZHANG committed
367
                        str(_) for _ in this_task["relevant_data_id_list"]
368
                    ]
369
            else:
BO ZHANG's avatar
BO ZHANG committed
370
                this_task["dag_run"] = None
371
        return task_list