Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
csst-cicd
csst-dag
Commits
3ddcd5a6
Commit
3ddcd5a6
authored
Jul 08, 2025
by
BO ZHANG
🏀
Browse files
update dispatcher.py
parent
c1df3efa
Changes
1
Show whitespace changes
Inline
Side-by-side
csst_dag/dag/dispatcher.py
View file @
3ddcd5a6
import
numpy
as
np
from
astropy.table
import
Table
import
joblib
from
astropy
import
table
from
csst_dfs_client
import
plan
,
level0
,
level1
from
tqdm
import
trange
from
.._csst
import
csst
# from csst_dag._csst import csst
# THESE ARE GENERAL PARAMETERS!
PLAN_PARAMS
=
{
"dataset"
:
None
,
...
...
@@ -41,6 +47,7 @@ PROC_PARAMS = {
"pmapname"
:
"pmapname"
,
"final_prc_status"
:
-
2
,
"demo"
:
False
,
# should be capable to extend
}
...
...
@@ -64,14 +71,62 @@ def override_common_keys(d1: dict, d2: dict) -> dict:
return
{
k
:
d2
[
k
]
if
k
in
d2
.
keys
()
else
d1
[
k
]
for
k
in
d1
.
keys
()}
#
def extract_basis(dlist: list[dict], basis_keys: tuple) ->
np.ndarray
:
#
"""Extract basis key-value pairs from a list of dictionaries."""
#
return Table([{k: d.get(k) for k in basis_keys} for d in dlist])
.as_array()
def
extract_basis
_table
(
dlist
:
list
[
dict
],
basis_keys
:
tuple
)
->
table
.
Table
:
"""Extract basis key-value pairs from a list of dictionaries."""
return
table
.
Table
([{
k
:
d
.
get
(
k
)
for
k
in
basis_keys
}
for
d
in
dlist
])
def
extract_basis
(
dlist
:
list
[
dict
],
basis_keys
:
tuple
)
->
np
.
typing
.
NDArray
:
"""Extract basis key-value pairs from a list of dictionaries."""
return
np
.
array
([{
k
:
d
.
get
(
k
)
for
k
in
basis_keys
}
for
d
in
dlist
],
dtype
=
dict
)
def
split_data_basis
(
data_basis
:
table
.
Table
,
n_split
:
int
=
1
)
->
list
[
table
.
Table
]:
"""Split data basis into n_split parts."""
assert
(
np
.
unique
(
data_basis
[
"dataset"
]).
size
==
1
),
"Only one dataset is allowed for splitting."
# sort
data_basis
.
sort
(
keys
=
[
"dataset"
,
"obs_id"
])
# get unique obsid
u_obsid
,
i_obsid
,
c_obsid
=
np
.
unique
(
data_basis
[
"obs_id"
].
data
,
return_index
=
True
,
return_counts
=
True
)
# set chunk size
chunk_size
=
int
(
np
.
fix
(
len
(
u_obsid
)
/
n_split
))
# initialize chunks
chunks
=
[]
for
i_split
in
range
(
n_split
):
if
i_split
<
n_split
-
1
:
chunks
.
append
(
data_basis
[
i_obsid
[
i_split
*
chunk_size
]
:
i_obsid
[(
i_split
+
1
)
*
chunk_size
]
]
)
else
:
chunks
.
append
(
data_basis
[
i_obsid
[
i_split
*
chunk_size
]
:])
# np.unique(table.vstack(chunks)["_id"])
# np.unique(table.vstack(chunks)["obs_id"])
return
chunks
# plan basis keys
PLAN_BASIS_KEYS
=
(
"dataset"
,
"instrument"
,
"obs_type"
,
"obs_group"
,
"obs_id"
,
"n_frame"
,
"_id"
,
)
# data basis keys
DATA_BASIS_KEYS
=
(
"dataset"
,
"instrument"
,
"obs_type"
,
"obs_group"
,
"obs_id"
,
"detector"
,
"file_name"
,
"_id"
,
)
class
Dispatcher
:
...
...
@@ -80,125 +135,420 @@ class Dispatcher:
"""
@
staticmethod
def
dispatch_level0_file
(
**
kwargs
)
->
dict
:
# plan_recs = plan.find(**override_common_keys(PLAN_PARAMS, kwargs))
data_recs
=
level0
.
find
(
**
override_common_keys
(
LEVEL0_PARAMS
,
kwargs
))
# construct results
task_list
=
[]
for
data_rec
in
data_recs
:
# construct task
task
=
dict
(
dataset
=
data_rec
[
"dataset"
],
instrument
=
data_rec
[
"instrument"
],
obs_type
=
data_rec
[
"obs_type"
],
obs_group
=
data_rec
[
"obs_group"
],
obs_id
=
data_rec
[
"obs_id"
],
detector
=
data_rec
[
"detector"
]
,
file_name
=
data_rec
[
"file_name"
]
,
def
find_plan_basis
(
**
kwargs
)
->
table
.
Table
:
"""
Find plan records.
"""
# query
qr
=
plan
.
find
(
**
override_common_keys
(
PLAN_PARAMS
,
kwargs
))
assert
qr
.
success
,
qr
# plan basis / obsid basis
for
_
in
qr
.
data
:
_
[
"n_frame"
]
=
(
_
[
"params"
][
"n_epec_frame"
]
if
_
[
"instrument"
]
==
"HSTDM"
else
1
)
plan_basis
=
extract_basis_table
(
qr
.
data
,
PLAN_BASIS_KEYS
,
)
return
plan_basis
return
dict
(
task_list
=
task_list
,
relevant_data_id_list
=
[],
@
staticmethod
def
find_level0_basis
(
**
kwargs
)
->
table
.
Table
:
"""
Find level0 records.
"""
# query
qr
=
level0
.
find
(
**
override_common_keys
(
LEVEL0_PARAMS
,
kwargs
))
assert
qr
.
success
,
qr
# data basis
data_basis
=
extract_basis_table
(
qr
.
data
,
DATA_BASIS_KEYS
,
)
return
data_basis
@
staticmethod
def
dispatch_level0_detector
(
**
kwargs
)
->
dict
:
def
find_level1_basis
(
**
kwargs
)
->
table
.
Table
:
"""
Find level1 records.
"""
# query
qr
=
level1
.
find
(
**
override_common_keys
(
LEVEL1_PARAMS
,
kwargs
))
assert
qr
.
success
,
qr
# data basis
data_basis
=
extract_basis_table
(
qr
.
data
,
DATA_BASIS_KEYS
,
)
return
data_basis
# get instrument
assert
"instrument"
in
kwargs
.
keys
(),
f
"
{
kwargs
}
does not have key 'instrument'"
instrument
=
kwargs
.
get
(
"instrument"
)
assert
instrument
in
(
"MSC"
,
"MCI"
,
"IFS"
,
"CPIC"
,
"HSTDM"
)
@
staticmethod
def
dispatch_file
(
plan_basis
:
table
.
Table
,
data_basis
:
table
.
Table
,
)
->
list
[
dict
]:
# unique obsid
u_obsid
=
table
.
unique
(
data_basis
[
"dataset"
,
"obs_id"
])
# query for plan and data
plan_recs
=
plan
.
find
(
**
override_common_keys
(
PLAN_PARAMS
,
kwargs
))
assert
plan_recs
.
success
,
plan_recs
data_recs
=
level0
.
find
(
**
override_common_keys
(
LEVEL0_PARAMS
,
kwargs
))
assert
data_recs
.
success
,
data_recs
# initialize task list
task_list
=
[]
import
joblib
# loop over plan
for
i_data_basis
in
trange
(
len
(
data_basis
),
unit
=
"task"
,
dynamic_ncols
=
True
,
):
# i_data_basis = 1
this_data_basis
=
data_basis
[
i_data_basis
:
i_data_basis
+
1
]
this_relevant_plan
=
table
.
join
(
u_obsid
,
plan_basis
,
keys
=
[
"dataset"
,
"obs_id"
],
join_type
=
"inner"
,
)
# append this task
task_list
.
append
(
dict
(
task
=
this_data_basis
,
success
=
True
,
relevant_plan
=
this_relevant_plan
,
relevant_data
=
data_basis
[
i_data_basis
:
i_data_basis
+
1
],
)
)
plan_recs
=
joblib
.
load
(
"dagtest/csst-msc-c9-25sqdeg-v3.plan.dump"
)
data_recs
=
joblib
.
load
(
"dagtest/csst-msc-c9-25sqdeg-v3.level0.dump"
)
print
(
f
"
{
len
(
plan_recs
.
data
)
}
plan records"
)
print
(
f
"
{
len
(
data_recs
.
data
)
}
data records"
)
return
task_list
@
staticmethod
def
dispatch_detector
(
plan_basis
:
table
.
Table
,
data_basis
:
table
.
Table
,
n_jobs
:
int
=
1
,
)
->
list
[
dict
]:
"""
instrument
=
"MSC"
from
csst_dag._csst
import
csst
Parameters
----------
plan_basis
data_basis
n_jobs
effective_detector_names
=
csst
[
instrument
].
effective_detector_names
Returns
-------
# extract info
plan_basis
=
extract_basis
(
plan_recs
.
data
,
(
"""
if
n_jobs
!=
1
:
task_list
=
joblib
.
Parallel
(
n_jobs
=
n_jobs
)(
joblib
.
delayed
(
Dispatcher
.
dispatch_detector
)(
plan_basis
,
_
)
for
_
in
split_data_basis
(
data_basis
,
n_split
=
n_jobs
)
)
return
sum
(
task_list
,
[])
# unique obsid
u_obsid
=
table
.
unique
(
data_basis
[
"dataset"
,
"obs_id"
])
relevant_plan
=
table
.
join
(
u_obsid
,
plan_basis
,
keys
=
[
"dataset"
,
"obs_id"
],
join_type
=
"left"
,
)
print
(
f
"
{
len
(
relevant_plan
)
}
relevant plan records"
)
u_data_detector
=
table
.
unique
(
data_basis
[
"dataset"
,
"instrument"
,
"obs_type"
,
"obs_group"
,
"obs_id"
,
),
"detector"
,
]
)
data_basis
=
extract_basis
(
data_recs
.
data
,
(
# initialize task list
task_list
=
[]
# loop over plan
for
i_data_detector
in
trange
(
len
(
u_data_detector
),
unit
=
"task"
,
dynamic_ncols
=
True
,
):
# i_data_detector = 1
this_task
=
dict
(
u_data_detector
[
i_data_detector
])
this_data_detector
=
u_data_detector
[
i_data_detector
:
i_data_detector
+
1
]
# join data and plan
this_data_detector_files
=
table
.
join
(
this_data_detector
,
data_basis
,
keys
=
[
"dataset"
,
"instrument"
,
"obs_type"
,
"obs_group"
,
"obs_id"
,
"detector"
,
],
join_type
=
"inner"
,
)
this_data_detector_plan
=
table
.
join
(
this_data_detector
,
relevant_plan
,
keys
=
[
"dataset"
,
"instrument"
,
"obs_type"
,
"obs_group"
,
"obs_id"
,
],
join_type
=
"left"
,
)
# whether detector effective
this_detector
=
this_data_detector
[
"detector"
][
0
]
this_instrument
=
this_data_detector
[
"instrument"
][
0
]
this_detector_effective
=
(
this_detector
in
csst
[
this_instrument
].
effective_detector_names
)
n_files_expected
=
this_data_detector_plan
[
"n_frame"
][
0
]
n_files_found
=
len
(
this_data_detector_files
)
# append this task
task_list
.
append
(
dict
(
task
=
this_task
,
success
=
(
len
(
this_data_detector_plan
)
==
1
and
len
(
this_data_detector_files
)
==
1
and
this_detector_effective
and
n_files_found
==
n_files_expected
),
relevant_plan
=
this_data_detector_plan
,
relevant_data
=
this_data_detector_files
,
)
)
return
task_list
@
staticmethod
def
dispatch_obsid
(
plan_basis
:
table
.
Table
,
data_basis
:
table
.
Table
,
n_jobs
:
int
=
1
,
)
->
list
[
dict
]:
if
n_jobs
!=
1
:
task_list
=
joblib
.
Parallel
(
n_jobs
=
n_jobs
)(
joblib
.
delayed
(
Dispatcher
.
dispatch_obsid
)(
plan_basis
,
_
)
for
_
in
split_data_basis
(
data_basis
,
n_split
=
n_jobs
)
)
return
sum
(
task_list
,
[])
# select plan basis relevant to data via `obs_id`
u_data_obsid
=
np
.
unique
([
_
[
"obs_id"
]
for
_
in
data_basis
])
relevant_plan_basis
=
[
_
for
_
in
plan_basis
if
_
[
"obs_id"
]
in
u_data_obsid
]
print
(
f
"
{
len
(
relevant_plan_basis
)
}
relevant plan records"
)
# unique obsid
u_obsid
=
table
.
unique
(
data_basis
[
"dataset"
,
"obs_id"
])
relevant_plan
=
table
.
join
(
u_obsid
,
plan_basis
,
keys
=
[
"dataset"
,
"obs_id"
],
join_type
=
"left"
,
)
print
(
f
"
{
len
(
relevant_plan
)
}
relevant plan records"
)
# idx_selected_relevant_plan_basis = np.zeros(len(relevant_plan_basis), dtype=bool)
# 好像并不是要找出所有的plan,而是要找出所有的任务,而detector级的任务要比plan_basis多得多
u_data_obsid
=
table
.
unique
(
data_basis
[
"dataset"
,
"instrument"
,
"obs_type"
,
"obs_group"
,
"obs_id"
,
]
)
# initialize task list
task_list
=
[]
relevant_data_id_list
=
[]
# loop over plan
for
i_plan_basis
,
this_plan_basis
in
enumerate
(
relevant_plan_basis
):
print
(
f
"Processing
{
i_plan_basis
+
1
}
/
{
len
(
relevant_plan_basis
)
}
"
)
# span over `detector`
for
this_detector
in
effective_detector_names
:
# construct this_task
this_task
=
dict
(
dataset
=
this_plan_basis
[
"dataset"
],
instrument
=
this_plan_basis
[
"instrument"
],
obs_type
=
this_plan_basis
[
"obs_type"
],
obs_group
=
this_plan_basis
[
"obs_group"
],
obs_id
=
this_plan_basis
[
"obs_id"
],
detector
=
this_detector
,
)
# find this plan basis
idx_this_plan_basis
=
np
.
argwhere
(
plan_basis
==
this_plan_basis
).
flatten
()[
0
]
# get n_frame, calculate n_file_expected
if
instrument
==
"HSTDM"
:
n_file_expected
=
plan_recs
.
data
[
idx_this_plan_basis
][
"params"
][
"num_epec_frame"
for
i_data_obsid
in
trange
(
len
(
u_data_obsid
),
unit
=
"task"
,
dynamic_ncols
=
True
,
):
i_data_obsid
=
2
this_task
=
dict
(
u_data_obsid
[
i_data_obsid
])
this_data_obsid
=
u_data_obsid
[
i_data_obsid
:
i_data_obsid
+
1
]
# join data and plan
this_data_obsid_files
=
table
.
join
(
this_data_obsid
,
data_basis
,
keys
=
[
"dataset"
,
"instrument"
,
"obs_type"
,
"obs_group"
,
"obs_id"
,
],
join_type
=
"inner"
,
)
this_data_obsid_plan
=
table
.
join
(
this_data_obsid
,
relevant_plan
,
keys
=
[
"dataset"
,
"instrument"
,
"obs_type"
,
"obs_group"
,
"obs_id"
,
],
join_type
=
"left"
,
)
# whether effective detectors all there
this_instrument
=
this_data_obsid
[
"instrument"
][
0
]
this_success
=
set
(
csst
[
this_instrument
].
effective_detector_names
).
issubset
(
set
(
this_data_obsid_files
[
"detector"
])
)
# append this task
task_list
.
append
(
dict
(
task
=
this_task
,
success
=
this_success
,
relevant_plan
=
this_data_obsid_plan
,
relevant_data
=
this_data_obsid_files
,
)
)
return
task_list
@
staticmethod
def
dispatch_obsgroup
(
plan_basis
:
table
.
Table
,
data_basis
:
table
.
Table
,
# n_jobs: int = 1,
)
->
list
[
dict
]:
# unique obsgroup basis
obsgroup_basis
=
table
.
unique
(
plan_basis
[
"dataset"
,
"instrument"
,
"obs_type"
,
"obs_group"
,
]
)
# initialize task list
task_list
=
[]
# loop over obsgroup
for
i_obsgroup
in
trange
(
len
(
obsgroup_basis
),
unit
=
"task"
,
dynamic_ncols
=
True
,
):
# i_obsgroup = 1
this_task
=
dict
(
obsgroup_basis
[
i_obsgroup
])
this_success
=
True
this_obsgroup_obsid
=
table
.
join
(
obsgroup_basis
[
i_obsgroup
:
i_obsgroup
+
1
],
# this obsgroup
plan_basis
,
keys
=
[
"dataset"
,
"instrument"
,
"obs_type"
,
"obs_group"
],
join_type
=
"left"
,
)
this_obsgroup_file
=
table
.
join
(
this_obsgroup_obsid
,
data_basis
,
keys
=
[
"dataset"
,
"instrument"
,
"obs_type"
,
"obs_group"
,
"obs_id"
],
join_type
=
"inner"
,
table_names
=
[
"plan"
,
"data"
],
)
# loop over obsid
for
i_obsid
in
range
(
len
(
this_obsgroup_obsid
)):
# i_obsid = 1
# print(i_obsid)
instrument
=
this_obsgroup_obsid
[
i_obsid
][
"instrument"
]
n_frame
=
this_obsgroup_obsid
[
i_obsid
][
"n_frame"
]
effective_detector_names
=
csst
[
instrument
].
effective_detector_names
this_obsgroup_obsid_file
=
table
.
join
(
this_obsgroup_obsid
[
i_obsid
:
i_obsid
+
1
],
# this obsid
data_basis
,
keys
=
[
"dataset"
,
"instrument"
,
"obs_type"
,
"obs_group"
,
"obs_id"
],
join_type
=
"inner"
,
table_names
=
[
"plan"
,
"data"
],
)
if
instrument
==
"HSTDM"
:
# 我也不知道太赫兹要怎么玩
# this_success &= (
# len(this_obsgroup_obsid_file) == n_frame
# or len(this_obsgroup_obsid_file) == n_frame * 2
# )
# or simply
this_success
&=
len
(
this_obsgroup_obsid_file
)
%
n_frame
==
0
else
:
n_file_expected
=
1
# count files found in data_basis
idx_files_found
=
np
.
argwhere
(
data_basis
==
this_task
).
flatten
()
n_file_found
=
len
(
idx_files_found
)
# if found == expected, append this task
if
n_file_found
==
n_file_expected
:
task_list
.
append
(
this_task
)
relevant_data_id_list
.
extend
(
[
data_recs
.
data
[
_
][
"_id"
]
for
_
in
idx_files_found
]
# n_detector == n_file
# this_success &= len(this_obsgroup_obsid_file) == len(
# effective_detector_names
# )
# or more strictly, each detector matches
this_success
&=
set
(
this_obsgroup_obsid_file
[
"detector"
])
==
set
(
effective_detector_names
)
return
dict
(
task_list
=
task_list
,
relevant_data_id_list
=
relevant_data_id_list
)
# append this task
task_list
.
append
(
dict
(
task
=
this_task
,
success
=
this_success
,
relevant_plan
=
this_obsgroup_obsid
,
relevant_data
=
this_obsgroup_file
,
)
)
return
task_list
@
staticmethod
def
dispatch_level0_obsid
(
**
kwargs
)
->
list
[
dict
]:
pass
def
load_test_data
()
->
tuple
:
import
joblib
plan_recs
=
joblib
.
load
(
"dagtest/csst-msc-c9-25sqdeg-v3.plan.dump"
)
data_recs
=
joblib
.
load
(
"dagtest/csst-msc-c9-25sqdeg-v3.level0.dump"
)
print
(
f
"
{
len
(
plan_recs
.
data
)
}
plan records"
)
print
(
f
"
{
len
(
data_recs
.
data
)
}
data records"
)
for
_
in
plan_recs
.
data
:
_
[
"n_frame"
]
=
(
_
[
"params"
][
"n_epec_frame"
]
if
_
[
"instrument"
]
==
"HSTDM"
else
1
)
plan_basis
=
extract_basis_table
(
plan_recs
.
data
,
PLAN_BASIS_KEYS
,
)
data_basis
=
extract_basis_table
(
data_recs
.
data
,
DATA_BASIS_KEYS
,
)
return
plan_basis
,
data_basis
# # 1221 plan recs, 36630 data recs
# plan_basis, data_basis = Dispatcher.load_test_data()
#
# # 430 task/s
# task_list_via_file = Dispatcher.dispatch_file(plan_basis, data_basis)
#
# # 13 task/s @n_jobs=1, 100*10 task/s @n_jobs=10 (max)
# task_list_via_detector = Dispatcher.dispatch_detector(plan_basis, data_basis, n_jobs=10)
#
# # 16 task/s @n_jobs=1, 130*10 tasks/s @n_jobs=10 (max) 🔼
# task_list_via_obsid = Dispatcher.dispatch_obsid(plan_basis, data_basis, n_jobs=10)
#
# # 13s/task
# task_list_via_obsgroup = Dispatcher.dispatch_obsgroup(plan_basis, data_basis)
# print(
# sum(_["success"] for _ in task_list_via_obsgroup),
# "/",
# len(task_list_via_obsgroup),
# )
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment