dispatcher.py 6.51 KB
Newer Older
BO ZHANG's avatar
BO ZHANG committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import numpy as np
from astropy.table import Table
from csst_dfs_client import plan, level0, level1
from .._csst import csst

# THESE ARE GENERAL PARAMETERS!
PLAN_PARAMS = {
    "dataset": None,
    "instrument": None,
    "obs_type": None,
    "obs_group": None,
    "obs_id": None,
    "proposal_id": None,
}

LEVEL0_PARAMS = {
    "dataset": None,
    "instrument": None,
    "obs_type": None,
    "obs_group": None,
    "obs_id": None,
    "detector": None,
    "prc_status": -1024,
}

LEVEL1_PARAMS = {
    "dataset": None,
    "instrument": None,
    "obs_type": None,
    "obs_group": None,
    "obs_id": None,
    "detector": None,
    "prc_status": -1024,
    "data_model": None,
    "batch_id": "default_batch",
}

PROC_PARAMS = {
    "priority": 1,
    "batch_id": "default_batch",
    "pmapname": "pmapname",
    "final_prc_status": -2,
    "demo": False,
}


def override_common_keys(d1: dict, d2: dict) -> dict:
    """
    Construct a new dictionary by updating the values of basis_keys that exists in the first dictionary
    with the values of the second dictionary.

    Parameters
    ----------
    d1 : dict
        The first dictionary.
    d2 : dict
        The second dictionary.

    Returns
    -------
    dict:
        The updated dictionary.
    """
    return {k: d2[k] if k in d2.keys() else d1[k] for k in d1.keys()}


# def extract_basis(dlist: list[dict], basis_keys: tuple) -> np.ndarray:
#     """Extract basis key-value pairs from a list of dictionaries."""
#     return Table([{k: d.get(k) for k in basis_keys} for d in dlist]).as_array()


def extract_basis(dlist: list[dict], basis_keys: tuple) -> np.typing.NDArray:
    """Extract basis key-value pairs from a list of dictionaries."""
    return np.array([{k: d.get(k) for k in basis_keys} for d in dlist], dtype=dict)


class Dispatcher:
    """
    A class to dispatch tasks based on the observation type.
    """

    @staticmethod
    def dispatch_level0_file(**kwargs) -> dict:
        # plan_recs = plan.find(**override_common_keys(PLAN_PARAMS, kwargs))
        data_recs = level0.find(**override_common_keys(LEVEL0_PARAMS, kwargs))
        # construct results
        task_list = []
        for data_rec in data_recs:
            # construct task
            task = dict(
                dataset=data_rec["dataset"],
                instrument=data_rec["instrument"],
                obs_type=data_rec["obs_type"],
                obs_group=data_rec["obs_group"],
                obs_id=data_rec["obs_id"],
                detector=data_rec["detector"],
                file_name=data_rec["file_name"],
            )

        return dict(
            task_list=task_list,
            relevant_data_id_list=[],
        )

    @staticmethod
    def dispatch_level0_detector(**kwargs) -> dict:

        # get instrument
        assert "instrument" in kwargs.keys(), f"{kwargs} does not have key 'instrument'"
        instrument = kwargs.get("instrument")
        assert instrument in ("MSC", "MCI", "IFS", "CPIC", "HSTDM")

        # query for plan and data
        plan_recs = plan.find(**override_common_keys(PLAN_PARAMS, kwargs))
        assert plan_recs.success, plan_recs
        data_recs = level0.find(**override_common_keys(LEVEL0_PARAMS, kwargs))
        assert data_recs.success, data_recs

        import joblib

        plan_recs = joblib.load("dagtest/csst-msc-c9-25sqdeg-v3.plan.dump")
        data_recs = joblib.load("dagtest/csst-msc-c9-25sqdeg-v3.level0.dump")
        print(f"{len(plan_recs.data)} plan records")
        print(f"{len(data_recs.data)} data records")

        instrument = "MSC"
        from csst_dag._csst import csst

        effective_detector_names = csst[instrument].effective_detector_names

        # extract info
        plan_basis = extract_basis(
            plan_recs.data,
            (
                "dataset",
                "instrument",
                "obs_type",
                "obs_group",
                "obs_id",
            ),
        )
        data_basis = extract_basis(
            data_recs.data,
            (
                "dataset",
                "instrument",
                "obs_type",
                "obs_group",
                "obs_id",
                "detector",
            ),
        )

        # select plan basis relevant to data via `obs_id`
        u_data_obsid = np.unique([_["obs_id"] for _ in data_basis])
        relevant_plan_basis = [_ for _ in plan_basis if _["obs_id"] in u_data_obsid]
        print(f"{len(relevant_plan_basis)} relevant plan records")

        # idx_selected_relevant_plan_basis = np.zeros(len(relevant_plan_basis), dtype=bool)
        # 好像并不是要找出所有的plan,而是要找出所有的任务,而detector级的任务要比plan_basis多得多

        task_list = []
        relevant_data_id_list = []

        # loop over plan
        for i_plan_basis, this_plan_basis in enumerate(relevant_plan_basis):
            print(f"Processing {i_plan_basis + 1}/{len(relevant_plan_basis)}")
            # span over `detector`
            for this_detector in effective_detector_names:
                # construct this_task
                this_task = dict(
                    dataset=this_plan_basis["dataset"],
                    instrument=this_plan_basis["instrument"],
                    obs_type=this_plan_basis["obs_type"],
                    obs_group=this_plan_basis["obs_group"],
                    obs_id=this_plan_basis["obs_id"],
                    detector=this_detector,
                )
                # find this plan basis
                idx_this_plan_basis = np.argwhere(
                    plan_basis == this_plan_basis
                ).flatten()[0]
                # get n_frame, calculate n_file_expected
                if instrument == "HSTDM":
                    n_file_expected = plan_recs.data[idx_this_plan_basis]["params"][
                        "num_epec_frame"
                    ]
                else:
                    n_file_expected = 1
                # count files found in data_basis
                idx_files_found = np.argwhere(data_basis == this_task).flatten()
                n_file_found = len(idx_files_found)
                # if found == expected, append this task
                if n_file_found == n_file_expected:
                    task_list.append(this_task)
                    relevant_data_id_list.extend(
                        [data_recs.data[_]["_id"] for _ in idx_files_found]
                    )

        return dict(task_list=task_list, relevant_data_id_list=relevant_data_id_list)

    @staticmethod
    def dispatch_level0_obsid(**kwargs) -> list[dict]:
        pass