Commit b385a61b authored by BO ZHANG's avatar BO ZHANG 🏀
Browse files

add dm.cleanup and dm.n_jobs_gpu

parent c63576c4
...@@ -111,7 +111,9 @@ class CsstMsDataManager: ...@@ -111,7 +111,9 @@ class CsstMsDataManager:
verbose : bool verbose : bool
If True, print verbose info. If True, print verbose info.
n_jobs : int n_jobs : int
The number of jobs. The number of CPU jobs.
n_jobs_gpu : int
The number of GPU jobs.
backend : str backend : str
The joblib backend. The joblib backend.
device : str device : str
...@@ -162,6 +164,7 @@ class CsstMsDataManager: ...@@ -162,6 +164,7 @@ class CsstMsDataManager:
clear_dir: bool = False, clear_dir: bool = False,
verbose: bool = True, verbose: bool = True,
n_jobs: int = 18, n_jobs: int = 18,
n_jobs_gpu: int = 1,
backend: str = "multiprocessing", backend: str = "multiprocessing",
device: str = "CPU", device: str = "CPU",
stamps: str = "", stamps: str = "",
...@@ -228,7 +231,9 @@ class CsstMsDataManager: ...@@ -228,7 +231,9 @@ class CsstMsDataManager:
# record hard code names in history # record hard code names in history
self.hardcode_history = [] self.hardcode_history = []
# parallel configuration
self.n_jobs = n_jobs self.n_jobs = n_jobs
self.n_jobs_gpu = n_jobs_gpu
self.backend = backend self.backend = backend
self.device = device self.device = device
...@@ -241,16 +246,30 @@ class CsstMsDataManager: ...@@ -241,16 +246,30 @@ class CsstMsDataManager:
if clear_dir: if clear_dir:
self.clear_dir(self.dir_l1) self.clear_dir(self.dir_l1)
# L1 whitelist
self.l1_whitelist = []
for detector in self.target_detectors:
if self.datatype == "mbi":
self.l1_whitelist.append(self.l1_detector(detector=detector, post="img_L1.fits"))
self.l1_whitelist.append(self.l1_detector(detector=detector, post="wht_L1.fits"))
self.l1_whitelist.append(self.l1_detector(detector=detector, post="flg_L1.fits"))
self.l1_whitelist.append(self.l1_detector(detector=detector, post="cat.fits"))
self.l1_whitelist.append(self.l1_detector(detector=detector, post="psf.fits"))
elif self.datatype == "sls":
self.l1_whitelist.append(self.l1_detector(detector=detector, post="L1_1.fits"))
# pipeline logger # pipeline logger
if log_ppl == "": if log_ppl == "":
self.logger_ppl = get_logger(name="CSST L1 Pipeline Logger", filename="") self.logger_ppl = get_logger(name="CSST L1 Pipeline Logger", filename="")
else: else:
self.logger_ppl = get_logger(name="CSST L1 Pipeline Logger", filename=os.path.join(dir_l1, log_ppl)) self.logger_ppl = get_logger(name="CSST L1 Pipeline Logger", filename=os.path.join(dir_l1, log_ppl))
self.l1_whitelist.append(log_ppl)
# module logger # module logger
if log_mod == "": if log_mod == "":
self.logger_mod = get_logger(name="CSST L1 Module Logger", filename="") self.logger_mod = get_logger(name="CSST L1 Module Logger", filename="")
else: else:
self.logger_mod = get_logger(name="CSST L1 Module Logger", filename=os.path.join(dir_l1, log_mod)) self.logger_mod = get_logger(name="CSST L1 Module Logger", filename=os.path.join(dir_l1, log_mod))
self.l1_whitelist.append(log_mod)
self.custom_bias = None self.custom_bias = None
self.custom_dark = None self.custom_dark = None
...@@ -891,6 +910,15 @@ class CsstMsDataManager: ...@@ -891,6 +910,15 @@ class CsstMsDataManager:
""" Query L1 data from DFS. """ """ Query L1 data from DFS. """
pass pass
def l1_cleanup(self):
filelist = glob.glob(f"{self.dir_l1}/**")
for file in filelist:
if file not in self.l1_whitelist:
try:
os.remove(file)
except:
pass
# temporarily compatible with old interface # temporarily compatible with old interface
CsstMbiDataManager = CsstMsDataManager CsstMbiDataManager = CsstMsDataManager
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment