Commit 6f2b7509 authored by BO ZHANG's avatar BO ZHANG 🏀
Browse files

remove dag_list

parent ced55eb3
import itertools
from astropy import table
from ._base_dag import BaseDAG
from ._dag_list import DAG_LIST
from .dags import GeneralDAGViaObsid, GeneralDAGViaObsgroup
from .dispatcher import Dispatcher
from ._dispatcher import Dispatcher
from .._csst import csst
class CsstDAGs(dict):
def generate_permutations(**kwargs) -> table.Table:
"""
A class to manage all DAGs.
生成关键字参数所有值的排列组合字典列表
参数:
**kwargs: 关键字参数,值应为可迭代对象(如列表)
返回:
list[dict]: 每个字典代表一种排列组合
"""
# 验证输入值是否为可迭代对象
for key, values in kwargs.items():
if not isinstance(values, (list, tuple, set)):
kwargs[key] = [values] # 如果不是可迭代对象,转换为列表
dag_list = {
"csst-msc-l1-qc0": GeneralDAGViaObsid(
dag_group="msc-l1", dag="csst-msc-l1-qc0", use_detector=True
),
"csst-msc-l1-mbi": GeneralDAGViaObsid(
dag_group="msc-l1", dag="csst-msc-l1-mbi", use_detector=True
),
"csst-msc-l1-ast": GeneralDAGViaObsid(
dag_group="msc-l1", dag="csst-msc-l1-ast", use_detector=True
# 提取键和对应的值列表
keys = list(kwargs.keys())
value_lists = [kwargs[key] for key in keys]
# 生成笛卡尔积(所有值组合)
permutations = []
for combination in itertools.product(*value_lists):
# 将每个组合转换为字典 {键: 值}
perm_dict = dict(zip(keys, combination))
permutations.append(perm_dict)
return table.Table(permutations)
CSST_DAGS = {
"csst-msc-l1-qc0": BaseDAG(
dag="csst-msc-l1-qc0",
pattern=generate_permutations(
instrument=["MSC"],
obs_type=["BIAS", "DARK", "FLAT"],
),
"csst-msc-l1-sls": GeneralDAGViaObsid(
dag_group="msc-l1", dag="csst-msc-l1-sls", use_detector=True
dispatcher=Dispatcher.dispatch_file,
),
"csst-msc-l1-mbi": BaseDAG(
dag="csst-msc-l1-mbi",
pattern=generate_permutations(
instrument=["MSC"],
obs_type=["WIDE", "DEEP"],
detector=csst["MSC"]["MBI"].effective_detector_names,
),
"csst-msc-l1-ooc": GeneralDAGViaObsgroup(
dag_group="msc-l1-ooc", dag="csst-msc-l1-ooc"
),
"csst-msc-l1-ast": BaseDAG(
dag="csst-msc-l1-ast",
pattern=generate_permutations(
instrument=["MSC"],
obs_type=["WIDE", "DEEP"],
detector=csst["MSC"]["MBI"].effective_detector_names,
),
"csst-cpic-l1": GeneralDAGViaObsid(
dag_group="cpic-l1", dag="csst-cpic-l1", use_detector=True
),
"csst-msc-l1-sls": BaseDAG(
dag="csst-msc-l1-sls",
pattern=generate_permutations(
instrument=["MSC"],
obs_type=["WIDE", "DEEP"],
detector=csst["MSC"]["SLS"].effective_detector_names,
),
"csst-cpic-l1-qc0": GeneralDAGViaObsid(
dag_group="cpic-l1", dag="csst-cpic-l1-qc0", use_detector=True
),
"csst-msc-l1-ooc": BaseDAG(
dag="csst-msc-l1-ooc",
pattern=generate_permutations(
instrument=["MSC"],
obs_type=["BIAS", "DARK", "FLAT"],
detector=csst["MSC"].effective_detector_names,
),
}
def __init__(self):
super().__init__() # 初始化空字典
self.update(self.dag_list) # 先添加默认键值对
# self.update(*args, **kwargs) # 用户传入值覆盖默认值
@staticmethod
def ls():
print(DAG_LIST.keys())
),
}
""" """
import json
import os
from abc import abstractmethod
from typing import Any, Callable, Optional
import yaml
from typing import Any
from astropy import table
from ._dag_list import DAG_LIST
from .._dfs import DFS, dfs
from ..hash import generate_sha1_from_time
from ._dispatcher import Dispatcher
DAG_CONFIG_DIR = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
......@@ -29,7 +28,7 @@ class BaseDAG:
"""Base class for all Directed Acyclic Graph (DAG) implementations.
This class provides core functionality for DAG configuration, message generation,
and execution management within the CSST data processing system.
and execution management within the CSST dlist processing system.
Attributes
----------
......@@ -39,8 +38,6 @@ class BaseDAG:
Configuration loaded from YAML file
dag_run_template : dict
Message template structure loaded from JSON file
dag_run_keys : set
Set of all valid message keys from the template
dfs : DFS
Data Flow System instance for execution
......@@ -50,34 +47,23 @@ class BaseDAG:
If DAG name is not in DAG_MAP or config name mismatch
"""
INSTRUMENT_ENUM = ("MSC", "MCI", "IFS", "CPIC", "HSTDM")
def __init__(self, dag_group: str, dag: str, use_detector: bool = False):
def __init__(
self,
dag: str,
pattern: table.Table,
dispatcher: Callable,
):
"""Initialize a DAG instance with configuration loading.
Parameters
----------
dag_group : str
Name of the DAG group.
dag : str
Name of the DAG.
use_detector : bool, optional
Flag to determine if `detector` is used.
Raises
------
AssertionError
If DAG name is invalid or config files are inconsistent
DAG name, must exist in DAG_MAP
"""
# Set DAG name
self.dag_group = dag_group
self.dag = dag
self.use_detector = use_detector
assert dag in DAG_LIST, f"{dag} not in DAG_MAP"
# determine instrument
self.instrument = dag.split("-")[1].upper() # e.g., "MSC"
assert self.instrument in self.INSTRUMENT_ENUM, self.instrument
self.pattern = pattern
self.dispatcher = dispatcher
# Load yaml and json config
yml_path = os.path.join(DAG_CONFIG_DIR, f"{dag}.yml")
......@@ -91,25 +77,45 @@ class BaseDAG:
with open(json_path, "r") as f:
self.dag_run_template = json.load(f)
# Summarize DAG run keys
self.dag_run_keys = set(self.dag_run_template.keys())
# DFS instance
self.dfs = dfs
def schedule(self, **kwargs):
"""Placeholder for DAG scheduling logic.
def filter_basis(self, plan_basis, data_basis):
filtered_data_basis = table.join(
self.pattern,
data_basis,
keys=self.pattern.colnames,
join_type="inner",
)
u_data_basis = table.unique(filtered_data_basis["dataset", "obs_id"])
filtered_plan_basis = table.join(
u_data_basis,
plan_basis,
keys=["dataset", "obs_id"],
join_type="inner",
)
return filtered_plan_basis, filtered_data_basis
Notes
-----
This method must be implemented by concrete DAG subclasses.
def schedule(
self,
dag_group_run: dict,
data_basis: table.Table,
plan_basis: table.Table,
**kwargs,
) -> list[dict]:
filtered_plan_basis, filtered_data_basis = self.filter_basis(
plan_basis, data_basis
)
task_list = self.dispatcher(filtered_plan_basis, filtered_data_basis)
dag_run_list = []
for task in task_list:
dag_run = self.gen_dag_run(
dag_group_run=dag_group_run,
**task,
)
dag_run_list.append(dag_run)
Raises
------
NotImplementedError
Always raises as this is an abstract method
"""
raise NotImplementedError("Not implemented yet")
return dag_run_list
@staticmethod
def generate_sha1():
......@@ -151,13 +157,15 @@ class BaseDAG:
return dict(
dag_group=dag_group,
dag_group_run=BaseDAG.generate_sha1(),
# dag=self.dag,
# dag_run=BaseDAG.generate_sha1(),
batch_id=batch_id,
priority=priority,
)
def gen_dag_run(self, dag_group_run: dict, **dag_run_kwargs: Any):
def gen_dag_run(
self,
dag_group_run: dict,
**dag_run_kwargs: Any,
) -> dict:
"""Generate a complete DAG run message.
Parameters
......@@ -182,12 +190,16 @@ class BaseDAG:
# update dag_group_run info
for k, v in dag_group_run.items():
assert k in self.dag_run_keys, f"{k} not in {self.dag_run_keys}"
assert (
k in self.dag_run_template.keys()
), f"{k} not in {self.dag_run_template.keys()}"
dag_run[k] = v
# update dag_run info
for k, v in dag_run_kwargs.items():
assert k in self.dag_run_keys, f"{k} not in {self.dag_run_keys}"
assert (
k in self.dag_run_template.keys()
), f"{k} not in {self.dag_run_template.keys()}"
dag_run[k] = v
return dag_run
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment