Commit 85880de1 authored by BO ZHANG's avatar BO ZHANG 🏀
Browse files

update DAGs

parent 9f1a6963
from .dfs import DFS
from .dag import CsstDAG
from ._dfs import DFS, dfs
from .dag import CsstDAGs
......@@ -7,13 +7,13 @@ from astropy import time
DAG_RUN_ID_DIGITS = 6
DAG_MESSAGE_TEMPLATE_DIRECTORY = os.path.join(os.path.dirname(__file__), "dag")
DAG_MESSAGE_TEMPLATE_DIRECTORY = os.path.join(os.path.dirname(__file__), "dag_cfg")
DAG_YAML_LIST = glob.glob(DAG_MESSAGE_TEMPLATE_DIRECTORY + "/*.yml")
DAG_LIST = [os.path.splitext(os.path.basename(_))[0] for _ in DAG_YAML_LIST]
# print(DAG_LIST)
# print(DAG_MAP)
# [
# "csst-msc-l1-ooc-bias",
# "csst-msc-l1-ooc-flat",
......@@ -28,7 +28,7 @@ DAG_LIST = [os.path.splitext(os.path.basename(_))[0] for _ in DAG_YAML_LIST]
def gen_dag_run_id(digits=6):
"""
Generate a unique run_id for a dag.
Generate a unique run_id for a dag_cfg.
"""
now = time.Time.now()
dag_run_id = now.strftime("%Y%m%d-%H%M%S-")
......@@ -41,10 +41,10 @@ def gen_dag_run_id(digits=6):
def get_dag_message_template(dag_id):
"""
Get the dag message template for a given dag_id.
Get the dag_cfg message template for a given dag_cfg.
"""
if dag_id not in DAG_LIST:
raise ValueError(f"Unknown dag_id: {dag_id}")
raise ValueError(f"Unknown dag_cfg: {dag_id}")
with open(os.path.join(DAG_MESSAGE_TEMPLATE_DIRECTORY, f"{dag_id}.json"), "r") as f:
template = json.load(f)
return template
......
......@@ -8,16 +8,16 @@ Example
python -m csst_dag.cli.msc -h
python -m csst_dag.cli.msc \
--dataset=csst-msc-c9-25sqdeg-v3 \
--obs-group=none \
--batch-id=csci-test-20250507 \
--dataset=csst-msc-c9-25sqdeg-v3 \
--obs-group=W1 \
--priority=1 \
--initial-prc-status=-1024 \
--final-prc-status=-2 \
--demo
"""
from csst_dag.dag import CsstDAG
from csst_dag.dag import CsstDAGs
import argparse
parser = argparse.ArgumentParser(
......@@ -25,12 +25,15 @@ parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# task related parameters
parser.add_argument("--batch-id", type=str, help="Batch ID", default="default_batch")
parser.add_argument("--priority", type=str, help="Task priority", default=1)
# data related parameters
parser.add_argument("--dataset", type=str, help="Dataset name")
# parser.add_argument("--instrument", type=str, help="Instrument name", default="MSC")
parser.add_argument("--obs-group", type=str, help="Observation group", default="none")
parser.add_argument("--obs-type", type=str, help="Observation type", default="")
parser.add_argument("--batch-id", type=str, help="Batch ID", default="default_batch")
parser.add_argument("--priority", type=str, help="Task priority", default=1)
# status related parameters
parser.add_argument(
"--initial-prc-status", type=int, help="Initial processing status", default=-1024
)
......@@ -44,30 +47,33 @@ parser.add_argument(
args = parser.parse_args()
print("CLI parameters: ", args)
DAG_LOOP_MAP = {
"WIDE": ["csst-msc-l1-mbi", "csst-msc-l1-sls"],
"DEEP": ["csst-msc-l1-mbi", "csst-msc-l1-sls"],
# define DAP LIST in this group
DAG_MAP = {
"WIDE": ["csst-msc-l1-mbi", "csst-msc-l1-sls", "csst-msc-l1-ast"],
"DEEP": ["csst-msc-l1-mbi", "csst-msc-l1-sls", "csst-msc-l1-ast"],
"BIAS": ["csst-msc-l1-qc0"],
"DARK": ["csst-msc-l1-qc0"],
"FLAT": ["csst-msc-l1-qc0"],
}
# if obs_type is set
if args.obs_type:
assert args.obs_type in DAG_LOOP_MAP.keys(), f"Unknown obs_type: {args.obs_type}"
DAG_LOOP_MAP = {args.obs_type: DAG_LOOP_MAP[args.obs_type]}
assert args.obs_type in DAG_MAP.keys(), f"Unknown obs_type: {args.obs_type}"
DAG_MAP = {args.obs_type: DAG_MAP[args.obs_type]}
for obs_type, dag_ids in DAG_LOOP_MAP.items():
for obs_type, dags in DAG_MAP.items():
print(f"* Processing {obs_type}")
for dag_id in dag_ids:
print(f" - Scheduling `{dag_id}` -> ", end="")
dag = CsstDAG.get_dag(dag_id=dag_id)
for this_dag in dags:
print(f" - Scheduling `{this_dag}` -> ", end="")
dag = CsstDAGs.get_dag(dag=this_dag)
msgs = dag.schedule(
batch_id=args.batch_id,
priority=args.priority,
dataset=args.dataset,
obs_type=obs_type,
obs_group=args.obs_group,
batch_id=args.batch_id,
initial_prc_status=args.initial_prc_status,
final_prc_status=args.final_prc_status,
demo=args.demo,
priority=args.priority,
)
print(f"{len(msgs)} tasks.")
from csst_dag import DFS
from csst_dfs_client import plan
import argparse
import sys
import os
......
from ._base_dag import BaseDAG
from ._dag_list import DAG_LIST
from .l1 import CsstL1
from .l1 import GeneralL1DAG
DAG_MAP = {
"csst-msc-l1-qc0": CsstL1(dag_id="csst-msc-l1-qc0"),
"csst-msc-l1-mbi": CsstL1(dag_id="csst-msc-l1-mbi"),
"csst-msc-l1-sls": CsstL1(dag_id="csst-msc-l1-sls"),
"csst-cpic-l1": CsstL1(dag_id="csst-cpic-l1"),
"csst-cpic-l1-qc0": CsstL1(dag_id="csst-cpic-l1-qc0"),
}
class CsstDAGs(dict):
"""
A class to manage all DAGs.
"""
dag_list = {
"csst-msc-l1-qc0": GeneralL1DAG(dag_group="msc-l1", dag="csst-msc-l1-qc0"),
"csst-msc-l1-mbi": GeneralL1DAG(dag_group="msc-l1", dag="csst-msc-l1-mbi"),
"csst-msc-l1-ast": GeneralL1DAG(dag_group="msc-l1", dag="csst-msc-l1-ast"),
"csst-msc-l1-sls": GeneralL1DAG(dag_group="msc-l1", dag="csst-msc-l1-sls"),
"csst-msc-l1-ooc": GeneralL1DAG(dag_group="msc-l1", dag="csst-msc-l1-ooc"),
"csst-cpic-l1": GeneralL1DAG(dag_group="cpic-l1", dag="csst-cpic-l1"),
"csst-cpic-l1-qc0": GeneralL1DAG(dag_group="cpic-l1", dag="csst-cpic-l1-qc0"),
}
class CsstDAG:
def __init__(self):
pass
super().__init__() # 初始化空字典
self.update(self.dag_list) # 先添加默认键值对
# self.update(*args, **kwargs) # 用户传入值覆盖默认值
@staticmethod
def ls():
print(DAG_MAP.keys())
print(DAG_LIST.keys())
@staticmethod
def get_dag(dag_id: str = ""):
assert dag_id in DAG_LIST, f"{dag_id} not in DAG_LIST"
return DAG_MAP[dag_id]
def get_dag(dag: str = ""):
assert dag in DAG_LIST, f"{dag} not in DAG_LIST"
return DAG_LIST[dag]
@staticmethod
def get_all():
return DAG_MAP
return DAG_LIST
from abc import ABC, abstractmethod
from ._dag_list import DAG_LIST
from ..dfs import dfs
import yaml
import os
import glob
import string
""" """
import json
import numpy as np
from astropy import time
import os
from abc import abstractmethod
import yaml
from typing import Any
from ._dag_list import DAG_LIST
from .._dfs import DFS, dfs
from ..hash import generate_sha1_from_time
DAG_RUN_ID_DIGITS = 6
DAG_CONFIG_DIR = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"dag_config",
......@@ -25,51 +26,189 @@ DAG_CONFIG_DIR = os.path.join(
class BaseDAG:
def __init__(self, dag_id: str):
self.dag_id = dag_id
assert dag_id in DAG_LIST, f"{dag_id} not in DAG_LIST"
"""Base class for all Directed Acyclic Graph (DAG) implementations.
This class provides core functionality for DAG configuration, message generation,
and execution management within the CSST data processing system.
Attributes
----------
dag : str
Name of the DAG, must exist in DAG_MAP
dag_cfg : dict
Configuration loaded from YAML file
dag_run_template : dict
Message template structure loaded from JSON file
dag_run_keys : set
Set of all valid message keys from the template
dfs : DFS
Data Flow System instance for execution
Raises
------
AssertionError
If DAG name is not in DAG_MAP or config name mismatch
"""
INSTRUMENT_ENUM = ("MSC", "MCI", "IFS", "CPIC", "HSTDM")
def __init__(self, dag_group: str, dag: str):
"""Initialize a DAG instance with configuration loading.
Parameters
----------
dag_group : str
Name of the DAG group.
dag : str
Name of the DAG.
Raises
------
AssertionError
If DAG name is invalid or config files are inconsistent
"""
# Set DAG name
self.dag_group = dag_group
self.dag = dag
assert dag in DAG_LIST, f"{dag} not in DAG_MAP"
# determine instrument
self.instrument = dag.split("-")[1] # e.g., "MSC"
assert self.instrument in self.INSTRUMENT_ENUM
# Load yaml and json config
yml_path = os.path.join(DAG_CONFIG_DIR, f"{dag}.yml")
json_path = os.path.join(DAG_CONFIG_DIR, f"{dag}.json")
yml_path = os.path.join(DAG_CONFIG_DIR, f"{dag_id}.yml")
json_path = os.path.join(DAG_CONFIG_DIR, f"{dag_id}.json")
with open(yml_path, "r") as f:
self.dag = yaml.safe_load(f)[0]
self.dag_cfg = yaml.safe_load(f)[0]
assert (
self.dag["dag_id"] == self.dag_id
), f"{self.dag['dag_id']} != {self.dag_id}" # , f"{dag_id} not consistent with definition in .yml file."
self.dag_cfg["name"] == self.dag
), f"{self.dag_cfg['name']} != {self.dag}" # , f"{dag_cfg} not consistent with definition in .yml file."
with open(json_path, "r") as f:
self.msg_template = json.load(f)
self.msg_keys = set(self.msg_template.keys())
self.dag_run_template = json.load(f)
# Summarize DAG run keys
self.dag_run_keys = set(self.dag_run_template.keys())
# DFS instance
self.dfs = dfs
def gen_msg(self, **kwargs):
"""Load message template and generate message dictionary."""
msg = self.msg_template.copy()
for k, v in kwargs.items():
assert k in self.msg_keys, f"{k} not in {self.msg_keys}"
msg[k] = v
return msg
# @abstractmethod
# def trigger(self, **kwargs) -> None:
# pass
#
def schedule(self, **kwargs):
"""Placeholder for DAG scheduling logic.
Notes
-----
This method must be implemented by concrete DAG subclasses.
Raises
------
NotImplementedError
Always raises as this is an abstract method
"""
raise NotImplementedError("Not implemented yet")
@staticmethod
def gen_dag_run_id(digits=6):
def generate_sha1():
"""Generate a unique SHA1 hash based on current timestamp.
Returns
-------
str
SHA1 hash string
"""
Generate a unique run_id for a dag.
return generate_sha1_from_time(verbose=False)
@staticmethod
def gen_dag_group_run(
dag_group: str = "-",
batch_id: str = "-",
priority: int = 1,
):
"""Generate a DAG group run configuration.
Parameters
----------
dag_group : str, optional
Group identifier (default: "-")
batch_id : str, optional
Batch identifier (default: "-")
priority : int, optional
Execution priority (default: 1)
Returns
-------
dict
Dictionary containing:
- dag_group: Original group name
- dag_group_run: Generated SHA1 identifier
- batch_id: Batch identifier
- priority: Execution priority
"""
now = time.Time.now()
dag_run_id = now.strftime("%Y%m%d-%H%M%S-")
n = len(string.ascii_lowercase)
for i in range(digits):
dag_run_id += string.ascii_lowercase[np.random.randint(low=0, high=n)]
return dag_run_id
@abstractmethod
def push(self, msg_dict: dict) -> None:
msg_str = json.dumps(msg_dict, ensure_ascii=False, indent=None)
return self.dfs.redis.push(msg_str)
return dict(
dag_group=dag_group,
dag_group_run=BaseDAG.generate_sha1(),
# dag=self.dag,
# dag_run=BaseDAG.generate_sha1(),
batch_id=batch_id,
priority=priority,
)
def gen_dag_run(self, dag_group_run: dict, **dag_run_kwargs: Any):
"""Generate a complete DAG run message.
Parameters
----------
dag_group_run : dict
Output from gen_dag_group_run()
**dag_run_kwargs : Any
Additional run-specific parameters
Returns
-------
dict
Complete DAG run message
Raises
------
AssertionError
If any key is not in the message template
"""
# copy template
dag_run = self.dag_run_template.copy()
# update dag_group_run info
for k, v in dag_group_run.items():
assert k in self.dag_run_keys, f"{k} not in {self.dag_run_keys}"
dag_run[k] = v
# update dag_run info
for k, v in dag_run_kwargs.items():
assert k in self.dag_run_keys, f"{k} not in {self.dag_run_keys}"
dag_run[k] = v
return dag_run
@staticmethod
def push_dag_group_run(
dag_group_run: dict,
dag_run_list: list[dict],
):
"""Submit a DAG group run to the DFS system.
Parameters
----------
dag_group_run : dict
Group run configuration
dag_run_list : list[dict]
List of individual DAG run messages
Returns
-------
Any
Result from dfs.dag.new_dag_group_run()
"""
return dfs.dag.new_dag_group_run(
dag_group_run=dag_group_run,
dag_run_list=dag_run_list,
)
......@@ -5,6 +5,7 @@ DAG_LIST = [
"csst-msc-l1-qc0",
"csst-msc-l1-mbi",
"csst-msc-l1-sls",
"csst-msc-l1-ast",
"csst-cpic-l1",
"csst-cpic-l1-qc0",
]
......
......@@ -2,148 +2,114 @@ import json
from ._base_dag import BaseDAG
from csst_dfs_client import plan, level0
# CHIPID_MAP = {
# "csst-msc-l1-mbi": MSC_MBI_CHIPID,
# "csst-msc-l1-sls": MSC_SLS_CHIPID,
# "csst-msc-l1-qc0": MSC_CHIPID,
# }
MSC_DETECTORS = [
"01",
"02",
"03",
"04",
"05",
"06",
"07",
"08",
"09",
"10",
"11",
"12",
"13",
"14",
"15",
"16",
"17",
"18",
"19",
"20",
"21",
"22",
"23",
"24",
"25",
"26",
"27",
"28",
"29",
"30",
]
MSC_MBI_DETECTORS = [
"06",
"07",
"08",
"09",
"11",
"12",
"13",
"14",
"15",
"16",
"17",
"18",
"19",
"20",
"22",
"23",
"24",
"25",
]
MSC_SLS_DETECTORS = [
"01",
"02",
"03",
"04",
"05",
"10",
"21",
"26",
"27",
"28",
"29",
"30",
]
DAG_PARAMS = {
"csst-msc-l1-mbi": {
"instrument": "MSC",
"additional_keys": {
"detector": {
"key_in_dfs": "detector",
"key_in_dag": "detector",
"enum": [
"06",
"07",
"08",
"09",
"11",
"12",
"13",
"14",
"15",
"16",
"17",
"18",
"19",
"20",
"22",
"23",
"24",
"25",
],
}
},
"detector": MSC_MBI_DETECTORS,
},
"csst-msc-l1-ast": {
"detector": MSC_MBI_DETECTORS,
},
"csst-msc-l1-sls": {
"instrument": "MSC",
"additional_keys": {
"detector": {
"key_in_dfs": "detector",
"key_in_dag": "detector",
"enum": [
"01",
"02",
"03",
"04",
"05",
"10",
"21",
"26",
"27",
"28",
"29",
"30",
],
},
},
"detector": MSC_SLS_DETECTORS,
},
"csst-msc-l1-qc0": {
"instrument": "MSC",
"additional_keys": {
"detector": {
"key_in_dfs": "detector",
"key_in_dag": "detector",
"enum": [
"01",
"02",
"03",
"04",
"05",
"06",
"07",
"08",
"09",
"10",
"11",
"12",
"13",
"14",
"15",
"16",
"17",
"18",
"19",
"20",
"21",
"22",
"23",
"24",
"25",
"26",
"27",
"28",
"29",
"30",
],
},
},
"detector": MSC_DETECTORS,
},
"csst-cpic-l1": {
"instrument": "CPIC",
"additional_keys": {
"detector": {
"key_in_dfs": "detector",
"key_in_dag": "detector",
"enum": [
"VIS",
],
},
},
"detector": ["VIS"],
},
"csst-cpic-l1-qc0": {
"instrument": "CPIC",
"additional_keys": {
"detector": {
"key_in_dfs": "detector",
"key_in_dag": "detector",
"enum": [
"VIS",
],
},
},
"detector": ["VIS"],
},
}
SCHEDULE_KWARGS = {"priority", "queue", "execution_date"}
SCHEDULE_KWARGS = {
"priority",
# "queue",
# "execution_date",
}
class CsstL1(BaseDAG):
class GeneralL1DAG(BaseDAG):
def __init__(self, dag_id: str):
super().__init__(dag_id)
self.params = DAG_PARAMS[dag_id] # MSC/MCI/IFS/CPIC/HSTDM
def __init__(self, dag_group: str, dag: str):
super().__init__(dag_group=dag_group, dag=dag)
# set parameters
self.params = DAG_PARAMS[dag]
def schedule(
self,
batch_id: str | None = "-",
dataset: str = "csst-msc-c9-25sqdeg-v3",
obs_type: str = "WIDE",
obs_group="none",
batch_id: str | None = "default",
obs_group="W1",
initial_prc_status: int = -1024, # level0 prc_status level1
final_prc_status: int = -2,
demo=True,
......@@ -151,7 +117,6 @@ class CsstL1(BaseDAG):
):
assert kwargs.keys() <= SCHEDULE_KWARGS, f"Unknown kwargs: {kwargs.keys()}"
# no need to query plan
#
# plan.write_file(local_path="plan.json")
# plan.find(
# instrument="MSC",
......@@ -160,9 +125,18 @@ class CsstL1(BaseDAG):
# project_id=project_id,
# )
# generate a dag_group_run
dag_group_run = self.gen_dag_group_run(
dag_group=self.dag_group,
batch_id=batch_id,
priority=kwargs.get("priority", 1),
)
if demo:
print(json.dumps(dag_group_run, indent=4))
# find level0 data records
recs = level0.find(
instrument=self.params["instrument"],
instrument=self.instrument,
dataset=dataset,
obs_type=obs_type,
obs_group=obs_group,
......@@ -171,41 +145,44 @@ class CsstL1(BaseDAG):
assert recs.success, recs.message
# generate DAG messages
msgs = []
dag_run_list = []
for this_rec in recs.data:
# filter level0 data records
is_selected = True
additional_keys = {}
for k, v in self.params["additional_keys"].items():
is_selected = is_selected and this_rec[v["key_in_dfs"]] in v["enum"]
additional_keys[v["key_in_dag"]] = this_rec[v["key_in_dfs"]]
for k, v in self.params.items():
is_selected = this_rec[k] in v and is_selected
additional_keys[k] = this_rec[k]
if is_selected:
# generate a DAG message if is_selected
this_msg = self.gen_msg(
this_dag_run = self.gen_dag_run(
dag_group_run=dag_group_run,
batch_id=batch_id,
dag_run_id=self.generate_sha1(),
dataset=dataset,
obs_type=obs_type,
obs_group=obs_group,
batch_id=batch_id,
obs_id=this_rec["obs_id"],
# chip_id=this_rec["detector_no"],
dag_run_id=self.gen_dag_run_id(),
**additional_keys,
**kwargs,
)
# print(json.dumps(this_msg, indent=4))
if not demo:
# push and update
self.push(this_msg)
this_update = level0.update_prc_status(
level0_id=this_rec["level0_id"],
dag_run_id=this_msg["dag_run_id"],
prc_status=final_prc_status,
dataset=dataset,
)
assert this_update.success, this_update.message
if demo:
print(json.dumps(this_dag_run, indent=4))
# update level0 prc_status
this_update = level0.update_prc_status(
level0_id=this_rec["level0_id"],
dag_run_id=this_dag_run["dag_run_id"],
prc_status=final_prc_status,
dataset=dataset,
)
assert this_update.success, this_update.message
dag_run_list.append(this_dag_run)
msgs.append(this_msg)
return msgs
if not demo:
# push and update
res_push = self.push_dag_group_run(dag_group_run, dag_run_list)
print(res_push)
assert res_push.success, res_push.message
return dag_group_run, dag_run_list
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment