import numpy as np from astropy.table import Table from csst_dfs_client import plan, level0, level1 from .._csst import csst # THESE ARE GENERAL PARAMETERS! PLAN_PARAMS = { "dataset": None, "instrument": None, "obs_type": None, "obs_group": None, "obs_id": None, "proposal_id": None, } LEVEL0_PARAMS = { "dataset": None, "instrument": None, "obs_type": None, "obs_group": None, "obs_id": None, "detector": None, "prc_status": -1024, } LEVEL1_PARAMS = { "dataset": None, "instrument": None, "obs_type": None, "obs_group": None, "obs_id": None, "detector": None, "prc_status": -1024, "data_model": None, "batch_id": "default_batch", } PROC_PARAMS = { "priority": 1, "batch_id": "default_batch", "pmapname": "pmapname", "final_prc_status": -2, "demo": False, } def override_common_keys(d1: dict, d2: dict) -> dict: """ Construct a new dictionary by updating the values of basis_keys that exists in the first dictionary with the values of the second dictionary. Parameters ---------- d1 : dict The first dictionary. d2 : dict The second dictionary. Returns ------- dict: The updated dictionary. """ return {k: d2[k] if k in d2.keys() else d1[k] for k in d1.keys()} # def extract_basis(dlist: list[dict], basis_keys: tuple) -> np.ndarray: # """Extract basis key-value pairs from a list of dictionaries.""" # return Table([{k: d.get(k) for k in basis_keys} for d in dlist]).as_array() def extract_basis(dlist: list[dict], basis_keys: tuple) -> np.typing.NDArray: """Extract basis key-value pairs from a list of dictionaries.""" return np.array([{k: d.get(k) for k in basis_keys} for d in dlist], dtype=dict) class Dispatcher: """ A class to dispatch tasks based on the observation type. """ @staticmethod def dispatch_level0_file(**kwargs) -> dict: # plan_recs = plan.find(**override_common_keys(PLAN_PARAMS, kwargs)) data_recs = level0.find(**override_common_keys(LEVEL0_PARAMS, kwargs)) # construct results task_list = [] for data_rec in data_recs: # construct task task = dict( dataset=data_rec["dataset"], instrument=data_rec["instrument"], obs_type=data_rec["obs_type"], obs_group=data_rec["obs_group"], obs_id=data_rec["obs_id"], detector=data_rec["detector"], file_name=data_rec["file_name"], ) return dict( task_list=task_list, relevant_data_id_list=[], ) @staticmethod def dispatch_level0_detector(**kwargs) -> dict: # get instrument assert "instrument" in kwargs.keys(), f"{kwargs} does not have key 'instrument'" instrument = kwargs.get("instrument") assert instrument in ("MSC", "MCI", "IFS", "CPIC", "HSTDM") # query for plan and data plan_recs = plan.find(**override_common_keys(PLAN_PARAMS, kwargs)) assert plan_recs.success, plan_recs data_recs = level0.find(**override_common_keys(LEVEL0_PARAMS, kwargs)) assert data_recs.success, data_recs import joblib plan_recs = joblib.load("dagtest/csst-msc-c9-25sqdeg-v3.plan.dump") data_recs = joblib.load("dagtest/csst-msc-c9-25sqdeg-v3.level0.dump") print(f"{len(plan_recs.data)} plan records") print(f"{len(data_recs.data)} data records") instrument = "MSC" from csst_dag._csst import csst effective_detector_names = csst[instrument].effective_detector_names # extract info plan_basis = extract_basis( plan_recs.data, ( "dataset", "instrument", "obs_type", "obs_group", "obs_id", ), ) data_basis = extract_basis( data_recs.data, ( "dataset", "instrument", "obs_type", "obs_group", "obs_id", "detector", ), ) # select plan basis relevant to data via `obs_id` u_data_obsid = np.unique([_["obs_id"] for _ in data_basis]) relevant_plan_basis = [_ for _ in plan_basis if _["obs_id"] in u_data_obsid] print(f"{len(relevant_plan_basis)} relevant plan records") # idx_selected_relevant_plan_basis = np.zeros(len(relevant_plan_basis), dtype=bool) # 好像并不是要找出所有的plan,而是要找出所有的任务,而detector级的任务要比plan_basis多得多 task_list = [] relevant_data_id_list = [] # loop over plan for i_plan_basis, this_plan_basis in enumerate(relevant_plan_basis): print(f"Processing {i_plan_basis + 1}/{len(relevant_plan_basis)}") # span over `detector` for this_detector in effective_detector_names: # construct this_task this_task = dict( dataset=this_plan_basis["dataset"], instrument=this_plan_basis["instrument"], obs_type=this_plan_basis["obs_type"], obs_group=this_plan_basis["obs_group"], obs_id=this_plan_basis["obs_id"], detector=this_detector, ) # find this plan basis idx_this_plan_basis = np.argwhere( plan_basis == this_plan_basis ).flatten()[0] # get n_frame, calculate n_file_expected if instrument == "HSTDM": n_file_expected = plan_recs.data[idx_this_plan_basis]["params"][ "num_epec_frame" ] else: n_file_expected = 1 # count files found in data_basis idx_files_found = np.argwhere(data_basis == this_task).flatten() n_file_found = len(idx_files_found) # if found == expected, append this task if n_file_found == n_file_expected: task_list.append(this_task) relevant_data_id_list.extend( [data_recs.data[_]["_id"] for _ in idx_files_found] ) return dict(task_list=task_list, relevant_data_id_list=relevant_data_id_list) @staticmethod def dispatch_level0_obsid(**kwargs) -> list[dict]: pass