Commit f819b17d authored by Wei Shoulin's avatar Wei Shoulin
Browse files

refactor(client): implement batch processing for DAG and plan operations

- Add batch processing logic to dag.py for handling large dag_run_list
- Enhance write_file function in plan.py to support batch uploads
- Update test_level0.py to include tests for new batch processing features
parent 5bdb9543
Pipeline #9222 canceled with stages
in 0 seconds
......@@ -15,7 +15,21 @@ def new_dag_group_run(dag_group_run: dict, dag_run_list: Optional[list] = None)
Result: 成功后,Result.data为写入记录,失败message为失败原因。
"""
return request.put("/api/dag/group_run", {'dag_group_run': dag_group_run, 'dag_run_list': dag_run_list})
batch_size = 512
if dag_run_list is None:
return request.put("/api/dag/group_run", {'dag_group_run': dag_group_run, 'dag_run_list': []})
results = []
for i in range(0, len(dag_run_list), batch_size):
batch = dag_run_list[i:i + batch_size]
result = request.put("/api/dag/group_run", {'dag_group_run': dag_group_run, 'dag_run_list': batch})
results.append(result)
if not result.success:
# If any batch fails, return the failed result immediately
return result
# If all batches succeed, return the last result
return results[-1]
def find_group_run(dag_group: Optional[str] = None,
batch_id: Optional[str] = None,
......
import os
import json
from typing import Optional, IO, Tuple, Literal, Union
from .common import request, Result
DateTimeTuple = Tuple[str, str]
......@@ -83,25 +83,40 @@ def find_by_opid(opid: str) -> Result:
"""
return request.get(f"/api/plan/{opid}")
def write_file(local_file: Union[IO, str], **kwargs) -> Result:
def write_file(local_file: Union[IO, str, list]) -> Result:
"""
将本地json文件或json数据写入DFS中。
将本地json文件、json数据流或json数据列表写入DFS中。
Args:
local_file (str]): 文件路径
local_file (Union[IO, str, list]): 文件路径、数据流或JSON数据列表
**kwargs: 额外的关键字参数,这些参数将传递给DFS。
Returns:
Result: 操作的结果对象,包含操作是否成功以及相关的错误信息,成功返回数据对象。
"""
batch_size = 512
"""
if local_file is None:
raise ValueError("local_file is required")
if isinstance(local_file, str):
if not os.path.exists(local_file):
raise FileNotFoundError(local_file)
return request.post_file("/api/plan/file", local_file, kwargs)
return request.post_bytesio("/api/plan/file", local_file, kwargs)
with open(local_file, 'r') as f:
data = json.load(f)
elif isinstance(local_file, IO):
data = json.load(local_file)
elif isinstance(local_file, list):
data = local_file
else:
raise ValueError("Unsupported type for local_file")
if not isinstance(data, list):
raise ValueError("Data must be a list of JSON objects")
for i in range(0, len(data), batch_size):
batch_data = data[i:i + batch_size]
response = request.post("/api/plan/file", {"plans": batch_data})
if response.code != 200:
return response
return Result.ok_data(len(data))
def new(data: dict) -> Result:
"""
......
......@@ -7,64 +7,64 @@ class Level0TestCase(unittest.TestCase):
def setUp(self):
pass
# def test_find(self):
# start_time = time.time()
# result = level0.find(instrument='MSC', dataset="msc-v093")
# print(f"1操作执行时间: {time.time() - start_time} 秒")
# start_time = time.time()
# result = level0.find(instrument='MSC',
# ra_obj = 170,
# dec_obj = -24,
# radius = 1)
# print(f"2操作执行时间: {time.time() - start_time} 秒, 数据量:{result['total_count']}")
# start_time = time.time()
# result = level0.find(instrument='MSC', file_name="CSST_MSC_MS_SCI_20240609181116_20240609181347_10109100157991_27_L0_V01.fits")
# print(f"3操作执行时间: {time.time() - start_time} 秒, 数据量:{result['total_count']}")
# self.assertEqual(result.code, 200, "error code: " + result.message)
# self.assertIsNotNone(result.data, "error message: " + result.message)
def test_find(self):
start_time = time.time()
result = level0.find(instrument='MSC', dataset="msc-v093")
print(f"1操作执行时间: {time.time() - start_time} 秒")
start_time = time.time()
result = level0.find(instrument='MSC', dataset="msc-v093",
ra_obj = 170,
dec_obj = -24,
radius = 1)
print(f"2操作执行时间: {time.time() - start_time} 秒, 数据量:{result['total_count']}")
start_time = time.time()
result = level0.find(instrument='MSC', file_name="CSST_MSC_MS_SCI_20240609181116_20240609181347_10109100157991_27_L0_V01.fits")
print(f"3操作执行时间: {time.time() - start_time} 秒, 数据量:{result['total_count']}")
self.assertEqual(result.code, 200, "error code: " + result.message)
self.assertIsNotNone(result.data, "error message: " + result.message)
# def test_find_by_level0_id(self):
# result = level0.find_by_level0_id(level0_id = "1060940003452925")
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# self.assertIsNotNone(result.data, "error message: " + result.message)
def test_find_by_level0_id(self):
result = level0.find_by_level0_id(level0_id = "1060940003452925")
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
self.assertIsNotNone(result.data, "error message: " + result.message)
# def test_update_qc0_status(self):
# result = level0.update_qc0_status(level0_id = "1010910015799127", qc0_status=1)
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_update_qc0_status(self):
result = level0.update_qc0_status(level0_id = "1010910015799127", qc0_status=1)
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_update_qc0_status_by_ids(self):
result = level0.update_qc0_status_by_ids(ids = ["676ac74a530b47ca41568858"], qc0_status=4)
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# def test_update_prc_status(self):
# result = level0.update_prc_status(level0_id = "1010910015799127", dag_run="202411071002481234", prc_status=3)
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_update_prc_status(self):
result = level0.update_prc_status(level0_id = "1010910015799127", dag_run="202411071002481234", prc_status=3)
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# def test_update_prc_status_by_ids(self):
# result = level0.update_prc_status_by_ids(ids = ["676ac74a530b47ca41568858"], prc_status=4)
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_update_prc_status_by_ids(self):
result = level0.update_prc_status_by_ids(ids = ["676ac74a530b47ca41568858"], prc_status=4)
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# def test_write(self):
# file_path = "/Users/wsl/temp/csst/import/CSST_MSC_MS_SCI_20240609181116_20240609181347_10109100157991_27_L0_V01.fits"
# result = level0.write(local_file = file_path, dataset= 'msc-v093')
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_write(self):
file_path = "/Users/wsl/temp/csst/import/CSST_MSC_MS_SCI_20240609181116_20240609181347_10109100157991_27_L0_V01.fits"
result = level0.write(local_file = file_path, dataset= 'msc-v093')
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# def test_find_process(self):
# result = level0.find_process(level0_id="1060940003452925")
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# def test_add_process(self):
# result = level0.add_process(level0_id="1060940003452925",
# dag="csst-msc-l1-mbi",
# dag_run="202411071002481234",
# dataset="v93",
# batch_id="v930batch",
# prc_time="2024-11-07 10:24:12", prc_status=1, prc_module="MSC", message="")
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
\ No newline at end of file
def test_find_process(self):
result = level0.find_process(level0_id="1060940003452925")
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_add_process(self):
result = level0.add_process(level0_id="1060940003452925",
dag="csst-msc-l1-mbi",
dag_run="202411071002481234",
dataset="v93",
batch_id="v930batch",
prc_time="2024-11-07 10:24:12", prc_status=1, prc_module="MSC", message="")
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment