Commit df44de32 authored by BO ZHANG's avatar BO ZHANG 🏀
Browse files

update dispatcher.py

parent 3ddcd5a6
...@@ -9,7 +9,7 @@ class DotDict(dict): ...@@ -9,7 +9,7 @@ class DotDict(dict):
def __getattr__(self, key): def __getattr__(self, key):
"""属性访问优先级:1. 内置属性 → 2. 键值 → 3. 报错""" """属性访问优先级:1. 内置属性 → 2. 键值 → 3. 报错"""
try: try:
# 优先返回内置属性(如 keys, items 等方法) # 优先返回内置属性(如 basis_keys, items 等方法)
return object.__getattribute__(self, key) return object.__getattribute__(self, key)
except AttributeError: except AttributeError:
if key in self: if key in self:
......
...@@ -137,6 +137,11 @@ class Telescope(DotDict): ...@@ -137,6 +137,11 @@ class Telescope(DotDict):
def n_instrument(self): def n_instrument(self):
return len(self.instruments) return len(self.instruments)
# def plan_to_detector(self, plan_data):
# # convert to dict
# plan_dict = dict(plan_data)
# if plan_dict["instrument"] == "HSTDM":
mbi = SimpleInstrument( mbi = SimpleInstrument(
name="MBI", name="MBI",
......
from ._base_dag import BaseDAG from ._base_dag import BaseDAG
from ._dag_list import DAG_LIST from ._dag_list import DAG_LIST
from .dags import GeneralDAGViaObsid from .dags import GeneralDAGViaObsid, GeneralDAGViaObsgroup
from .dispatcher import Dispatcher
class CsstDAGs(dict): class CsstDAGs(dict):
...@@ -21,9 +22,9 @@ class CsstDAGs(dict): ...@@ -21,9 +22,9 @@ class CsstDAGs(dict):
"csst-msc-l1-sls": GeneralDAGViaObsid( "csst-msc-l1-sls": GeneralDAGViaObsid(
dag_group="msc-l1", dag="csst-msc-l1-sls", use_detector=True dag_group="msc-l1", dag="csst-msc-l1-sls", use_detector=True
), ),
# "csst-msc-l1-ooc": GeneralDAGViaObsgroup( "csst-msc-l1-ooc": GeneralDAGViaObsgroup(
# dag_group="msc-l1", dag="csst-msc-l1-ooc" dag_group="msc-l1-ooc", dag="csst-msc-l1-ooc"
# ), ),
"csst-cpic-l1": GeneralDAGViaObsid( "csst-cpic-l1": GeneralDAGViaObsid(
dag_group="cpic-l1", dag="csst-cpic-l1", use_detector=True dag_group="cpic-l1", dag="csst-cpic-l1", use_detector=True
), ),
......
...@@ -128,6 +128,9 @@ DATA_BASIS_KEYS = ( ...@@ -128,6 +128,9 @@ DATA_BASIS_KEYS = (
"_id", "_id",
) )
# join_type for data x plan
PLAN_JOIN_TYPE = "inner"
class Dispatcher: class Dispatcher:
""" """
...@@ -188,8 +191,8 @@ class Dispatcher: ...@@ -188,8 +191,8 @@ class Dispatcher:
plan_basis: table.Table, plan_basis: table.Table,
data_basis: table.Table, data_basis: table.Table,
) -> list[dict]: ) -> list[dict]:
# unique obsid # unique obsid --> useless
u_obsid = table.unique(data_basis["dataset", "obs_id"]) # u_obsid = table.unique(data_basis["dataset", "obs_id"])
# initialize task list # initialize task list
task_list = [] task_list = []
...@@ -201,9 +204,10 @@ class Dispatcher: ...@@ -201,9 +204,10 @@ class Dispatcher:
dynamic_ncols=True, dynamic_ncols=True,
): ):
# i_data_basis = 1 # i_data_basis = 1
this_task = dict(data_basis[i_data_basis])
this_data_basis = data_basis[i_data_basis : i_data_basis + 1] this_data_basis = data_basis[i_data_basis : i_data_basis + 1]
this_relevant_plan = table.join( this_relevant_plan = table.join(
u_obsid, this_data_basis,
plan_basis, plan_basis,
keys=["dataset", "obs_id"], keys=["dataset", "obs_id"],
join_type="inner", join_type="inner",
...@@ -211,10 +215,12 @@ class Dispatcher: ...@@ -211,10 +215,12 @@ class Dispatcher:
# append this task # append this task
task_list.append( task_list.append(
dict( dict(
task=this_data_basis, task=this_task,
success=True, success=True,
relevant_plan=this_relevant_plan, relevant_plan=this_relevant_plan,
relevant_data=data_basis[i_data_basis : i_data_basis + 1], relevant_data=data_basis[i_data_basis : i_data_basis + 1],
n_relevant_plan=len(this_relevant_plan),
n_relevant_data=1,
) )
) )
...@@ -251,7 +257,7 @@ class Dispatcher: ...@@ -251,7 +257,7 @@ class Dispatcher:
u_obsid, u_obsid,
plan_basis, plan_basis,
keys=["dataset", "obs_id"], keys=["dataset", "obs_id"],
join_type="left", join_type=PLAN_JOIN_TYPE,
) )
print(f"{len(relevant_plan)} relevant plan records") print(f"{len(relevant_plan)} relevant plan records")
...@@ -303,7 +309,7 @@ class Dispatcher: ...@@ -303,7 +309,7 @@ class Dispatcher:
"obs_group", "obs_group",
"obs_id", "obs_id",
], ],
join_type="left", join_type=PLAN_JOIN_TYPE,
) )
# whether detector effective # whether detector effective
...@@ -312,7 +318,12 @@ class Dispatcher: ...@@ -312,7 +318,12 @@ class Dispatcher:
this_detector_effective = ( this_detector_effective = (
this_detector in csst[this_instrument].effective_detector_names this_detector in csst[this_instrument].effective_detector_names
) )
n_files_expected = this_data_detector_plan["n_frame"][0]
n_files_expected = (
this_data_detector_plan["n_frame"][0]
if len(this_data_detector_plan) > 0
else 0
)
n_files_found = len(this_data_detector_files) n_files_found = len(this_data_detector_files)
# append this task # append this task
task_list.append( task_list.append(
...@@ -326,6 +337,8 @@ class Dispatcher: ...@@ -326,6 +337,8 @@ class Dispatcher:
), ),
relevant_plan=this_data_detector_plan, relevant_plan=this_data_detector_plan,
relevant_data=this_data_detector_files, relevant_data=this_data_detector_files,
n_relevant_plan=len(this_data_detector_plan),
n_relevant_data=len(this_data_detector_files),
) )
) )
return task_list return task_list
...@@ -350,7 +363,7 @@ class Dispatcher: ...@@ -350,7 +363,7 @@ class Dispatcher:
u_obsid, u_obsid,
plan_basis, plan_basis,
keys=["dataset", "obs_id"], keys=["dataset", "obs_id"],
join_type="left", join_type=PLAN_JOIN_TYPE,
) )
print(f"{len(relevant_plan)} relevant plan records") print(f"{len(relevant_plan)} relevant plan records")
...@@ -373,12 +386,12 @@ class Dispatcher: ...@@ -373,12 +386,12 @@ class Dispatcher:
unit="task", unit="task",
dynamic_ncols=True, dynamic_ncols=True,
): ):
i_data_obsid = 2 # i_data_obsid = 2
this_task = dict(u_data_obsid[i_data_obsid]) this_task = dict(u_data_obsid[i_data_obsid])
this_data_obsid = u_data_obsid[i_data_obsid : i_data_obsid + 1] this_data_obsid = u_data_obsid[i_data_obsid : i_data_obsid + 1]
# join data and plan # join data and plan
this_data_obsid_files = table.join( this_data_obsid_file = table.join(
this_data_obsid, this_data_obsid,
data_basis, data_basis,
keys=[ keys=[
...@@ -400,14 +413,33 @@ class Dispatcher: ...@@ -400,14 +413,33 @@ class Dispatcher:
"obs_group", "obs_group",
"obs_id", "obs_id",
], ],
join_type="left", join_type=PLAN_JOIN_TYPE,
) )
# whether effective detectors all there # whether effective detectors all there
this_instrument = this_data_obsid["instrument"][0] this_instrument = this_data_obsid["instrument"][0]
this_success = set(csst[this_instrument].effective_detector_names).issubset( this_n_frame = (
set(this_data_obsid_files["detector"]) this_data_obsid_plan["n_frame"] if len(this_data_obsid_plan) > 0 else 0
) )
this_effective_detector_names = csst[
this_instrument
].effective_detector_names
if this_instrument == "HSTDM":
# 不确定以后是1个探测器还是2个探测器
this_n_file_found = len(this_data_obsid_file)
this_n_file_expected = (this_n_frame, this_n_frame * 2)
this_success = this_n_file_found in this_n_file_expected
else:
# for other instruments, e.g., MSC
# n_file_found = len(this_obsgroup_obsid_file)
# n_file_expected = len(effective_detector_names)
# this_success &= n_file_found == n_file_expected
# or more strictly, expected files are a subset of files found
this_success = set(this_effective_detector_names) <= set(
this_data_obsid_file["detector"]
)
# append this task # append this task
task_list.append( task_list.append(
...@@ -415,7 +447,9 @@ class Dispatcher: ...@@ -415,7 +447,9 @@ class Dispatcher:
task=this_task, task=this_task,
success=this_success, success=this_success,
relevant_plan=this_data_obsid_plan, relevant_plan=this_data_obsid_plan,
relevant_data=this_data_obsid_files, relevant_data=this_data_obsid_file,
n_relevant_plan=len(this_data_obsid_plan),
n_relevant_data=len(this_data_obsid_file),
) )
) )
...@@ -452,14 +486,14 @@ class Dispatcher: ...@@ -452,14 +486,14 @@ class Dispatcher:
this_task = dict(obsgroup_basis[i_obsgroup]) this_task = dict(obsgroup_basis[i_obsgroup])
this_success = True this_success = True
this_obsgroup_obsid = table.join( this_obsgroup_plan = table.join(
obsgroup_basis[i_obsgroup : i_obsgroup + 1], # this obsgroup obsgroup_basis[i_obsgroup : i_obsgroup + 1], # this obsgroup
plan_basis, plan_basis,
keys=["dataset", "instrument", "obs_type", "obs_group"], keys=["dataset", "instrument", "obs_type", "obs_group"],
join_type="left", join_type=PLAN_JOIN_TYPE,
) )
this_obsgroup_file = table.join( this_obsgroup_file = table.join(
this_obsgroup_obsid, this_obsgroup_plan,
data_basis, data_basis,
keys=["dataset", "instrument", "obs_type", "obs_group", "obs_id"], keys=["dataset", "instrument", "obs_type", "obs_group", "obs_id"],
join_type="inner", join_type="inner",
...@@ -467,36 +501,37 @@ class Dispatcher: ...@@ -467,36 +501,37 @@ class Dispatcher:
) )
# loop over obsid # loop over obsid
for i_obsid in range(len(this_obsgroup_obsid)): for i_obsid in range(len(this_obsgroup_plan)):
# i_obsid = 1 # i_obsid = 1
# print(i_obsid) # print(i_obsid)
instrument = this_obsgroup_obsid[i_obsid]["instrument"] this_instrument = this_obsgroup_plan[i_obsid]["instrument"]
n_frame = this_obsgroup_obsid[i_obsid]["n_frame"] this_n_frame = this_obsgroup_plan[i_obsid]["n_frame"]
effective_detector_names = csst[instrument].effective_detector_names this_effective_detector_names = csst[
this_instrument
].effective_detector_names
this_obsgroup_obsid_file = table.join( this_obsgroup_obsid_file = table.join(
this_obsgroup_obsid[i_obsid : i_obsid + 1], # this obsid this_obsgroup_plan[i_obsid : i_obsid + 1], # this obsid
data_basis, data_basis,
keys=["dataset", "instrument", "obs_type", "obs_group", "obs_id"], keys=["dataset", "instrument", "obs_type", "obs_group", "obs_id"],
join_type="inner", join_type="inner",
table_names=["plan", "data"], table_names=["plan", "data"],
) )
if instrument == "HSTDM": # 我也不知道太赫兹要怎么玩 if this_instrument == "HSTDM":
# this_success &= ( # 不确定以后是1个探测器还是2个探测器
# len(this_obsgroup_obsid_file) == n_frame this_n_file_found = len(this_obsgroup_obsid_file)
# or len(this_obsgroup_obsid_file) == n_frame * 2 this_n_file_expected = (this_n_frame, this_n_frame * 2)
# ) this_success &= this_n_file_found in this_n_file_expected
# or simply
this_success &= len(this_obsgroup_obsid_file) % n_frame == 0
else: else:
# n_detector == n_file # for other instruments, e.g., MSC
# this_success &= len(this_obsgroup_obsid_file) == len( # n_file_found = len(this_obsgroup_obsid_file)
# effective_detector_names # n_file_expected = len(effective_detector_names)
# ) # this_success &= n_file_found == n_file_expected
# or more strictly, each detector matches
this_success &= set(this_obsgroup_obsid_file["detector"]) == set( # or more strictly, expected files are a subset of files found
effective_detector_names this_success &= set(this_effective_detector_names) <= set(
this_obsgroup_obsid_file["detector"]
) )
# append this task # append this task
...@@ -504,8 +539,10 @@ class Dispatcher: ...@@ -504,8 +539,10 @@ class Dispatcher:
dict( dict(
task=this_task, task=this_task,
success=this_success, success=this_success,
relevant_plan=this_obsgroup_obsid, relevant_plan=this_obsgroup_plan,
relevant_data=this_obsgroup_file, relevant_data=this_obsgroup_file,
n_relevant_plan=len(this_obsgroup_plan),
n_relevant_data=this_obsgroup_file,
) )
) )
return task_list return task_list
...@@ -522,6 +559,7 @@ class Dispatcher: ...@@ -522,6 +559,7 @@ class Dispatcher:
_["n_frame"] = ( _["n_frame"] = (
_["params"]["n_epec_frame"] if _["instrument"] == "HSTDM" else 1 _["params"]["n_epec_frame"] if _["instrument"] == "HSTDM" else 1
) )
# 未来如果HSTDM的设定简化一些,这里n_frame可以改成n_file,更直观
plan_basis = extract_basis_table( plan_basis = extract_basis_table(
plan_recs.data, plan_recs.data,
PLAN_BASIS_KEYS, PLAN_BASIS_KEYS,
...@@ -531,24 +569,3 @@ class Dispatcher: ...@@ -531,24 +569,3 @@ class Dispatcher:
DATA_BASIS_KEYS, DATA_BASIS_KEYS,
) )
return plan_basis, data_basis return plan_basis, data_basis
# # 1221 plan recs, 36630 data recs
# plan_basis, data_basis = Dispatcher.load_test_data()
#
# # 430 task/s
# task_list_via_file = Dispatcher.dispatch_file(plan_basis, data_basis)
#
# # 13 task/s @n_jobs=1, 100*10 task/s @n_jobs=10 (max)
# task_list_via_detector = Dispatcher.dispatch_detector(plan_basis, data_basis, n_jobs=10)
#
# # 16 task/s @n_jobs=1, 130*10 tasks/s @n_jobs=10 (max) 🔼
# task_list_via_obsid = Dispatcher.dispatch_obsid(plan_basis, data_basis, n_jobs=10)
#
# # 13s/task
# task_list_via_obsgroup = Dispatcher.dispatch_obsgroup(plan_basis, data_basis)
# print(
# sum(_["success"] for _ in task_list_via_obsgroup),
# "/",
# len(task_list_via_obsgroup),
# )
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment