Commit 3ddcd5a6 authored by BO ZHANG's avatar BO ZHANG 🏀
Browse files

update dispatcher.py

parent c1df3efa
import numpy as np
from astropy.table import Table
import joblib
from astropy import table
from csst_dfs_client import plan, level0, level1
from tqdm import trange
from .._csst import csst
# from csst_dag._csst import csst
# THESE ARE GENERAL PARAMETERS!
PLAN_PARAMS = {
"dataset": None,
......@@ -41,6 +47,7 @@ PROC_PARAMS = {
"pmapname": "pmapname",
"final_prc_status": -2,
"demo": False,
# should be capable to extend
}
......@@ -64,14 +71,62 @@ def override_common_keys(d1: dict, d2: dict) -> dict:
return {k: d2[k] if k in d2.keys() else d1[k] for k in d1.keys()}
# def extract_basis(dlist: list[dict], basis_keys: tuple) -> np.ndarray:
# """Extract basis key-value pairs from a list of dictionaries."""
# return Table([{k: d.get(k) for k in basis_keys} for d in dlist]).as_array()
def extract_basis_table(dlist: list[dict], basis_keys: tuple) -> table.Table:
"""Extract basis key-value pairs from a list of dictionaries."""
return table.Table([{k: d.get(k) for k in basis_keys} for d in dlist])
def extract_basis(dlist: list[dict], basis_keys: tuple) -> np.typing.NDArray:
"""Extract basis key-value pairs from a list of dictionaries."""
return np.array([{k: d.get(k) for k in basis_keys} for d in dlist], dtype=dict)
def split_data_basis(data_basis: table.Table, n_split: int = 1) -> list[table.Table]:
"""Split data basis into n_split parts."""
assert (
np.unique(data_basis["dataset"]).size == 1
), "Only one dataset is allowed for splitting."
# sort
data_basis.sort(keys=["dataset", "obs_id"])
# get unique obsid
u_obsid, i_obsid, c_obsid = np.unique(
data_basis["obs_id"].data, return_index=True, return_counts=True
)
# set chunk size
chunk_size = int(np.fix(len(u_obsid) / n_split))
# initialize chunks
chunks = []
for i_split in range(n_split):
if i_split < n_split - 1:
chunks.append(
data_basis[
i_obsid[i_split * chunk_size] : i_obsid[(i_split + 1) * chunk_size]
]
)
else:
chunks.append(data_basis[i_obsid[i_split * chunk_size] :])
# np.unique(table.vstack(chunks)["_id"])
# np.unique(table.vstack(chunks)["obs_id"])
return chunks
# plan basis keys
PLAN_BASIS_KEYS = (
"dataset",
"instrument",
"obs_type",
"obs_group",
"obs_id",
"n_frame",
"_id",
)
# data basis keys
DATA_BASIS_KEYS = (
"dataset",
"instrument",
"obs_type",
"obs_group",
"obs_id",
"detector",
"file_name",
"_id",
)
class Dispatcher:
......@@ -80,125 +135,420 @@ class Dispatcher:
"""
@staticmethod
def dispatch_level0_file(**kwargs) -> dict:
# plan_recs = plan.find(**override_common_keys(PLAN_PARAMS, kwargs))
data_recs = level0.find(**override_common_keys(LEVEL0_PARAMS, kwargs))
# construct results
task_list = []
for data_rec in data_recs:
# construct task
task = dict(
dataset=data_rec["dataset"],
instrument=data_rec["instrument"],
obs_type=data_rec["obs_type"],
obs_group=data_rec["obs_group"],
obs_id=data_rec["obs_id"],
detector=data_rec["detector"],
file_name=data_rec["file_name"],
def find_plan_basis(**kwargs) -> table.Table:
"""
Find plan records.
"""
# query
qr = plan.find(**override_common_keys(PLAN_PARAMS, kwargs))
assert qr.success, qr
# plan basis / obsid basis
for _ in qr.data:
_["n_frame"] = (
_["params"]["n_epec_frame"] if _["instrument"] == "HSTDM" else 1
)
plan_basis = extract_basis_table(
qr.data,
PLAN_BASIS_KEYS,
)
return plan_basis
return dict(
task_list=task_list,
relevant_data_id_list=[],
@staticmethod
def find_level0_basis(**kwargs) -> table.Table:
"""
Find level0 records.
"""
# query
qr = level0.find(**override_common_keys(LEVEL0_PARAMS, kwargs))
assert qr.success, qr
# data basis
data_basis = extract_basis_table(
qr.data,
DATA_BASIS_KEYS,
)
return data_basis
@staticmethod
def dispatch_level0_detector(**kwargs) -> dict:
def find_level1_basis(**kwargs) -> table.Table:
"""
Find level1 records.
"""
# query
qr = level1.find(**override_common_keys(LEVEL1_PARAMS, kwargs))
assert qr.success, qr
# data basis
data_basis = extract_basis_table(
qr.data,
DATA_BASIS_KEYS,
)
return data_basis
# get instrument
assert "instrument" in kwargs.keys(), f"{kwargs} does not have key 'instrument'"
instrument = kwargs.get("instrument")
assert instrument in ("MSC", "MCI", "IFS", "CPIC", "HSTDM")
@staticmethod
def dispatch_file(
plan_basis: table.Table,
data_basis: table.Table,
) -> list[dict]:
# unique obsid
u_obsid = table.unique(data_basis["dataset", "obs_id"])
# query for plan and data
plan_recs = plan.find(**override_common_keys(PLAN_PARAMS, kwargs))
assert plan_recs.success, plan_recs
data_recs = level0.find(**override_common_keys(LEVEL0_PARAMS, kwargs))
assert data_recs.success, data_recs
# initialize task list
task_list = []
import joblib
# loop over plan
for i_data_basis in trange(
len(data_basis),
unit="task",
dynamic_ncols=True,
):
# i_data_basis = 1
this_data_basis = data_basis[i_data_basis : i_data_basis + 1]
this_relevant_plan = table.join(
u_obsid,
plan_basis,
keys=["dataset", "obs_id"],
join_type="inner",
)
# append this task
task_list.append(
dict(
task=this_data_basis,
success=True,
relevant_plan=this_relevant_plan,
relevant_data=data_basis[i_data_basis : i_data_basis + 1],
)
)
plan_recs = joblib.load("dagtest/csst-msc-c9-25sqdeg-v3.plan.dump")
data_recs = joblib.load("dagtest/csst-msc-c9-25sqdeg-v3.level0.dump")
print(f"{len(plan_recs.data)} plan records")
print(f"{len(data_recs.data)} data records")
return task_list
@staticmethod
def dispatch_detector(
plan_basis: table.Table,
data_basis: table.Table,
n_jobs: int = 1,
) -> list[dict]:
"""
instrument = "MSC"
from csst_dag._csst import csst
Parameters
----------
plan_basis
data_basis
n_jobs
effective_detector_names = csst[instrument].effective_detector_names
Returns
-------
# extract info
plan_basis = extract_basis(
plan_recs.data,
(
"""
if n_jobs != 1:
task_list = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(Dispatcher.dispatch_detector)(plan_basis, _)
for _ in split_data_basis(data_basis, n_split=n_jobs)
)
return sum(task_list, [])
# unique obsid
u_obsid = table.unique(data_basis["dataset", "obs_id"])
relevant_plan = table.join(
u_obsid,
plan_basis,
keys=["dataset", "obs_id"],
join_type="left",
)
print(f"{len(relevant_plan)} relevant plan records")
u_data_detector = table.unique(
data_basis[
"dataset",
"instrument",
"obs_type",
"obs_group",
"obs_id",
),
"detector",
]
)
data_basis = extract_basis(
data_recs.data,
(
# initialize task list
task_list = []
# loop over plan
for i_data_detector in trange(
len(u_data_detector),
unit="task",
dynamic_ncols=True,
):
# i_data_detector = 1
this_task = dict(u_data_detector[i_data_detector])
this_data_detector = u_data_detector[i_data_detector : i_data_detector + 1]
# join data and plan
this_data_detector_files = table.join(
this_data_detector,
data_basis,
keys=[
"dataset",
"instrument",
"obs_type",
"obs_group",
"obs_id",
"detector",
],
join_type="inner",
)
this_data_detector_plan = table.join(
this_data_detector,
relevant_plan,
keys=[
"dataset",
"instrument",
"obs_type",
"obs_group",
"obs_id",
],
join_type="left",
)
# whether detector effective
this_detector = this_data_detector["detector"][0]
this_instrument = this_data_detector["instrument"][0]
this_detector_effective = (
this_detector in csst[this_instrument].effective_detector_names
)
n_files_expected = this_data_detector_plan["n_frame"][0]
n_files_found = len(this_data_detector_files)
# append this task
task_list.append(
dict(
task=this_task,
success=(
len(this_data_detector_plan) == 1
and len(this_data_detector_files) == 1
and this_detector_effective
and n_files_found == n_files_expected
),
relevant_plan=this_data_detector_plan,
relevant_data=this_data_detector_files,
)
)
return task_list
@staticmethod
def dispatch_obsid(
plan_basis: table.Table,
data_basis: table.Table,
n_jobs: int = 1,
) -> list[dict]:
if n_jobs != 1:
task_list = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(Dispatcher.dispatch_obsid)(plan_basis, _)
for _ in split_data_basis(data_basis, n_split=n_jobs)
)
return sum(task_list, [])
# select plan basis relevant to data via `obs_id`
u_data_obsid = np.unique([_["obs_id"] for _ in data_basis])
relevant_plan_basis = [_ for _ in plan_basis if _["obs_id"] in u_data_obsid]
print(f"{len(relevant_plan_basis)} relevant plan records")
# unique obsid
u_obsid = table.unique(data_basis["dataset", "obs_id"])
relevant_plan = table.join(
u_obsid,
plan_basis,
keys=["dataset", "obs_id"],
join_type="left",
)
print(f"{len(relevant_plan)} relevant plan records")
# idx_selected_relevant_plan_basis = np.zeros(len(relevant_plan_basis), dtype=bool)
# 好像并不是要找出所有的plan,而是要找出所有的任务,而detector级的任务要比plan_basis多得多
u_data_obsid = table.unique(
data_basis[
"dataset",
"instrument",
"obs_type",
"obs_group",
"obs_id",
]
)
# initialize task list
task_list = []
relevant_data_id_list = []
# loop over plan
for i_plan_basis, this_plan_basis in enumerate(relevant_plan_basis):
print(f"Processing {i_plan_basis + 1}/{len(relevant_plan_basis)}")
# span over `detector`
for this_detector in effective_detector_names:
# construct this_task
this_task = dict(
dataset=this_plan_basis["dataset"],
instrument=this_plan_basis["instrument"],
obs_type=this_plan_basis["obs_type"],
obs_group=this_plan_basis["obs_group"],
obs_id=this_plan_basis["obs_id"],
detector=this_detector,
)
# find this plan basis
idx_this_plan_basis = np.argwhere(
plan_basis == this_plan_basis
).flatten()[0]
# get n_frame, calculate n_file_expected
if instrument == "HSTDM":
n_file_expected = plan_recs.data[idx_this_plan_basis]["params"][
"num_epec_frame"
for i_data_obsid in trange(
len(u_data_obsid),
unit="task",
dynamic_ncols=True,
):
i_data_obsid = 2
this_task = dict(u_data_obsid[i_data_obsid])
this_data_obsid = u_data_obsid[i_data_obsid : i_data_obsid + 1]
# join data and plan
this_data_obsid_files = table.join(
this_data_obsid,
data_basis,
keys=[
"dataset",
"instrument",
"obs_type",
"obs_group",
"obs_id",
],
join_type="inner",
)
this_data_obsid_plan = table.join(
this_data_obsid,
relevant_plan,
keys=[
"dataset",
"instrument",
"obs_type",
"obs_group",
"obs_id",
],
join_type="left",
)
# whether effective detectors all there
this_instrument = this_data_obsid["instrument"][0]
this_success = set(csst[this_instrument].effective_detector_names).issubset(
set(this_data_obsid_files["detector"])
)
# append this task
task_list.append(
dict(
task=this_task,
success=this_success,
relevant_plan=this_data_obsid_plan,
relevant_data=this_data_obsid_files,
)
)
return task_list
@staticmethod
def dispatch_obsgroup(
plan_basis: table.Table,
data_basis: table.Table,
# n_jobs: int = 1,
) -> list[dict]:
# unique obsgroup basis
obsgroup_basis = table.unique(
plan_basis[
"dataset",
"instrument",
"obs_type",
"obs_group",
]
)
# initialize task list
task_list = []
# loop over obsgroup
for i_obsgroup in trange(
len(obsgroup_basis),
unit="task",
dynamic_ncols=True,
):
# i_obsgroup = 1
this_task = dict(obsgroup_basis[i_obsgroup])
this_success = True
this_obsgroup_obsid = table.join(
obsgroup_basis[i_obsgroup : i_obsgroup + 1], # this obsgroup
plan_basis,
keys=["dataset", "instrument", "obs_type", "obs_group"],
join_type="left",
)
this_obsgroup_file = table.join(
this_obsgroup_obsid,
data_basis,
keys=["dataset", "instrument", "obs_type", "obs_group", "obs_id"],
join_type="inner",
table_names=["plan", "data"],
)
# loop over obsid
for i_obsid in range(len(this_obsgroup_obsid)):
# i_obsid = 1
# print(i_obsid)
instrument = this_obsgroup_obsid[i_obsid]["instrument"]
n_frame = this_obsgroup_obsid[i_obsid]["n_frame"]
effective_detector_names = csst[instrument].effective_detector_names
this_obsgroup_obsid_file = table.join(
this_obsgroup_obsid[i_obsid : i_obsid + 1], # this obsid
data_basis,
keys=["dataset", "instrument", "obs_type", "obs_group", "obs_id"],
join_type="inner",
table_names=["plan", "data"],
)
if instrument == "HSTDM": # 我也不知道太赫兹要怎么玩
# this_success &= (
# len(this_obsgroup_obsid_file) == n_frame
# or len(this_obsgroup_obsid_file) == n_frame * 2
# )
# or simply
this_success &= len(this_obsgroup_obsid_file) % n_frame == 0
else:
n_file_expected = 1
# count files found in data_basis
idx_files_found = np.argwhere(data_basis == this_task).flatten()
n_file_found = len(idx_files_found)
# if found == expected, append this task
if n_file_found == n_file_expected:
task_list.append(this_task)
relevant_data_id_list.extend(
[data_recs.data[_]["_id"] for _ in idx_files_found]
# n_detector == n_file
# this_success &= len(this_obsgroup_obsid_file) == len(
# effective_detector_names
# )
# or more strictly, each detector matches
this_success &= set(this_obsgroup_obsid_file["detector"]) == set(
effective_detector_names
)
return dict(task_list=task_list, relevant_data_id_list=relevant_data_id_list)
# append this task
task_list.append(
dict(
task=this_task,
success=this_success,
relevant_plan=this_obsgroup_obsid,
relevant_data=this_obsgroup_file,
)
)
return task_list
@staticmethod
def dispatch_level0_obsid(**kwargs) -> list[dict]:
pass
def load_test_data() -> tuple:
import joblib
plan_recs = joblib.load("dagtest/csst-msc-c9-25sqdeg-v3.plan.dump")
data_recs = joblib.load("dagtest/csst-msc-c9-25sqdeg-v3.level0.dump")
print(f"{len(plan_recs.data)} plan records")
print(f"{len(data_recs.data)} data records")
for _ in plan_recs.data:
_["n_frame"] = (
_["params"]["n_epec_frame"] if _["instrument"] == "HSTDM" else 1
)
plan_basis = extract_basis_table(
plan_recs.data,
PLAN_BASIS_KEYS,
)
data_basis = extract_basis_table(
data_recs.data,
DATA_BASIS_KEYS,
)
return plan_basis, data_basis
# # 1221 plan recs, 36630 data recs
# plan_basis, data_basis = Dispatcher.load_test_data()
#
# # 430 task/s
# task_list_via_file = Dispatcher.dispatch_file(plan_basis, data_basis)
#
# # 13 task/s @n_jobs=1, 100*10 task/s @n_jobs=10 (max)
# task_list_via_detector = Dispatcher.dispatch_detector(plan_basis, data_basis, n_jobs=10)
#
# # 16 task/s @n_jobs=1, 130*10 tasks/s @n_jobs=10 (max) 🔼
# task_list_via_obsid = Dispatcher.dispatch_obsid(plan_basis, data_basis, n_jobs=10)
#
# # 13s/task
# task_list_via_obsgroup = Dispatcher.dispatch_obsgroup(plan_basis, data_basis)
# print(
# sum(_["success"] for _ in task_list_via_obsgroup),
# "/",
# len(task_list_via_obsgroup),
# )
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment