Commit c1df3efa authored by BO ZHANG's avatar BO ZHANG 🏀
Browse files

init dispatcher.py

parent 6336cc21
import numpy as np
from astropy.table import Table
from csst_dfs_client import plan, level0, level1
from .._csst import csst
# THESE ARE GENERAL PARAMETERS!
PLAN_PARAMS = {
"dataset": None,
"instrument": None,
"obs_type": None,
"obs_group": None,
"obs_id": None,
"proposal_id": None,
}
LEVEL0_PARAMS = {
"dataset": None,
"instrument": None,
"obs_type": None,
"obs_group": None,
"obs_id": None,
"detector": None,
"prc_status": -1024,
}
LEVEL1_PARAMS = {
"dataset": None,
"instrument": None,
"obs_type": None,
"obs_group": None,
"obs_id": None,
"detector": None,
"prc_status": -1024,
"data_model": None,
"batch_id": "default_batch",
}
PROC_PARAMS = {
"priority": 1,
"batch_id": "default_batch",
"pmapname": "pmapname",
"final_prc_status": -2,
"demo": False,
}
def override_common_keys(d1: dict, d2: dict) -> dict:
"""
Construct a new dictionary by updating the values of basis_keys that exists in the first dictionary
with the values of the second dictionary.
Parameters
----------
d1 : dict
The first dictionary.
d2 : dict
The second dictionary.
Returns
-------
dict:
The updated dictionary.
"""
return {k: d2[k] if k in d2.keys() else d1[k] for k in d1.keys()}
# def extract_basis(dlist: list[dict], basis_keys: tuple) -> np.ndarray:
# """Extract basis key-value pairs from a list of dictionaries."""
# return Table([{k: d.get(k) for k in basis_keys} for d in dlist]).as_array()
def extract_basis(dlist: list[dict], basis_keys: tuple) -> np.typing.NDArray:
"""Extract basis key-value pairs from a list of dictionaries."""
return np.array([{k: d.get(k) for k in basis_keys} for d in dlist], dtype=dict)
class Dispatcher:
"""
A class to dispatch tasks based on the observation type.
"""
@staticmethod
def dispatch_level0_file(**kwargs) -> dict:
# plan_recs = plan.find(**override_common_keys(PLAN_PARAMS, kwargs))
data_recs = level0.find(**override_common_keys(LEVEL0_PARAMS, kwargs))
# construct results
task_list = []
for data_rec in data_recs:
# construct task
task = dict(
dataset=data_rec["dataset"],
instrument=data_rec["instrument"],
obs_type=data_rec["obs_type"],
obs_group=data_rec["obs_group"],
obs_id=data_rec["obs_id"],
detector=data_rec["detector"],
file_name=data_rec["file_name"],
)
return dict(
task_list=task_list,
relevant_data_id_list=[],
)
@staticmethod
def dispatch_level0_detector(**kwargs) -> dict:
# get instrument
assert "instrument" in kwargs.keys(), f"{kwargs} does not have key 'instrument'"
instrument = kwargs.get("instrument")
assert instrument in ("MSC", "MCI", "IFS", "CPIC", "HSTDM")
# query for plan and data
plan_recs = plan.find(**override_common_keys(PLAN_PARAMS, kwargs))
assert plan_recs.success, plan_recs
data_recs = level0.find(**override_common_keys(LEVEL0_PARAMS, kwargs))
assert data_recs.success, data_recs
import joblib
plan_recs = joblib.load("dagtest/csst-msc-c9-25sqdeg-v3.plan.dump")
data_recs = joblib.load("dagtest/csst-msc-c9-25sqdeg-v3.level0.dump")
print(f"{len(plan_recs.data)} plan records")
print(f"{len(data_recs.data)} data records")
instrument = "MSC"
from csst_dag._csst import csst
effective_detector_names = csst[instrument].effective_detector_names
# extract info
plan_basis = extract_basis(
plan_recs.data,
(
"dataset",
"instrument",
"obs_type",
"obs_group",
"obs_id",
),
)
data_basis = extract_basis(
data_recs.data,
(
"dataset",
"instrument",
"obs_type",
"obs_group",
"obs_id",
"detector",
),
)
# select plan basis relevant to data via `obs_id`
u_data_obsid = np.unique([_["obs_id"] for _ in data_basis])
relevant_plan_basis = [_ for _ in plan_basis if _["obs_id"] in u_data_obsid]
print(f"{len(relevant_plan_basis)} relevant plan records")
# idx_selected_relevant_plan_basis = np.zeros(len(relevant_plan_basis), dtype=bool)
# 好像并不是要找出所有的plan,而是要找出所有的任务,而detector级的任务要比plan_basis多得多
task_list = []
relevant_data_id_list = []
# loop over plan
for i_plan_basis, this_plan_basis in enumerate(relevant_plan_basis):
print(f"Processing {i_plan_basis + 1}/{len(relevant_plan_basis)}")
# span over `detector`
for this_detector in effective_detector_names:
# construct this_task
this_task = dict(
dataset=this_plan_basis["dataset"],
instrument=this_plan_basis["instrument"],
obs_type=this_plan_basis["obs_type"],
obs_group=this_plan_basis["obs_group"],
obs_id=this_plan_basis["obs_id"],
detector=this_detector,
)
# find this plan basis
idx_this_plan_basis = np.argwhere(
plan_basis == this_plan_basis
).flatten()[0]
# get n_frame, calculate n_file_expected
if instrument == "HSTDM":
n_file_expected = plan_recs.data[idx_this_plan_basis]["params"][
"num_epec_frame"
]
else:
n_file_expected = 1
# count files found in data_basis
idx_files_found = np.argwhere(data_basis == this_task).flatten()
n_file_found = len(idx_files_found)
# if found == expected, append this task
if n_file_found == n_file_expected:
task_list.append(this_task)
relevant_data_id_list.extend(
[data_recs.data[_]["_id"] for _ in idx_files_found]
)
return dict(task_list=task_list, relevant_data_id_list=relevant_data_id_list)
@staticmethod
def dispatch_level0_obsid(**kwargs) -> list[dict]:
pass
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment