dag.py 4.19 KB
Newer Older
BO ZHANG's avatar
BO ZHANG committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import json
import yaml
import glob
import toml
from typing import Optional
from .message import gen_dag_run_id


RUN_ID_DIGITS = 15


DAG_RULE_DIRECTORY = os.path.join(os.path.dirname(__file__), "dag_rules")
DAG_TEMPLATE_DIRECTORY = os.path.join(os.path.dirname(__file__), "dag_templates")
# DAG_RULE_DIRECTORY = "/Users/cham/CsstProjects/csst-dag/csst_dag/dag_rules"
# DAG_TEMPLATE_DIRECTORY = "/Users/cham/CsstProjects/csst-dag/csst_dag/dag_templates"


def match_string_headers(s: str, headers: set[str]) -> bool:
    status = False
    for header in headers:
        if s.startswith(header):
            status = True
    return status


# define CsstDAG and CsstDAGList
class CsstDAG:
    def __init__(self, name="", rules: Optional[dict] = None):
        self.name = name
        self.rules = rules
        self.keys = set(rules.keys())

        # load message template
        dag_def_path = os.path.join(DAG_TEMPLATE_DIRECTORY, f"{name}.yml")
        dag_msg_path = os.path.join(DAG_TEMPLATE_DIRECTORY, f"{name}.json")
        if os.path.exists(dag_def_path):
            with open(dag_def_path, "r") as f:
                self.definition = yaml.safe_load(f)
        else:
            raise FileNotFoundError(f"{dag_def_path} not found")
        if os.path.exists(dag_msg_path):
            with open(dag_msg_path, "r") as f:
                self.message_template = json.load(f)
        else:
            raise FileNotFoundError(f"{dag_msg_path} not found")

    def match(self, **kwargs: dict) -> bool:
        # check if all required keys are present
        if set(kwargs.keys()) != self.keys:
            return False
        # check if all values are valid
        for k, v in kwargs.items():
            # check if v is in self.rules[k]
            if not match_string_headers(v, self.rules[k]):
                return False
        # all checks passed
        return True

    def __repr__(self):
        return f"CsstDAG({self.name})"

    def pprint(self):
        print(f"Name: {self.name}")
        print(f"Keys: {self.keys}")
        print(f"Definition: {self.definition}")
        print(f"Message template: {self.message_template}")

    def gen_message(
        self,
        batch_id: str = "msc-v093-rdx-naoc-v1",
        dataset: str = "msc-v093",
        **kwargs,  # required keywords for DAG
    ) -> str:
        """Generate DAG message"""
        if not self.match(**kwargs):
            raise ValueError(f"Cannot generate DAG message for {self.name}")

        dag_id = self.name
        this_dag_run_id = gen_dag_run_id(RUN_ID_DIGITS)

        this_message = dict(
            dag_id=dag_id,
            dag_run_id=this_dag_run_id,
            batch_id=batch_id,
            message=dict(
                dataset=dataset,
                batch_id=batch_id,
                **kwargs,
            ),
        )

        message_string = json.dumps(this_message, ensure_ascii=False, indent=None)
        return message_string


class CsstDAGList(list):
    def __init__(self, *args):
        super().__init__(*args)

    def get(self, name: str) -> CsstDAG:
        """Get DAG by name"""
        for dag in self:
            if dag.name == name:
                return dag
        raise ValueError(f"Cannot find DAG: {name}")

    def match(self, **kwargs: dict) -> list:
        """Match DAGs by kwargs"""
        matched_dag_list = []
        for dag in self:
            if dag.match(**kwargs):
                matched_dag_list.append(dag.name)
        return matched_dag_list

    def match_dag(self, **kwargs: dict) -> list:
        """Match DAGs by kwargs"""
        matched_dag_list = []
        for dag in self:
            if dag.match(**kwargs):
                matched_dag_list.append(dag)
        return matched_dag_list


# load all DAG templates
DAG_RULES = glob.glob(os.path.join(DAG_RULE_DIRECTORY, "*.toml"))
CSST_DAG_LIST = CsstDAGList()
print(f"DAG_TRIGGER_DIRECTORY: {DAG_RULE_DIRECTORY}")
print(f"DAG_TEMPLATE_DIRECTORY: {DAG_TEMPLATE_DIRECTORY}")
for dag_file in DAG_RULES:
    with open(os.path.join(DAG_RULE_DIRECTORY, dag_file), "r") as f:
        dags = toml.load(f)

    for name, dag in dags.items():
        print(f" - Add DAG: name={name}, required_keys={tuple(dag.keys())}")
        CSST_DAG_LIST.append(CsstDAG(name, dag))