import numpy as np
from astropy.table import Table
from csst_dag.dag import Dispatcher

# # 1221 plan recs, 36630 data recs
plan_basis, data_basis = Dispatcher.load_test_data()
print(len(plan_basis), len(data_basis))
print("plan conames: ", plan_basis.colnames)
print("data conames: ", data_basis.colnames)


Dispatcher.dispatch_file(plan_basis, data_basis)
Dispatcher.dispatch_detector(plan_basis, data_basis)
Dispatcher.dispatch_obsid(plan_basis, data_basis)
Dispatcher.dispatch_obsgroup(plan_basis, data_basis)


# # 666 task/s
task_list_via_file = Dispatcher.dispatch_file(plan_basis, data_basis[:10])
t = Table(task_list_via_file)
np.unique(t["success"])
np.unique(t["n_relevant_plan"])
np.unique(t["n_relevant_data"])
print(t["task"])
print(sum(t["success"]))
print(task_list_via_file[0]["relevant_plan"].colnames)
print(task_list_via_file[0]["relevant_data"].colnames)

# # 13 task/s @n_jobs=1, 100*10 task/s @n_jobs=10 (max)
task_list_via_detector = Dispatcher.dispatch_detector(
    plan_basis, data_basis[::100], n_jobs=1
)
t = Table(task_list_via_detector)
np.unique(t["n_relevant_plan"], return_counts=True)
np.unique(t["success"], return_counts=True)
print(task_list_via_detector[0]["relevant_plan"].colnames)
print(task_list_via_detector[0]["relevant_data"].colnames)

# 16 task/s @n_jobs=1, 130*10 tasks/s @n_jobs=10 (max) 🔼
task_list_via_obsid = Dispatcher.dispatch_obsid(plan_basis, data_basis, n_jobs=10)
t = Table(task_list_via_obsid)
np.unique(t["n_relevant_plan"], return_counts=True)
np.unique(t["success"], return_counts=True)
print(task_list_via_obsid[0]["relevant_plan"].colnames)
print(task_list_via_obsid[0]["relevant_data"].colnames)

# 13s/task
task_list_via_obsgroup = Dispatcher.dispatch_obsgroup(plan_basis, data_basis)
t = Table(task_list_via_obsgroup)
print(
    sum(_["success"] for _ in task_list_via_obsgroup),
    "/",
    len(task_list_via_obsgroup),
)
print(task_list_via_obsgroup[0]["relevant_plan"].colnames)
print(task_list_via_obsgroup[0]["relevant_data"].colnames)

# 16 task/s @n_jobs=1, 130*10 tasks/s @n_jobs=10 (max) 🔼
task_list_via_obsgroup_detector = Dispatcher.dispatch_obsgroup_detector(
    plan_basis, data_basis
)
t = Table(task_list_via_obsgroup_detector)
np.unique(t["n_relevant_plan"], return_counts=True)
np.unique(t["success"], return_counts=True)
print(task_list_via_obsgroup_detector[0]["relevant_plan"].colnames)
print(task_list_via_obsgroup_detector[0]["relevant_data"].colnames)

# relevant plan_basis:
# ['dataset', 'instrument', 'obs_type', 'obs_group', 'obs_id', 'detector', 'n_file', '_id']
# relevant data_basis:
# ['dataset', 'instrument', 'obs_type', 'obs_group', 'obs_id', 'detector', 'file_name', '_id', 'prc_status']
