Commit dc71ee77 authored by Wei Shoulin's avatar Wei Shoulin
Browse files

feat(csst_dfs_client): enhance level0 and level1 search APIs

- Add support for data_uuid and data_model parameters in level0.find
- Rename object_name parameter to object in level0, level1, and plan modules
- Support extra keyword arguments in level0.find and level1.find
- Add get_by_uuid function to retrieve level0 data by UUID
- Update tests to reflect API changes and enable previously commented tests
- Improve parameter handling for datetime ranges in search functions
parent 93de0971
Pipeline #10639 canceled with stages
in 0 seconds
import os
from typing import Optional, Tuple, Literal, List
from .common import request, Result, utils, constants
from typing import List, Literal, Optional, Tuple
from .common import Result, constants, request, utils
DateTimeTuple = Tuple[str, str]
def find(
instrument: Literal['MSC', 'IFS', 'MCI', 'HSTDM', 'CPIC'],
dataset: str,
instrument: Literal["MSC", "IFS", "MCI", "HSTDM", "CPIC"],
data_model: str = "raw",
data_uuid: Optional[str] = None,
obs_group: Optional[str] = None,
obs_id: Optional[str] = None,
detector: Optional[str] = None,
......@@ -20,13 +24,17 @@ def find(
ra_obj: Optional[int] = None,
dec_obj: Optional[int] = None,
radius: Optional[float] = None,
object_name: Optional[str] = None,
object: Optional[str] = None,
page: int = 1,
limit: int = 0) -> Result:
limit: int = 0,
**extra_kwargs,
) -> Result:
"""
根据给定的参数搜索0级数据文件记录
Args:
data_uuid (Optional[str], optional): 数据UUID. Defaults to None.
data_model (str): 数据模型. Defaults is 'raw'.
instrument (str): 设备,必需为'MSC', 'IFS', 'MCI', 'HSTDM', 'CPIC'之一.
obs_group (Optional[str], optional): 项目ID. Defaults to None.
obs_id (Optional[str], optional): 观测ID. Defaults to None.
......@@ -41,7 +49,7 @@ def find(
ra_obj (Optional[int], optional): 目标赤经. Defaults to None.
dec_obj (Optional[int], optional): 目标赤纬. Defaults to None.
radius (Optional[float], optional): 搜索半径. Defaults to None.
object_name (Optional[str], optional): 目标名称. Defaults to None.
object (Optional[str], optional): 目标名称. Defaults to None.
dataset (Optional[str], optional): 数据集名称. Defaults to constants.DEFAULT_DATASET.
page (int, optional): 页码. Defaults to 1.
limit (int, optional): 每页数量. Defaults to 0,不限制.
......@@ -52,39 +60,49 @@ def find(
"""
params = {
'obs_group': obs_group,
'obs_id': obs_id,
'instrument': instrument,
'detector': detector,
'obs_type': obs_type,
'filter': filter,
'qc_status': qc_status,
'prc_status': prc_status,
'file_name': file_name,
'ra_obj': ra_obj,
'dec_obj': dec_obj,
'radius': radius,
'object_name': object_name,
'obs_date_start': None,
'obs_date_end': None,
'create_time_start': None,
'create_time_end': None,
'dataset': dataset,
'page': page,
'limit': limit,
"data_uuid": data_uuid,
"data_model": data_model,
"obs_group": obs_group,
"obs_id": obs_id,
"instrument": instrument,
"detector": detector,
"obs_type": obs_type,
"filter": filter,
"qc_status": qc_status,
"prc_status": prc_status,
"file_name": file_name,
"ra_obj": ra_obj,
"dec_obj": dec_obj,
"radius": radius,
"object": object,
"obs_date_start": None,
"obs_date_end": None,
"create_time_start": None,
"create_time_end": None,
"dataset": dataset,
"page": page,
"limit": limit,
}
params.update(extra_kwargs)
if obs_date is not None:
params['obs_date_start'], params['obs_date_end'] = obs_date
if params['obs_date_start'] and utils.is_valid_datetime_format(params['obs_date_start']):
params["obs_date_start"], params["obs_date_end"] = obs_date
if params["obs_date_start"] and utils.is_valid_datetime_format(
params["obs_date_start"]
):
pass
if params['obs_date_end'] and utils.is_valid_datetime_format(params['obs_date_end']):
if params["obs_date_end"] and utils.is_valid_datetime_format(
params["obs_date_end"]
):
pass
if create_time is not None:
params['create_time_start'], params['create_time_end'] = create_time
utils.is_valid_datetime_format(params['create_time_start']) or utils.is_valid_datetime_format(params['create_time_end'])
params["create_time_start"], params["create_time_end"] = create_time
utils.is_valid_datetime_format(
params["create_time_start"]
) or utils.is_valid_datetime_format(params["create_time_end"])
return request.post("/api/level0", params)
def get_by_id(_id: str) -> Result:
"""
根据内部ID获取0级数据
......@@ -98,6 +116,21 @@ def get_by_id(_id: str) -> Result:
"""
return request.get(f"/api/level0/_id/{_id}")
def get_by_uuid(_uuid: str) -> Result:
"""
根据UUID获取0级数据
Args:
_uuid (str): 0级数据的UUID
Returns:
Result: 查询结果
"""
return request.get(f"/api/level0/uuid/{_uuid}")
def find_by_level0_id(level0_id: str) -> Result:
"""
通过 level0 的 ID 查询0级数据
......@@ -111,6 +144,7 @@ def find_by_level0_id(level0_id: str) -> Result:
"""
return request.get(f"/api/level0/{level0_id}")
def update_qc_status(level0_id: str, qc_status: int, dataset: str) -> Result:
"""
更新0级数据的QC状态
......@@ -123,7 +157,11 @@ def update_qc_status(level0_id: str, qc_status: int, dataset: str) -> Result:
Returns:
Result: 更新结果
"""
return request.put(f"/api/level0/qc_status/{level0_id}", {'qc_status': qc_status, 'dataset': dataset})
return request.put(
f"/api/level0/qc_status/{level0_id}",
{"qc_status": qc_status, "dataset": dataset},
)
def update_qc_status_by_ids(ids: List[str], qc_status: int) -> Result:
"""
......@@ -136,9 +174,14 @@ def update_qc_status_by_ids(ids: List[str], qc_status: int) -> Result:
Returns:
Result: 更新结果
"""
return request.put("/api/level0/qc_status/batch/update", {'qc_status': qc_status, 'ids': ids})
return request.put(
"/api/level0/qc_status/batch/update", {"qc_status": qc_status, "ids": ids}
)
def update_prc_status(level0_id: str, dag_run: str, prc_status: int, dataset: str) -> Result:
def update_prc_status(
level0_id: str, dag_run: str, prc_status: int, dataset: str
) -> Result:
"""
更新0级数据的处理状态
......@@ -151,7 +194,11 @@ def update_prc_status(level0_id: str, dag_run: str, prc_status: int, dataset: st
Returns:
Result: 操作结果
"""
return request.put(f"/api/level0/prc_status/{level0_id}/{dag_run}", {'prc_status': prc_status, 'dataset': dataset})
return request.put(
f"/api/level0/prc_status/{level0_id}/{dag_run}",
{"prc_status": prc_status, "dataset": dataset},
)
def update_prc_status_by_ids(ids: List[str], prc_status: int) -> Result:
"""
......@@ -164,11 +211,14 @@ def update_prc_status_by_ids(ids: List[str], prc_status: int) -> Result:
Returns:
Result: 操作结果
"""
return request.put("/api/level0/prc_status/batch", {'prc_status': prc_status, 'ids': ids})
return request.put(
"/api/level0/prc_status/batch", {"prc_status": prc_status, "ids": ids}
)
def write(local_file: str,
dataset: str = constants.DEFAULT_DATASET,
**kwargs) -> Result:
def write(
local_file: str, dataset: str = constants.DEFAULT_DATASET, **kwargs
) -> Result:
"""
将本地文件写入DFS中
......@@ -182,16 +232,17 @@ def write(local_file: str,
"""
params = {
'dataset': dataset,
"dataset": dataset,
}
params.update(kwargs)
if not os.path.exists(local_file):
raise FileNotFoundError(local_file)
return request.post_file("/api/level0/file", local_file, params)
def write_cat(local_file: str,
dataset: str = constants.DEFAULT_DATASET,
**kwargs) -> Result:
def write_cat(
local_file: str, dataset: str = constants.DEFAULT_DATASET, **kwargs
) -> Result:
"""
主巡天仿真数据的星表本地文件写入DFS中
......@@ -205,14 +256,16 @@ def write_cat(local_file: str,
"""
params = {
'dataset': dataset,
"dataset": dataset,
}
params.update(kwargs)
if not os.path.exists(local_file):
raise FileNotFoundError(local_file)
return request.post_file("/api/level0/cat/file", local_file, params)
def find_process(dag: Optional[str] = None,
def find_process(
dag: Optional[str] = None,
dag_run: Optional[str] = None,
batch_id: Optional[str] = None,
level0_id: Optional[str] = None,
......@@ -221,7 +274,8 @@ def find_process(dag: Optional[str] = None,
prc_status: Optional[int] = None,
prc_date: Optional[DateTimeTuple] = None,
page: int = 1,
limit: int = 0) -> Result:
limit: int = 0,
) -> Result:
"""
查询0级数据处理过程
......@@ -242,27 +296,33 @@ def find_process(dag: Optional[str] = None,
"""
params = {
'dag': dag,
'dag_run': dag_run,
'batch_id': batch_id,
'level0_id': level0_id,
'dataset': dataset,
'prc_module': prc_module,
'prc_status': prc_status,
'prc_date_start': None,
'prc_date_end': None,
'page': page,
'limit': limit
"dag": dag,
"dag_run": dag_run,
"batch_id": batch_id,
"level0_id": level0_id,
"dataset": dataset,
"prc_module": prc_module,
"prc_status": prc_status,
"prc_date_start": None,
"prc_date_end": None,
"page": page,
"limit": limit,
}
if prc_date is not None:
params['prc_date_start'], params['prc_date_end'] = prc_date
if params['prc_date_start'] and utils.is_valid_datetime_format(params['prc_date_start']):
params["prc_date_start"], params["prc_date_end"] = prc_date
if params["prc_date_start"] and utils.is_valid_datetime_format(
params["prc_date_start"]
):
pass
if params['prc_date_end'] and utils.is_valid_datetime_format(params['prc_date_end']):
if params["prc_date_end"] and utils.is_valid_datetime_format(
params["prc_date_end"]
):
pass
return request.post("/api/level0/process", params)
def add_process(level0_id: str,
def add_process(
level0_id: str,
dag: str,
dag_run: str,
batch_id: Optional[str] = None,
......@@ -270,7 +330,8 @@ def add_process(level0_id: str,
prc_status: int = -1024,
prc_date: str = utils.get_current_time(),
prc_module: str = "",
message: str = "") -> Result:
message: str = "",
) -> Result:
"""
添加0级数据处理过程
......@@ -290,19 +351,20 @@ def add_process(level0_id: str,
"""
params = {
'level0_id': level0_id,
'dag': dag,
'dag_run': dag_run,
'dataset': dataset,
'batch_id': batch_id,
'prc_date': prc_date,
'prc_status': prc_status,
'prc_module': prc_module,
'message': message,
"level0_id": level0_id,
"dag": dag,
"dag_run": dag_run,
"dataset": dataset,
"batch_id": batch_id,
"prc_date": prc_date,
"prc_status": prc_status,
"prc_module": prc_module,
"message": message,
}
utils.is_valid_datetime_format(prc_date)
return request.post("/api/level0/prc", params)
def new(data: dict) -> Result:
"""
新建0级数据,用于仿真数据测试
......
......@@ -23,13 +23,14 @@ def find(
ra_cen: Optional[int] = None,
dec_cen: Optional[int] = None,
radius: Optional[float] = None,
object_name: Optional[str] = None,
object: Optional[str] = None,
rss_id: Optional[str] = None,
cube_id: Optional[str] = None,
build: Optional[int] = None,
pmapname: Optional[str] = None,
page: int = 1,
limit: int = 0) -> Result:
limit: int = 0,
**extra_kwargs) -> Result:
"""
根据给定的参数搜索1级数据文件记录
......@@ -49,7 +50,7 @@ def find(
ra_cen (Optional[int], optional): 中心赤经. Defaults to None.
dec_cen (Optional[int], optional): 中心赤纬. Defaults to None.
radius (Optional[float], optional): 搜索半径. Defaults to None.
object_name (Optional[str], optional): 天体名称. Defaults to None.
object (Optional[str], optional): 天体名称. Defaults to None.
rss_id (Optional[str], optional): RSS ID (IFS) Defaults to None.
cube_id (Optional[str], optional): Cube ID (IFS). Defaults to None.
dataset (Optional[str], optional): 数据集名称. Defaults to None.
......@@ -79,7 +80,7 @@ def find(
'ra_cen': ra_cen,
'dec_cen': dec_cen,
'radius': radius,
'object_name': object_name,
'object': object,
'obs_date_start': None,
'obs_date_end': None,
'create_time_start': None,
......@@ -93,6 +94,7 @@ def find(
'page': page,
'limit': limit,
}
params.update(extra_kwargs)
if obs_date is not None:
params['obs_date_start'], params['obs_date_end'] = obs_date
......
......@@ -9,7 +9,7 @@ def find(mode: Optional[str] = None,
obs_id: Optional[str] = None,
instrument: Literal['MSC', 'IFS', 'MCI', 'HSTDM', 'CPIC'] = 'MSC',
obs_type: Optional[str] = None,
object_name: Optional[str] = None,
object: Optional[str] = None,
obstime: Optional[DateTimeTuple] = None,
dataset: Optional[str] = None,
page: int = 1,
......@@ -24,7 +24,7 @@ def find(mode: Optional[str] = None,
obs_id (Optional[str], optional): 观测ID,支持模糊搜索. Defaults to None.
instrument (Optional[str], optional): 模块ID,如'MSC', 'IFS'. Defaults to None.
obs_type (Optional[str], optional): 观测类型,如主巡天宽场、TOO观测、定标星场等等. Defaults to None.
object_name (Optional[str], optional): 目标名称. Defaults to None.
object (Optional[str], optional): 目标名称. Defaults to None.
obstime (Optional[DateTimeTuple], optional): 观测时间范围. 如("2021-08-30 00:00:00", "2024-12-30 23:59:59"),Defaults to None.
dataset (Optional[str], optional): 数据集名称. Defaults to constants.DEFAULT_DATASET.
page (int, optional): 页码. Defaults to 1.
......@@ -42,7 +42,7 @@ def find(mode: Optional[str] = None,
'obs_id': obs_id,
'instrument': instrument,
'obs_type': obs_type,
'object_name': object_name,
'object': object,
'obs_time_start': None,
'obs_time_end': None,
'create_time_start': None,
......
......@@ -7,64 +7,69 @@ class Level0TestCase(unittest.TestCase):
def setUp(self):
pass
# def test_find(self):
# start_time = time.time()
# result = level0.find(instrument='MSC', dataset="msc-v093")
# print(f"1操作执行时间: {time.time() - start_time} 秒")
# start_time = time.time()
# result = level0.find(instrument='MSC', dataset="msc-v093",
# ra_obj = 170,
# dec_obj = -24,
# radius = 1)
# print(f"2操作执行时间: {time.time() - start_time} 秒, 数据量:{result['total_count']}")
# start_time = time.time()
# result = level0.find(instrument='MSC', file_name="CSST_MSC_MS_SCI_20240609181116_20240609181347_10109100157991_27_L0_V01.fits")
# print(f"3操作执行时间: {time.time() - start_time} 秒, 数据量:{result['total_count']}")
# self.assertEqual(result.code, 200, "error code: " + result.message)
# self.assertIsNotNone(result.data, "error message: " + result.message)
def test_find(self):
start_time = time.time()
result = level0.find(instrument='MSC', dataset="msc-v093", CRVAL1=170)
print(f"1操作执行时间: {time.time() - start_time} 秒")
start_time = time.time()
result = level0.find(instrument='MSC', dataset="msc-v093",
ra_obj = 170,
dec_obj = -24,
radius = 1)
print(f"2操作执行时间: {time.time() - start_time} 秒, 数据量:{result['total_count']}")
start_time = time.time()
result = level0.find(instrument='MSC', file_name="CSST_MSC_MS_SCI_20240609181116_20240609181347_10109100157991_27_L0_V01.fits")
print(f"3操作执行时间: {time.time() - start_time} 秒, 数据量:{result['total_count']}")
self.assertEqual(result.code, 200, "error code: " + result.message)
self.assertIsNotNone(result.data, "error message: " + result.message)
# def test_find_by_level0_id(self):
# result = level0.find_by_level0_id(level0_id = "1060940003452925")
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# self.assertIsNotNone(result.data, "error message: " + result.message)
def test_get_by_uuid(self):
result = level0.get_by_uuid(uuid = "0199d622-afd7-4ff7-a70e-95378a1ca638")
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
self.assertIsNotNone(result.data, "error message: " + result.message)
def test_find_by_level0_id(self):
result = level0.find_by_level0_id(level0_id = "1060940003452925")
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
self.assertIsNotNone(result.data, "error message: " + result.message)
# def test_update_qc_status(self):
# result = level0.update_qc_status(level0_id = "1010910015799127", qc_status=1, dataset="msc-v093")
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_update_qc_status(self):
result = level0.update_qc_status(level0_id = "1010910015799127", qc_status=1, dataset="msc-v093")
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# def test_update_qc_status_by_ids(self):
# result = level0.update_qc_status_by_ids(ids = ["676ac74a530b47ca41568858"], qc_status=4)
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_update_qc_status_by_ids(self):
result = level0.update_qc_status_by_ids(ids = ["676ac74a530b47ca41568858"], qc_status=4)
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# def test_update_prc_status(self):
# result = level0.update_prc_status(level0_id = "1010910015799127", dag_run="202411071002481234", prc_status=3)
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_update_prc_status(self):
result = level0.update_prc_status(level0_id = "1010910015799127", dag_run="202411071002481234", prc_status=3)
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_update_prc_status_by_ids(self):
result = level0.update_prc_status_by_ids(ids = ["676ac74a530b47ca41568858"], prc_status=4)
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# def test_write(self):
# file_path = "/Users/wsl/temp/csst/import/CSST_MSC_MS_SCI_20240609181116_20240609181347_10109100157991_27_L0_V01.fits"
# result = level0.write(local_file = file_path, dataset= 'msc-v093')
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_write(self):
file_path = "/Users/wsl/temp/csst/import/CSST_MSC_MS_SCI_20240609181116_20240609181347_10109100157991_27_L0_V01.fits"
result = level0.write(local_file = file_path, dataset= 'msc-v093')
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# def test_find_process(self):
# result = level0.find_process(level0_id="1060940003452925")
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# def test_add_process(self):
# result = level0.add_process(level0_id="1060940003452925",
# dag="csst-msc-l1-mbi",
# dag_run="202411071002481234",
# dataset="v93",
# batch_id="v930batch",
# prc_date="2024-11-07 10:24:12", prc_status=1, prc_module="MSC", message="")
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
\ No newline at end of file
def test_find_process(self):
result = level0.find_process(level0_id="1060940003452925")
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_add_process(self):
result = level0.add_process(level0_id="1060940003452925",
dag="csst-msc-l1-mbi",
dag_run="202411071002481234",
dataset="v93",
batch_id="v930batch",
prc_date="2024-11-07 10:24:12", prc_status=1, prc_module="MSC", message="")
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment