diff --git a/csst_common/data_manager.py b/csst_common/data_manager.py index 59d361ca587cff0ef102f1edb7fe821a22b2188f..e56ce94b694f94e7e5ce7251b7bc372e1ed81ab1 100644 --- a/csst_common/data_manager.py +++ b/csst_common/data_manager.py @@ -115,6 +115,8 @@ class CsstMsDataManager: The number of jobs. backend : str The joblib backend. + device : str + The device for neural network. "CPU" or "GPU". Examples -------- @@ -161,7 +163,8 @@ class CsstMsDataManager: clear_dir=False, verbose=True, n_jobs=18, - backend="multiprocessing" + backend="multiprocessing", + device="CPU" ): # set DFS log dir @@ -227,6 +230,7 @@ class CsstMsDataManager: self.n_jobs = n_jobs self.backend = backend + self.device = device # aXe self.set_env() @@ -331,7 +335,8 @@ class CsstMsDataManager: log_ppl="csst-l1ppl.log", log_mod="csst-l1mod.log", n_jobs=18, - backend="multiprocessing" + backend="multiprocessing", + device="CPU" ): """ initialize the multi-band imaging data manager """ @@ -376,7 +381,8 @@ class CsstMsDataManager: log_ppl=log_ppl, log_mod=log_mod, n_jobs=n_jobs, - backend=backend + backend=backend, + device=device ) @staticmethod @@ -788,7 +794,8 @@ class CsstMsDataManager: clear_l1=False, dfs_root="/share/dfs", n_jobs=18, - backend="multiprocessing" + backend="multiprocessing", + device="CPU" ): """ Initialize CsstMsDataManager from DFS. """ # (clear and) make directories @@ -830,7 +837,8 @@ class CsstMsDataManager: use_dfs=use_dfs, dfs_node=dfs_node, n_jobs=n_jobs, - backend=backend + backend=backend, + device=device ) assert dm.obs_id == obs_id