Commit dc71ee77 authored by Wei Shoulin's avatar Wei Shoulin
Browse files

feat(csst_dfs_client): enhance level0 and level1 search APIs

- Add support for data_uuid and data_model parameters in level0.find
- Rename object_name parameter to object in level0, level1, and plan modules
- Support extra keyword arguments in level0.find and level1.find
- Add get_by_uuid function to retrieve level0 data by UUID
- Update tests to reflect API changes and enable previously commented tests
- Improve parameter handling for datetime ranges in search functions
parent 93de0971
Pipeline #10639 canceled with stages
in 0 seconds
import os
from typing import Optional, Tuple, Literal, List
from .common import request, Result, utils, constants
from typing import List, Literal, Optional, Tuple
from .common import Result, constants, request, utils
DateTimeTuple = Tuple[str, str]
def find(
instrument: Literal['MSC', 'IFS', 'MCI', 'HSTDM', 'CPIC'],
dataset: str,
obs_group: Optional[str] = None,
obs_id: Optional[str] = None,
detector: Optional[str] = None,
obs_type: Optional[str] = None,
filter: Optional[str] = None,
obs_date: Optional[DateTimeTuple] = None,
create_time: Optional[DateTimeTuple] = None,
qc_status: Optional[int] = None,
prc_status: Optional[int] = None,
file_name: Optional[str] = None,
ra_obj: Optional[int] = None,
dec_obj: Optional[int] = None,
radius: Optional[float] = None,
object_name: Optional[str] = None,
page: int = 1,
limit: int = 0) -> Result:
dataset: str,
instrument: Literal["MSC", "IFS", "MCI", "HSTDM", "CPIC"],
data_model: str = "raw",
data_uuid: Optional[str] = None,
obs_group: Optional[str] = None,
obs_id: Optional[str] = None,
detector: Optional[str] = None,
obs_type: Optional[str] = None,
filter: Optional[str] = None,
obs_date: Optional[DateTimeTuple] = None,
create_time: Optional[DateTimeTuple] = None,
qc_status: Optional[int] = None,
prc_status: Optional[int] = None,
file_name: Optional[str] = None,
ra_obj: Optional[int] = None,
dec_obj: Optional[int] = None,
radius: Optional[float] = None,
object: Optional[str] = None,
page: int = 1,
limit: int = 0,
**extra_kwargs,
) -> Result:
"""
根据给定的参数搜索0级数据文件记录
Args:
data_uuid (Optional[str], optional): 数据UUID. Defaults to None.
data_model (str): 数据模型. Defaults is 'raw'.
instrument (str): 设备,必需为'MSC', 'IFS', 'MCI', 'HSTDM', 'CPIC'之一.
obs_group (Optional[str], optional): 项目ID. Defaults to None.
obs_id (Optional[str], optional): 观测ID. Defaults to None.
......@@ -41,190 +49,236 @@ def find(
ra_obj (Optional[int], optional): 目标赤经. Defaults to None.
dec_obj (Optional[int], optional): 目标赤纬. Defaults to None.
radius (Optional[float], optional): 搜索半径. Defaults to None.
object_name (Optional[str], optional): 目标名称. Defaults to None.
object (Optional[str], optional): 目标名称. Defaults to None.
dataset (Optional[str], optional): 数据集名称. Defaults to constants.DEFAULT_DATASET.
page (int, optional): 页码. Defaults to 1.
limit (int, optional): 每页数量. Defaults to 0,不限制.
Returns:
Result: 搜索结果对象.
"""
params = {
'obs_group': obs_group,
'obs_id': obs_id,
'instrument': instrument,
'detector': detector,
'obs_type': obs_type,
'filter': filter,
'qc_status': qc_status,
'prc_status': prc_status,
'file_name': file_name,
'ra_obj': ra_obj,
'dec_obj': dec_obj,
'radius': radius,
'object_name': object_name,
'obs_date_start': None,
'obs_date_end': None,
'create_time_start': None,
'create_time_end': None,
'dataset': dataset,
'page': page,
'limit': limit,
"data_uuid": data_uuid,
"data_model": data_model,
"obs_group": obs_group,
"obs_id": obs_id,
"instrument": instrument,
"detector": detector,
"obs_type": obs_type,
"filter": filter,
"qc_status": qc_status,
"prc_status": prc_status,
"file_name": file_name,
"ra_obj": ra_obj,
"dec_obj": dec_obj,
"radius": radius,
"object": object,
"obs_date_start": None,
"obs_date_end": None,
"create_time_start": None,
"create_time_end": None,
"dataset": dataset,
"page": page,
"limit": limit,
}
params.update(extra_kwargs)
if obs_date is not None:
params['obs_date_start'], params['obs_date_end'] = obs_date
if params['obs_date_start'] and utils.is_valid_datetime_format(params['obs_date_start']):
params["obs_date_start"], params["obs_date_end"] = obs_date
if params["obs_date_start"] and utils.is_valid_datetime_format(
params["obs_date_start"]
):
pass
if params["obs_date_end"] and utils.is_valid_datetime_format(
params["obs_date_end"]
):
pass
if params['obs_date_end'] and utils.is_valid_datetime_format(params['obs_date_end']):
pass
if create_time is not None:
params['create_time_start'], params['create_time_end'] = create_time
utils.is_valid_datetime_format(params['create_time_start']) or utils.is_valid_datetime_format(params['create_time_end'])
params["create_time_start"], params["create_time_end"] = create_time
utils.is_valid_datetime_format(
params["create_time_start"]
) or utils.is_valid_datetime_format(params["create_time_end"])
return request.post("/api/level0", params)
def get_by_id(_id: str) -> Result:
"""
根据内部ID获取0级数据
Args:
_id (str): 0级数据的内部ID
Returns:
Result: 查询结果
"""
return request.get(f"/api/level0/_id/{_id}")
def get_by_uuid(_uuid: str) -> Result:
"""
根据UUID获取0级数据
Args:
_uuid (str): 0级数据的UUID
Returns:
Result: 查询结果
"""
return request.get(f"/api/level0/uuid/{_uuid}")
def find_by_level0_id(level0_id: str) -> Result:
"""
通过 level0 的 ID 查询0级数据
Args:
level0_id (str): 0级数据的ID
Returns:
Result: 查询结果
"""
return request.get(f"/api/level0/{level0_id}")
def update_qc_status(level0_id: str, qc_status: int, dataset: str) -> Result:
"""
更新0级数据的QC状态
Args:
level0_id (str): 0级数据的ID
qc_status (int): QC状态
dataset (str): 数据集名称
Returns:
Result: 更新结果
"""
return request.put(f"/api/level0/qc_status/{level0_id}", {'qc_status': qc_status, 'dataset': dataset})
return request.put(
f"/api/level0/qc_status/{level0_id}",
{"qc_status": qc_status, "dataset": dataset},
)
def update_qc_status_by_ids(ids: List[str], qc_status: int) -> Result:
"""
根据内部_id,批量更新0级数据的QC状态
Args:
ids (List[str]): 内部_id列表
qc_status (int): QC状态
Returns:
Result: 更新结果
"""
return request.put("/api/level0/qc_status/batch/update", {'qc_status': qc_status, 'ids': ids})
return request.put(
"/api/level0/qc_status/batch/update", {"qc_status": qc_status, "ids": ids}
)
def update_prc_status(level0_id: str, dag_run: str, prc_status: int, dataset: str) -> Result:
def update_prc_status(
level0_id: str, dag_run: str, prc_status: int, dataset: str
) -> Result:
"""
更新0级数据的处理状态
Args:
level0_id (str): 0级数据的ID
dag_run (str): DAG运行标识
prc_status (int): 处理状态
dataset (str): 数据集名称
Returns:
Result: 操作结果
"""
return request.put(f"/api/level0/prc_status/{level0_id}/{dag_run}", {'prc_status': prc_status, 'dataset': dataset})
return request.put(
f"/api/level0/prc_status/{level0_id}/{dag_run}",
{"prc_status": prc_status, "dataset": dataset},
)
def update_prc_status_by_ids(ids: List[str], prc_status: int) -> Result:
"""
根据内部_id,批量更新0级数据的处理状态
Args:
ids (List[str]): 内部_id列表
prc_status (int): 处理状态
Returns:
Result: 操作结果
"""
return request.put("/api/level0/prc_status/batch", {'prc_status': prc_status, 'ids': ids})
return request.put(
"/api/level0/prc_status/batch", {"prc_status": prc_status, "ids": ids}
)
def write(local_file: str,
dataset: str = constants.DEFAULT_DATASET,
**kwargs) -> Result:
def write(
local_file: str, dataset: str = constants.DEFAULT_DATASET, **kwargs
) -> Result:
"""
将本地文件写入DFS中
Args:
local_file (str]): 文件路径
dataset (Optional[str], optional): 数据集名称. Defaults to None.
**kwargs: 额外的关键字参数,这些参数将传递给DFS
Returns:
Result: 操作的结果对象,包含操作是否成功以及相关的错误信息,成功返回数据对象
"""
params = {
'dataset': dataset,
"dataset": dataset,
}
params.update(kwargs)
if not os.path.exists(local_file):
raise FileNotFoundError(local_file)
raise FileNotFoundError(local_file)
return request.post_file("/api/level0/file", local_file, params)
def write_cat(local_file: str,
dataset: str = constants.DEFAULT_DATASET,
**kwargs) -> Result:
def write_cat(
local_file: str, dataset: str = constants.DEFAULT_DATASET, **kwargs
) -> Result:
"""
主巡天仿真数据的星表本地文件写入DFS中
Args:
local_file (str]): 文件路径
dataset (Optional[str], optional): 数据集名称. Defaults to None.
**kwargs: 额外的关键字参数,这些参数将传递给DFS
Returns:
Result: 操作的结果对象,包含操作是否成功以及相关的错误信息,成功返回数据对象
"""
params = {
'dataset': dataset,
"dataset": dataset,
}
params.update(kwargs)
if not os.path.exists(local_file):
raise FileNotFoundError(local_file)
raise FileNotFoundError(local_file)
return request.post_file("/api/level0/cat/file", local_file, params)
def find_process(dag: Optional[str] = None,
dag_run: Optional[str] = None,
batch_id: Optional[str] = None,
level0_id: Optional[str] = None,
dataset: Optional[str] = None,
prc_module: Optional[str] = None,
prc_status: Optional[int] = None,
prc_date: Optional[DateTimeTuple] = None,
page: int = 1,
limit: int = 0) -> Result:
def find_process(
dag: Optional[str] = None,
dag_run: Optional[str] = None,
batch_id: Optional[str] = None,
level0_id: Optional[str] = None,
dataset: Optional[str] = None,
prc_module: Optional[str] = None,
prc_status: Optional[int] = None,
prc_date: Optional[DateTimeTuple] = None,
page: int = 1,
limit: int = 0,
) -> Result:
"""
查询0级数据处理过程
Args:
dag (str): DAG标识
dag_run (str): DAG运行标识
......@@ -236,44 +290,51 @@ def find_process(dag: Optional[str] = None,
prc_date (DateTimeTuple): 处理时间范围
page (int): 页码,默认为1
limit (int): 每页数量 0: 不限制
Returns:
Result: 成功后,Result.data为数据列表,失败message为失败原因
"""
params = {
'dag': dag,
'dag_run': dag_run,
'batch_id': batch_id,
'level0_id': level0_id,
'dataset': dataset,
'prc_module': prc_module,
'prc_status': prc_status,
'prc_date_start': None,
'prc_date_end': None,
'page': page,
'limit': limit
"dag": dag,
"dag_run": dag_run,
"batch_id": batch_id,
"level0_id": level0_id,
"dataset": dataset,
"prc_module": prc_module,
"prc_status": prc_status,
"prc_date_start": None,
"prc_date_end": None,
"page": page,
"limit": limit,
}
if prc_date is not None:
params['prc_date_start'], params['prc_date_end'] = prc_date
if params['prc_date_start'] and utils.is_valid_datetime_format(params['prc_date_start']):
params["prc_date_start"], params["prc_date_end"] = prc_date
if params["prc_date_start"] and utils.is_valid_datetime_format(
params["prc_date_start"]
):
pass
if params["prc_date_end"] and utils.is_valid_datetime_format(
params["prc_date_end"]
):
pass
if params['prc_date_end'] and utils.is_valid_datetime_format(params['prc_date_end']):
pass
return request.post("/api/level0/process", params)
def add_process(level0_id: str,
dag: str,
dag_run: str,
batch_id: Optional[str] = None,
dataset: str = constants.DEFAULT_DATASET,
prc_status: int = -1024,
prc_date: str = utils.get_current_time(),
prc_module: str = "",
message: str = "") -> Result:
def add_process(
level0_id: str,
dag: str,
dag_run: str,
batch_id: Optional[str] = None,
dataset: str = constants.DEFAULT_DATASET,
prc_status: int = -1024,
prc_date: str = utils.get_current_time(),
prc_module: str = "",
message: str = "",
) -> Result:
"""
添加0级数据处理过程
Args:
level0_id (str): 0级数据的ID
dag (str): DAG标识
......@@ -284,34 +345,35 @@ def add_process(level0_id: str,
prc_status (int): 处理状态
prc_module (str): 处理模块
message (str): 处理消息
Returns:
Result: 成功后,Result.data为写入记录,失败message为失败原因
"""
params = {
'level0_id': level0_id,
'dag': dag,
'dag_run': dag_run,
'dataset': dataset,
'batch_id': batch_id,
'prc_date': prc_date,
'prc_status': prc_status,
'prc_module': prc_module,
'message': message,
"level0_id": level0_id,
"dag": dag,
"dag_run": dag_run,
"dataset": dataset,
"batch_id": batch_id,
"prc_date": prc_date,
"prc_status": prc_status,
"prc_module": prc_module,
"message": message,
}
utils.is_valid_datetime_format(prc_date)
return request.post("/api/level0/prc", params)
def new(data: dict) -> Result:
"""
新建0级数据,用于仿真数据测试
Args:
data (dict): 0级数据的字典表示
Returns:
Result: 成功后,Result.data为写入记录,失败message为失败原因
"""
return request.post("/api/level0/new", data)
\ No newline at end of file
return request.post("/api/level0/new", data)
......@@ -23,13 +23,14 @@ def find(
ra_cen: Optional[int] = None,
dec_cen: Optional[int] = None,
radius: Optional[float] = None,
object_name: Optional[str] = None,
object: Optional[str] = None,
rss_id: Optional[str] = None,
cube_id: Optional[str] = None,
build: Optional[int] = None,
pmapname: Optional[str] = None,
page: int = 1,
limit: int = 0) -> Result:
limit: int = 0,
**extra_kwargs) -> Result:
"""
根据给定的参数搜索1级数据文件记录
......@@ -49,7 +50,7 @@ def find(
ra_cen (Optional[int], optional): 中心赤经. Defaults to None.
dec_cen (Optional[int], optional): 中心赤纬. Defaults to None.
radius (Optional[float], optional): 搜索半径. Defaults to None.
object_name (Optional[str], optional): 天体名称. Defaults to None.
object (Optional[str], optional): 天体名称. Defaults to None.
rss_id (Optional[str], optional): RSS ID (IFS) Defaults to None.
cube_id (Optional[str], optional): Cube ID (IFS). Defaults to None.
dataset (Optional[str], optional): 数据集名称. Defaults to None.
......@@ -79,7 +80,7 @@ def find(
'ra_cen': ra_cen,
'dec_cen': dec_cen,
'radius': radius,
'object_name': object_name,
'object': object,
'obs_date_start': None,
'obs_date_end': None,
'create_time_start': None,
......@@ -93,6 +94,7 @@ def find(
'page': page,
'limit': limit,
}
params.update(extra_kwargs)
if obs_date is not None:
params['obs_date_start'], params['obs_date_end'] = obs_date
......
......@@ -9,7 +9,7 @@ def find(mode: Optional[str] = None,
obs_id: Optional[str] = None,
instrument: Literal['MSC', 'IFS', 'MCI', 'HSTDM', 'CPIC'] = 'MSC',
obs_type: Optional[str] = None,
object_name: Optional[str] = None,
object: Optional[str] = None,
obstime: Optional[DateTimeTuple] = None,
dataset: Optional[str] = None,
page: int = 1,
......@@ -24,7 +24,7 @@ def find(mode: Optional[str] = None,
obs_id (Optional[str], optional): 观测ID,支持模糊搜索. Defaults to None.
instrument (Optional[str], optional): 模块ID,如'MSC', 'IFS'. Defaults to None.
obs_type (Optional[str], optional): 观测类型,如主巡天宽场、TOO观测、定标星场等等. Defaults to None.
object_name (Optional[str], optional): 目标名称. Defaults to None.
object (Optional[str], optional): 目标名称. Defaults to None.
obstime (Optional[DateTimeTuple], optional): 观测时间范围. 如("2021-08-30 00:00:00", "2024-12-30 23:59:59"),Defaults to None.
dataset (Optional[str], optional): 数据集名称. Defaults to constants.DEFAULT_DATASET.
page (int, optional): 页码. Defaults to 1.
......@@ -42,7 +42,7 @@ def find(mode: Optional[str] = None,
'obs_id': obs_id,
'instrument': instrument,
'obs_type': obs_type,
'object_name': object_name,
'object': object,
'obs_time_start': None,
'obs_time_end': None,
'create_time_start': None,
......
......@@ -7,64 +7,69 @@ class Level0TestCase(unittest.TestCase):
def setUp(self):
pass
# def test_find(self):
# start_time = time.time()
# result = level0.find(instrument='MSC', dataset="msc-v093")
# print(f"1操作执行时间: {time.time() - start_time} 秒")
# start_time = time.time()
# result = level0.find(instrument='MSC', dataset="msc-v093",
# ra_obj = 170,
# dec_obj = -24,
# radius = 1)
# print(f"2操作执行时间: {time.time() - start_time} 秒, 数据量:{result['total_count']}")
# start_time = time.time()
# result = level0.find(instrument='MSC', file_name="CSST_MSC_MS_SCI_20240609181116_20240609181347_10109100157991_27_L0_V01.fits")
# print(f"3操作执行时间: {time.time() - start_time} 秒, 数据量:{result['total_count']}")
# self.assertEqual(result.code, 200, "error code: " + result.message)
# self.assertIsNotNone(result.data, "error message: " + result.message)
def test_find(self):
start_time = time.time()
result = level0.find(instrument='MSC', dataset="msc-v093", CRVAL1=170)
print(f"1操作执行时间: {time.time() - start_time} 秒")
start_time = time.time()
result = level0.find(instrument='MSC', dataset="msc-v093",
ra_obj = 170,
dec_obj = -24,
radius = 1)
print(f"2操作执行时间: {time.time() - start_time} 秒, 数据量:{result['total_count']}")
start_time = time.time()
result = level0.find(instrument='MSC', file_name="CSST_MSC_MS_SCI_20240609181116_20240609181347_10109100157991_27_L0_V01.fits")
print(f"3操作执行时间: {time.time() - start_time} 秒, 数据量:{result['total_count']}")
self.assertEqual(result.code, 200, "error code: " + result.message)
self.assertIsNotNone(result.data, "error message: " + result.message)
# def test_find_by_level0_id(self):
# result = level0.find_by_level0_id(level0_id = "1060940003452925")
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# self.assertIsNotNone(result.data, "error message: " + result.message)
def test_get_by_uuid(self):
result = level0.get_by_uuid(uuid = "0199d622-afd7-4ff7-a70e-95378a1ca638")
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
self.assertIsNotNone(result.data, "error message: " + result.message)
def test_find_by_level0_id(self):
result = level0.find_by_level0_id(level0_id = "1060940003452925")
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
self.assertIsNotNone(result.data, "error message: " + result.message)
# def test_update_qc_status(self):
# result = level0.update_qc_status(level0_id = "1010910015799127", qc_status=1, dataset="msc-v093")
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_update_qc_status(self):
result = level0.update_qc_status(level0_id = "1010910015799127", qc_status=1, dataset="msc-v093")
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# def test_update_qc_status_by_ids(self):
# result = level0.update_qc_status_by_ids(ids = ["676ac74a530b47ca41568858"], qc_status=4)
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_update_qc_status_by_ids(self):
result = level0.update_qc_status_by_ids(ids = ["676ac74a530b47ca41568858"], qc_status=4)
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# def test_update_prc_status(self):
# result = level0.update_prc_status(level0_id = "1010910015799127", dag_run="202411071002481234", prc_status=3)
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_update_prc_status(self):
result = level0.update_prc_status(level0_id = "1010910015799127", dag_run="202411071002481234", prc_status=3)
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_update_prc_status_by_ids(self):
result = level0.update_prc_status_by_ids(ids = ["676ac74a530b47ca41568858"], prc_status=4)
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# def test_write(self):
# file_path = "/Users/wsl/temp/csst/import/CSST_MSC_MS_SCI_20240609181116_20240609181347_10109100157991_27_L0_V01.fits"
# result = level0.write(local_file = file_path, dataset= 'msc-v093')
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_write(self):
file_path = "/Users/wsl/temp/csst/import/CSST_MSC_MS_SCI_20240609181116_20240609181347_10109100157991_27_L0_V01.fits"
result = level0.write(local_file = file_path, dataset= 'msc-v093')
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# def test_find_process(self):
# result = level0.find_process(level0_id="1060940003452925")
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
# def test_add_process(self):
# result = level0.add_process(level0_id="1060940003452925",
# dag="csst-msc-l1-mbi",
# dag_run="202411071002481234",
# dataset="v93",
# batch_id="v930batch",
# prc_date="2024-11-07 10:24:12", prc_status=1, prc_module="MSC", message="")
# print(result)
# self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
\ No newline at end of file
def test_find_process(self):
result = level0.find_process(level0_id="1060940003452925")
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
def test_add_process(self):
result = level0.add_process(level0_id="1060940003452925",
dag="csst-msc-l1-mbi",
dag_run="202411071002481234",
dataset="v93",
batch_id="v930batch",
prc_date="2024-11-07 10:24:12", prc_status=1, prc_module="MSC", message="")
print(result)
self.assertEqual(result.code, 200, "error code: " + str(result.code) + ", message: " + result.message)
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment