import numpy as np import joblib from astropy import table from csst_dfs_client import plan, level0, level1 from tqdm import trange from .._csst import csst # from csst_dag._csst import csst TQDM_KWARGS = dict(unit="task", dynamic_ncols=False) # THESE ARE GENERAL PARAMETERS! PLAN_PARAMS = { "dataset": None, "instrument": None, "obs_type": None, "obs_group": None, "obs_id": None, "proposal_id": None, } LEVEL0_PARAMS = { "dataset": None, "instrument": None, "obs_type": None, "obs_group": None, "obs_id": None, "detector": None, "prc_status": None, "qc_status": None, } LEVEL1_PARAMS = { "dataset": None, "instrument": None, "obs_type": None, "obs_group": None, "obs_id": None, "detector": None, "prc_status": None, "qc_status": None, # special keys for data products "data_model": None, "batch_id": "default_batch", "build": None, "pmapname": None, } # PROC_PARAMS = { # "priority": 1, # "batch_id": "default_batch", # "pmapname": "pmapname", # "final_prc_status": -2, # "demo": False, # # should be capable to extend # } # plan basis keys PLAN_BASIS_KEYS = ( "dataset", "instrument", "obs_type", "obs_group", "obs_id", "n_file", "_id", ) # data basis keys DATA_BASIS_KEYS = ( "dataset", "instrument", "obs_type", "obs_group", "obs_id", "detector", "file_name", "_id", "prc_status", ) # join_type for data x plan PLAN_JOIN_TYPE = "inner" """ References: - https://docs.astropy.org/en/stable/api/astropy.table.join.html - https://docs.astropy.org/en/stable/table/operations.html#join Typical types: - inner join: Only matching rows from both tables - left join: All rows from left table, matching rows from right table - right join: All rows from right table, matching rows from left table - outer join: All rows from both tables - cartesian join: Every combination of rows from both tables """ def override_common_keys(d1: dict, d2: dict) -> dict: """ Construct a new dictionary by updating the values of basis_keys that exists in the first dictionary with the values of the second dictionary. Parameters ---------- d1 : dict The first dictionary. d2 : dict The second dictionary. Returns ------- dict: The updated dictionary. """ return {k: d2[k] if k in d2.keys() else d1[k] for k in d1.keys()} def extract_basis_table(dlist: list[dict], basis_keys: tuple) -> table.Table: """Extract basis key-value pairs from a list of dictionaries.""" return table.Table([{k: d.get(k, "") for k in basis_keys} for d in dlist]) def split_data_basis(data_basis: table.Table, n_split: int = 1) -> list[table.Table]: """Split data basis into n_split parts via obs_id""" assert ( np.unique(data_basis["dataset"]).size == 1 ), "Only one dataset is allowed for splitting." # sort data_basis.sort(keys=["dataset", "obs_id"]) # get unique obsid u_obsid, i_obsid, c_obsid = np.unique( data_basis["obs_id"].data, return_index=True, return_counts=True ) # set chunk size chunk_size = int(np.fix(len(u_obsid) / n_split)) # initialize chunks chunks = [] for i_split in range(n_split): if i_split < n_split - 1: chunks.append( data_basis[ i_obsid[i_split * chunk_size] : i_obsid[(i_split + 1) * chunk_size] ] ) else: chunks.append(data_basis[i_obsid[i_split * chunk_size] :]) # np.unique(table.vstack(chunks)["_id"]) # np.unique(table.vstack(chunks)["obs_id"]) return chunks class Dispatcher: """ A class to dispatch tasks based on the observation type. """ @staticmethod def find_plan_basis(**kwargs) -> table.Table: """ Find plan records. """ # query prompt = "plan" qr_kwargs = override_common_keys(PLAN_PARAMS, kwargs) qr = plan.find(**qr_kwargs) assert qr.success, qr print(f">>> [{prompt}] query kwargs: {qr_kwargs}") print(f">>> [{prompt}] {len(qr.data)} records found.") # plan basis / obsid basis try: for _ in qr.data: this_instrument = _["instrument"] if this_instrument == "HSTDM": if _["params"]["detector"] == "SIS12": this_n_file = len(_["params"]["exposure_start"]) * 2 else: this_n_file = len(_["params"]["exposure_start"]) else: this_n_file = len(csst[this_instrument].effective_detector_names) _["n_file"] = this_n_file except KeyError: print(f"`n_epec_frame` is not found in {_}") raise KeyError(f"`n_epec_frame` is not found in {_}") plan_basis = extract_basis_table( qr.data, PLAN_BASIS_KEYS, ) return plan_basis @staticmethod def find_level0_basis(**kwargs) -> table.Table: """ Find level0 records. """ # query prompt = "level0" qr_kwargs = override_common_keys(LEVEL0_PARAMS, kwargs) qr = level0.find(**qr_kwargs) assert qr.success, qr print(f">>> [{prompt}] query kwargs: {qr_kwargs}") print(f">>> [{prompt}] {len(qr.data)} records found.") # data basis data_basis = extract_basis_table( qr.data, DATA_BASIS_KEYS, ) return data_basis @staticmethod def find_level1_basis(**kwargs) -> table.Table: """ Find level1 records. """ # query prompt = "level1" qr_kwargs = override_common_keys(LEVEL1_PARAMS, kwargs) qr = level1.find(**qr_kwargs) assert qr.success, qr print(f">>> [{prompt}] query kwargs: {qr_kwargs}") print(f">>> [{prompt}] {len(qr.data)} records found.") # data basis data_basis = extract_basis_table( qr.data, DATA_BASIS_KEYS, ) return data_basis @staticmethod def find_plan_level0_basis(**kwargs) -> tuple[table.Table, table.Table]: data_basis = Dispatcher.find_level0_basis(**kwargs) plan_basis = Dispatcher.find_plan_basis(**kwargs) assert len(data_basis) > 0, data_basis assert len(plan_basis) > 0, plan_basis u_data_basis = table.unique(data_basis["dataset", "obs_id"]) relevant_plan = table.join( u_data_basis, plan_basis, keys=["dataset", "obs_id"], join_type=PLAN_JOIN_TYPE, ) assert len(relevant_plan) > 0, relevant_plan return relevant_plan, data_basis @staticmethod def find_plan_level1_basis(**kwargs) -> tuple[table.Table, table.Table]: data_basis = Dispatcher.find_level1_basis(**kwargs) plan_basis = Dispatcher.find_plan_basis(**kwargs) assert len(data_basis) > 0, data_basis assert len(plan_basis) > 0, plan_basis u_data_basis = table.unique(data_basis["dataset", "obs_id"]) relevant_plan = table.join( u_data_basis, plan_basis, keys=["dataset", "obs_id"], join_type=PLAN_JOIN_TYPE, ) assert len(relevant_plan) > 0, relevant_plan return relevant_plan, data_basis @staticmethod def dispatch_file( plan_basis: table.Table, data_basis: table.Table, ) -> list[dict]: # unique obsid --> useless # u_obsid = table.unique(data_basis["dataset", "obs_id"]) # return an empty list if input is empty if len(plan_basis) == 0 or len(data_basis) == 0: return [] # initialize task list task_list = [] # sort data_basis before dispatching data_basis.sort(keys=data_basis.colnames) # loop over data for i_data_basis in trange(len(data_basis), **TQDM_KWARGS): # i_data_basis = 1 this_task = dict(data_basis[i_data_basis]) this_data_basis = data_basis[i_data_basis : i_data_basis + 1] this_relevant_plan = table.join( this_data_basis[ "dataset", "instrument", "obs_type", "obs_group", "obs_id", ], plan_basis, keys=[ "dataset", "instrument", "obs_type", "obs_group", "obs_id", ], join_type="inner", table_names=["data", "plan"], ) # set n_file_expected and n_file_found this_task["n_file_expected"] = 1 this_task["n_file_found"] = 1 # append this task task_list.append( dict( task=this_task, success=True, relevant_plan=this_relevant_plan, relevant_data=data_basis[i_data_basis : i_data_basis + 1], n_relevant_plan=len(this_relevant_plan), n_relevant_data=1, relevant_data_id_list=[data_basis[i_data_basis]["_id"]], n_file_expected=1, n_file_found=1, ) ) return task_list @staticmethod def dispatch_detector( plan_basis: table.Table, data_basis: table.Table, n_jobs: int = 1, ) -> list[dict]: """ Parameters ---------- plan_basis data_basis n_jobs Returns ------- """ if n_jobs != 1: task_list = joblib.Parallel(n_jobs=n_jobs)( joblib.delayed(Dispatcher.dispatch_detector)(plan_basis, _) for _ in split_data_basis(data_basis, n_split=n_jobs) ) return sum(task_list, []) # return an empty list if input is empty if len(plan_basis) == 0 or len(data_basis) == 0: return [] # unique obsid u_obsid = table.unique(data_basis["dataset", "obs_id"]) relevant_plan = table.join( u_obsid, plan_basis, keys=["dataset", "obs_id"], join_type=PLAN_JOIN_TYPE, ) print(f"{len(relevant_plan)} relevant plan records") u_data_detector = table.unique( data_basis[ "dataset", "instrument", "obs_type", "obs_group", "obs_id", "detector", ] ) # initialize task list task_list = [] # loop over plan for i_data_detector in trange(len(u_data_detector), **TQDM_KWARGS): # i_data_detector = 1 this_task = dict(u_data_detector[i_data_detector]) this_data_detector = u_data_detector[i_data_detector : i_data_detector + 1] # join data and plan this_data_detector_files = table.join( this_data_detector, data_basis, keys=[ "dataset", "instrument", "obs_type", "obs_group", "obs_id", "detector", ], join_type="inner", ) this_data_detector_plan = table.join( this_data_detector, relevant_plan, keys=[ "dataset", "instrument", "obs_type", "obs_group", "obs_id", ], join_type=PLAN_JOIN_TYPE, ) # whether detector effective this_detector = this_data_detector["detector"][0] this_instrument = this_data_detector["instrument"][0] this_detector_effective = ( this_detector in csst[this_instrument].effective_detector_names ) n_file_expected = ( this_data_detector_plan["n_file"][0] if len(this_data_detector_plan) > 0 else 0 ) n_file_found = len(this_data_detector_files) # set n_file_expected and n_file_found this_task["n_file_expected"] = n_file_expected this_task["n_file_found"] = n_file_found # append this task task_list.append( dict( task=this_task, success=( len(this_data_detector_plan) == 1 and len(this_data_detector_files) == 1 and this_detector_effective and n_file_found == n_file_expected ), relevant_plan=this_data_detector_plan, relevant_data=this_data_detector_files, n_relevant_plan=len(this_data_detector_plan), n_relevant_data=len(this_data_detector_files), relevant_data_id_list=( [] if len(this_data_detector_files) == 0 else list(this_data_detector_files["_id_data"]) ), n_file_expected=this_data_detector_plan["n_file"].sum(), n_file_found=len(this_data_detector_files), ) ) return task_list @staticmethod def dispatch_obsid( plan_basis: table.Table, data_basis: table.Table, n_jobs: int = 1, ) -> list[dict]: if n_jobs != 1: task_list = joblib.Parallel(n_jobs=n_jobs)( joblib.delayed(Dispatcher.dispatch_obsid)(plan_basis, _) for _ in split_data_basis(data_basis, n_split=n_jobs) ) return sum(task_list, []) # return an empty list if input is empty if len(plan_basis) == 0 or len(data_basis) == 0: return [] obsid_basis = data_basis.group_by([""]) # unique obsid u_obsid = table.unique(data_basis["dataset", "obs_id"]) relevant_plan = table.join( u_obsid, plan_basis, keys=["dataset", "obs_id"], join_type=PLAN_JOIN_TYPE, ) print(f"{len(relevant_plan)} relevant plan records") u_data_obsid = table.unique( data_basis[ "dataset", "instrument", "obs_type", "obs_group", "obs_id", ] ) # initialize task list task_list = [] # loop over plan for i_data_obsid in trange(len(u_data_obsid), **TQDM_KWARGS): # i_data_obsid = 2 this_task = dict(u_data_obsid[i_data_obsid]) this_data_obsid = u_data_obsid[i_data_obsid : i_data_obsid + 1] # join data and plan this_data_obsid_file = table.join( this_data_obsid, data_basis, keys=[ "dataset", "instrument", "obs_type", "obs_group", "obs_id", ], join_type="inner", ) # print(this_data_obsid_file.colnames) this_data_obsid_plan = table.join( this_data_obsid, relevant_plan, keys=[ "dataset", "instrument", "obs_type", "obs_group", "obs_id", ], join_type=PLAN_JOIN_TYPE, ) # whether effective detectors all there this_instrument = this_data_obsid["instrument"][0] this_n_file = ( this_data_obsid_plan["n_file"] if len(this_data_obsid_plan) > 0 else 0 ) this_effective_detector_names = csst[ this_instrument ].effective_detector_names if this_instrument == "HSTDM": # 不确定以后是1个探测器还是2个探测器 this_n_file_found = len(this_data_obsid_file) this_n_file_expected = (this_n_file, this_n_file * 2) this_success = this_n_file_found in this_n_file_expected else: # for other instruments, e.g., MSC # n_file_found = len(this_obsgroup_obsid_file) # n_file_expected = len(effective_detector_names) # this_success &= n_file_found == n_file_expected # or more strictly, expected files are a subset of files found this_success = set(this_effective_detector_names) <= set( this_data_obsid_file["detector"] ) n_file_expected = int(this_data_obsid_plan["n_file"].sum()) n_file_found = len(this_data_obsid_file) # set n_file_expected and n_file_found this_task["n_file_expected"] = n_file_expected this_task["n_file_found"] = n_file_found # append this task task_list.append( dict( task=this_task, success=this_success, relevant_plan=this_data_obsid_plan, relevant_data=this_data_obsid_file, n_relevant_plan=len(this_data_obsid_plan), n_relevant_data=len(this_data_obsid_file), relevant_data_id_list=( [] if len(this_data_obsid_file) == 0 else list(this_data_obsid_file["_id"]) ), n_file_expected=this_data_obsid_plan["n_file"].sum(), n_file_found=len(this_data_obsid_file), ) ) return task_list @staticmethod def dispatch_obsgroup_detector( plan_basis: table.Table, data_basis: table.Table, # n_jobs: int = 1, ) -> list[dict]: # return an empty list if input is empty if len(plan_basis) == 0 or len(data_basis) == 0: return [] # unique obsgroup basis (using group_by) obsgroup_basis = plan_basis.group_by( keys=[ "dataset", "instrument", "obs_type", "obs_group", ] ) # initialize task list task_list = [] # loop over obsgroup for i_obsgroup in trange(len(obsgroup_basis.groups), **TQDM_KWARGS): this_obsgroup_basis = obsgroup_basis.groups[i_obsgroup] this_obsgroup_obsid = this_obsgroup_basis["obs_id"].data n_file_expected = this_obsgroup_basis["n_file"].sum() this_instrument = this_obsgroup_basis["instrument"][0] effective_detector_names = csst[this_instrument].effective_detector_names for this_effective_detector_name in effective_detector_names: this_task = dict( dataset=this_obsgroup_basis["dataset"][0], instrument=this_obsgroup_basis["instrument"][0], obs_type=this_obsgroup_basis["obs_type"][0], obs_group=this_obsgroup_basis["obs_group"][0], detector=this_effective_detector_name, ) this_obsgroup_detector_expected = table.Table( [ dict( dataset=this_obsgroup_basis["dataset"][0], instrument=this_obsgroup_basis["instrument"][0], obs_type=this_obsgroup_basis["obs_type"][0], obs_group=this_obsgroup_basis["obs_group"][0], obs_id=this_obsid, detector=this_effective_detector_name, ) for this_obsid in this_obsgroup_obsid ] ) this_obsgroup_detector_found = table.join( this_obsgroup_detector_expected, data_basis, keys=[ "dataset", "instrument", "obs_type", "obs_group", "obs_id", "detector", ], join_type="inner", ) n_file_found = len(this_obsgroup_detector_found) this_success = n_file_found == n_file_expected and set( this_obsgroup_detector_found["obs_id"] ) == set(this_obsgroup_obsid) # set n_file_expected and n_file_found this_task["n_file_expected"] = n_file_expected this_task["n_file_found"] = n_file_found # append this task task_list.append( dict( task=this_task, success=this_success, relevant_plan=this_obsgroup_basis, relevant_data=this_obsgroup_detector_found, n_relevant_plan=len(this_obsgroup_basis), n_relevant_data=len(this_obsgroup_detector_found), relevant_data_id_list=( list(this_obsgroup_detector_found["_id"]) if n_file_found > 0 else [] ), n_file_expected=n_file_expected, n_file_found=n_file_found, ) ) return task_list @staticmethod def dispatch_obsgroup( plan_basis: table.Table, data_basis: table.Table, # n_jobs: int = 1, ) -> list[dict]: # return an empty list if input is empty if len(plan_basis) == 0 or len(data_basis) == 0: return [] # unique obsgroup basis obsgroup_basis = table.unique( plan_basis[ "dataset", "instrument", "obs_type", "obs_group", ] ) # initialize task list task_list = [] # loop over obsgroup for i_obsgroup in trange(len(obsgroup_basis), **TQDM_KWARGS): # i_obsgroup = 1 this_task = dict(obsgroup_basis[i_obsgroup]) this_success = True this_obsgroup_plan = table.join( obsgroup_basis[i_obsgroup : i_obsgroup + 1], # this obsgroup plan_basis, keys=["dataset", "instrument", "obs_type", "obs_group"], join_type=PLAN_JOIN_TYPE, ) this_obsgroup_file = table.join( this_obsgroup_plan, data_basis, keys=["dataset", "instrument", "obs_type", "obs_group", "obs_id"], join_type="inner", table_names=["plan", "data"], ) # loop over obsid for i_obsid in range(len(this_obsgroup_plan)): # i_obsid = 1 # print(i_obsid) this_instrument = this_obsgroup_plan[i_obsid]["instrument"] this_n_file = this_obsgroup_plan[i_obsid]["n_file"] this_effective_detector_names = csst[ this_instrument ].effective_detector_names this_obsgroup_obsid_file = table.join( this_obsgroup_plan[i_obsid : i_obsid + 1], # this obsid data_basis, keys=["dataset", "instrument", "obs_type", "obs_group", "obs_id"], join_type="inner", table_names=["plan", "data"], ) if this_instrument == "HSTDM": # 不确定以后是1个探测器还是2个探测器 this_n_file_found = len(this_obsgroup_obsid_file) this_n_file_expected = this_n_file this_success &= this_n_file_found == this_n_file_expected else: # for other instruments, e.g., MSC # n_file_found = len(this_obsgroup_obsid_file) # n_file_expected = len(effective_detector_names) # this_success &= n_file_found == n_file_expected # or more strictly, expected files are a subset of files found this_success &= set(this_effective_detector_names) <= set( this_obsgroup_obsid_file["detector"] ) n_file_expected = int(this_obsgroup_plan["n_file"].sum()) n_file_found = len(this_obsgroup_file) # set n_file_expected and n_file_found this_task["n_file_expected"] = n_file_expected this_task["n_file_found"] = n_file_found # append this task task_list.append( dict( task=this_task, success=this_success, relevant_plan=this_obsgroup_plan, relevant_data=this_obsgroup_file, n_relevant_plan=len(this_obsgroup_plan), n_relevant_data=len(this_obsgroup_file), relevant_data_id_list=( [] if len(this_obsgroup_file) == 0 else list(this_obsgroup_file["_id_data"]) ), n_file_expected=this_obsgroup_plan["n_file"].sum(), n_file_found=len(this_obsgroup_file), ) ) return task_list @staticmethod def load_test_data() -> tuple: import joblib plan_recs = joblib.load("dagtest/csst-msc-c9-25sqdeg-v3.plan.dump") data_recs = joblib.load("dagtest/csst-msc-c9-25sqdeg-v3.level0.dump") print(f"{len(plan_recs.data)} plan records") print(f"{len(data_recs.data)} data records") for _ in plan_recs.data: _["n_file"] = ( _["params"]["num_epec_frame"] if _["instrument"] == "HSTDM" else 1 ) plan_basis = extract_basis_table( plan_recs.data, PLAN_BASIS_KEYS, ) data_basis = extract_basis_table( data_recs.data, DATA_BASIS_KEYS, ) return plan_basis, data_basis