import numpy as np import joblib from astropy import table from csst_dfs_client import plan, level0, level1 from tqdm import trange from .._csst import csst # from csst_dag._csst import csst # THESE ARE GENERAL PARAMETERS! PLAN_PARAMS = { "dataset": None, "instrument": None, "obs_type": None, "obs_group": None, "obs_id": None, "proposal_id": None, } LEVEL0_PARAMS = { "dataset": None, "instrument": None, "obs_type": None, "obs_group": None, "obs_id": None, "detector": None, "prc_status": -1024, } LEVEL1_PARAMS = { "dataset": None, "instrument": None, "obs_type": None, "obs_group": None, "obs_id": None, "detector": None, "prc_status": -1024, "data_model": None, "batch_id": "default_batch", } PROC_PARAMS = { "priority": 1, "batch_id": "default_batch", "pmapname": "pmapname", "final_prc_status": -2, "demo": False, # should be capable to extend } def override_common_keys(d1: dict, d2: dict) -> dict: """ Construct a new dictionary by updating the values of basis_keys that exists in the first dictionary with the values of the second dictionary. Parameters ---------- d1 : dict The first dictionary. d2 : dict The second dictionary. Returns ------- dict: The updated dictionary. """ return {k: d2[k] if k in d2.keys() else d1[k] for k in d1.keys()} def extract_basis_table(dlist: list[dict], basis_keys: tuple) -> table.Table: """Extract basis key-value pairs from a list of dictionaries.""" return table.Table([{k: d.get(k) for k in basis_keys} for d in dlist]) def split_data_basis(data_basis: table.Table, n_split: int = 1) -> list[table.Table]: """Split data basis into n_split parts.""" assert ( np.unique(data_basis["dataset"]).size == 1 ), "Only one dataset is allowed for splitting." # sort data_basis.sort(keys=["dataset", "obs_id"]) # get unique obsid u_obsid, i_obsid, c_obsid = np.unique( data_basis["obs_id"].data, return_index=True, return_counts=True ) # set chunk size chunk_size = int(np.fix(len(u_obsid) / n_split)) # initialize chunks chunks = [] for i_split in range(n_split): if i_split < n_split - 1: chunks.append( data_basis[ i_obsid[i_split * chunk_size] : i_obsid[(i_split + 1) * chunk_size] ] ) else: chunks.append(data_basis[i_obsid[i_split * chunk_size] :]) # np.unique(table.vstack(chunks)["_id"]) # np.unique(table.vstack(chunks)["obs_id"]) return chunks # plan basis keys PLAN_BASIS_KEYS = ( "dataset", "instrument", "obs_type", "obs_group", "obs_id", "n_frame", "_id", ) # data basis keys DATA_BASIS_KEYS = ( "dataset", "instrument", "obs_type", "obs_group", "obs_id", "detector", "file_name", "_id", ) class Dispatcher: """ A class to dispatch tasks based on the observation type. """ @staticmethod def find_plan_basis(**kwargs) -> table.Table: """ Find plan records. """ # query qr = plan.find(**override_common_keys(PLAN_PARAMS, kwargs)) assert qr.success, qr # plan basis / obsid basis for _ in qr.data: _["n_frame"] = ( _["params"]["n_epec_frame"] if _["instrument"] == "HSTDM" else 1 ) plan_basis = extract_basis_table( qr.data, PLAN_BASIS_KEYS, ) return plan_basis @staticmethod def find_level0_basis(**kwargs) -> table.Table: """ Find level0 records. """ # query qr = level0.find(**override_common_keys(LEVEL0_PARAMS, kwargs)) assert qr.success, qr # data basis data_basis = extract_basis_table( qr.data, DATA_BASIS_KEYS, ) return data_basis @staticmethod def find_level1_basis(**kwargs) -> table.Table: """ Find level1 records. """ # query qr = level1.find(**override_common_keys(LEVEL1_PARAMS, kwargs)) assert qr.success, qr # data basis data_basis = extract_basis_table( qr.data, DATA_BASIS_KEYS, ) return data_basis @staticmethod def dispatch_file( plan_basis: table.Table, data_basis: table.Table, ) -> list[dict]: # unique obsid u_obsid = table.unique(data_basis["dataset", "obs_id"]) # initialize task list task_list = [] # loop over plan for i_data_basis in trange( len(data_basis), unit="task", dynamic_ncols=True, ): # i_data_basis = 1 this_data_basis = data_basis[i_data_basis : i_data_basis + 1] this_relevant_plan = table.join( u_obsid, plan_basis, keys=["dataset", "obs_id"], join_type="inner", ) # append this task task_list.append( dict( task=this_data_basis, success=True, relevant_plan=this_relevant_plan, relevant_data=data_basis[i_data_basis : i_data_basis + 1], ) ) return task_list @staticmethod def dispatch_detector( plan_basis: table.Table, data_basis: table.Table, n_jobs: int = 1, ) -> list[dict]: """ Parameters ---------- plan_basis data_basis n_jobs Returns ------- """ if n_jobs != 1: task_list = joblib.Parallel(n_jobs=n_jobs)( joblib.delayed(Dispatcher.dispatch_detector)(plan_basis, _) for _ in split_data_basis(data_basis, n_split=n_jobs) ) return sum(task_list, []) # unique obsid u_obsid = table.unique(data_basis["dataset", "obs_id"]) relevant_plan = table.join( u_obsid, plan_basis, keys=["dataset", "obs_id"], join_type="left", ) print(f"{len(relevant_plan)} relevant plan records") u_data_detector = table.unique( data_basis[ "dataset", "instrument", "obs_type", "obs_group", "obs_id", "detector", ] ) # initialize task list task_list = [] # loop over plan for i_data_detector in trange( len(u_data_detector), unit="task", dynamic_ncols=True, ): # i_data_detector = 1 this_task = dict(u_data_detector[i_data_detector]) this_data_detector = u_data_detector[i_data_detector : i_data_detector + 1] # join data and plan this_data_detector_files = table.join( this_data_detector, data_basis, keys=[ "dataset", "instrument", "obs_type", "obs_group", "obs_id", "detector", ], join_type="inner", ) this_data_detector_plan = table.join( this_data_detector, relevant_plan, keys=[ "dataset", "instrument", "obs_type", "obs_group", "obs_id", ], join_type="left", ) # whether detector effective this_detector = this_data_detector["detector"][0] this_instrument = this_data_detector["instrument"][0] this_detector_effective = ( this_detector in csst[this_instrument].effective_detector_names ) n_files_expected = this_data_detector_plan["n_frame"][0] n_files_found = len(this_data_detector_files) # append this task task_list.append( dict( task=this_task, success=( len(this_data_detector_plan) == 1 and len(this_data_detector_files) == 1 and this_detector_effective and n_files_found == n_files_expected ), relevant_plan=this_data_detector_plan, relevant_data=this_data_detector_files, ) ) return task_list @staticmethod def dispatch_obsid( plan_basis: table.Table, data_basis: table.Table, n_jobs: int = 1, ) -> list[dict]: if n_jobs != 1: task_list = joblib.Parallel(n_jobs=n_jobs)( joblib.delayed(Dispatcher.dispatch_obsid)(plan_basis, _) for _ in split_data_basis(data_basis, n_split=n_jobs) ) return sum(task_list, []) # unique obsid u_obsid = table.unique(data_basis["dataset", "obs_id"]) relevant_plan = table.join( u_obsid, plan_basis, keys=["dataset", "obs_id"], join_type="left", ) print(f"{len(relevant_plan)} relevant plan records") u_data_obsid = table.unique( data_basis[ "dataset", "instrument", "obs_type", "obs_group", "obs_id", ] ) # initialize task list task_list = [] # loop over plan for i_data_obsid in trange( len(u_data_obsid), unit="task", dynamic_ncols=True, ): i_data_obsid = 2 this_task = dict(u_data_obsid[i_data_obsid]) this_data_obsid = u_data_obsid[i_data_obsid : i_data_obsid + 1] # join data and plan this_data_obsid_files = table.join( this_data_obsid, data_basis, keys=[ "dataset", "instrument", "obs_type", "obs_group", "obs_id", ], join_type="inner", ) this_data_obsid_plan = table.join( this_data_obsid, relevant_plan, keys=[ "dataset", "instrument", "obs_type", "obs_group", "obs_id", ], join_type="left", ) # whether effective detectors all there this_instrument = this_data_obsid["instrument"][0] this_success = set(csst[this_instrument].effective_detector_names).issubset( set(this_data_obsid_files["detector"]) ) # append this task task_list.append( dict( task=this_task, success=this_success, relevant_plan=this_data_obsid_plan, relevant_data=this_data_obsid_files, ) ) return task_list @staticmethod def dispatch_obsgroup( plan_basis: table.Table, data_basis: table.Table, # n_jobs: int = 1, ) -> list[dict]: # unique obsgroup basis obsgroup_basis = table.unique( plan_basis[ "dataset", "instrument", "obs_type", "obs_group", ] ) # initialize task list task_list = [] # loop over obsgroup for i_obsgroup in trange( len(obsgroup_basis), unit="task", dynamic_ncols=True, ): # i_obsgroup = 1 this_task = dict(obsgroup_basis[i_obsgroup]) this_success = True this_obsgroup_obsid = table.join( obsgroup_basis[i_obsgroup : i_obsgroup + 1], # this obsgroup plan_basis, keys=["dataset", "instrument", "obs_type", "obs_group"], join_type="left", ) this_obsgroup_file = table.join( this_obsgroup_obsid, data_basis, keys=["dataset", "instrument", "obs_type", "obs_group", "obs_id"], join_type="inner", table_names=["plan", "data"], ) # loop over obsid for i_obsid in range(len(this_obsgroup_obsid)): # i_obsid = 1 # print(i_obsid) instrument = this_obsgroup_obsid[i_obsid]["instrument"] n_frame = this_obsgroup_obsid[i_obsid]["n_frame"] effective_detector_names = csst[instrument].effective_detector_names this_obsgroup_obsid_file = table.join( this_obsgroup_obsid[i_obsid : i_obsid + 1], # this obsid data_basis, keys=["dataset", "instrument", "obs_type", "obs_group", "obs_id"], join_type="inner", table_names=["plan", "data"], ) if instrument == "HSTDM": # 我也不知道太赫兹要怎么玩 # this_success &= ( # len(this_obsgroup_obsid_file) == n_frame # or len(this_obsgroup_obsid_file) == n_frame * 2 # ) # or simply this_success &= len(this_obsgroup_obsid_file) % n_frame == 0 else: # n_detector == n_file # this_success &= len(this_obsgroup_obsid_file) == len( # effective_detector_names # ) # or more strictly, each detector matches this_success &= set(this_obsgroup_obsid_file["detector"]) == set( effective_detector_names ) # append this task task_list.append( dict( task=this_task, success=this_success, relevant_plan=this_obsgroup_obsid, relevant_data=this_obsgroup_file, ) ) return task_list @staticmethod def load_test_data() -> tuple: import joblib plan_recs = joblib.load("dagtest/csst-msc-c9-25sqdeg-v3.plan.dump") data_recs = joblib.load("dagtest/csst-msc-c9-25sqdeg-v3.level0.dump") print(f"{len(plan_recs.data)} plan records") print(f"{len(data_recs.data)} data records") for _ in plan_recs.data: _["n_frame"] = ( _["params"]["n_epec_frame"] if _["instrument"] == "HSTDM" else 1 ) plan_basis = extract_basis_table( plan_recs.data, PLAN_BASIS_KEYS, ) data_basis = extract_basis_table( data_recs.data, DATA_BASIS_KEYS, ) return plan_basis, data_basis # # 1221 plan recs, 36630 data recs # plan_basis, data_basis = Dispatcher.load_test_data() # # # 430 task/s # task_list_via_file = Dispatcher.dispatch_file(plan_basis, data_basis) # # # 13 task/s @n_jobs=1, 100*10 task/s @n_jobs=10 (max) # task_list_via_detector = Dispatcher.dispatch_detector(plan_basis, data_basis, n_jobs=10) # # # 16 task/s @n_jobs=1, 130*10 tasks/s @n_jobs=10 (max) 🔼 # task_list_via_obsid = Dispatcher.dispatch_obsid(plan_basis, data_basis, n_jobs=10) # # # 13s/task # task_list_via_obsgroup = Dispatcher.dispatch_obsgroup(plan_basis, data_basis) # print( # sum(_["success"] for _ in task_list_via_obsgroup), # "/", # len(task_list_via_obsgroup), # )