Commit 9f071275 authored by BO ZHANG's avatar BO ZHANG 🏀
Browse files

add dm.device

parent 78a17d26
...@@ -115,6 +115,8 @@ class CsstMsDataManager: ...@@ -115,6 +115,8 @@ class CsstMsDataManager:
The number of jobs. The number of jobs.
backend : str backend : str
The joblib backend. The joblib backend.
device : str
The device for neural network. "CPU" or "GPU".
Examples Examples
-------- --------
...@@ -161,7 +163,8 @@ class CsstMsDataManager: ...@@ -161,7 +163,8 @@ class CsstMsDataManager:
clear_dir=False, clear_dir=False,
verbose=True, verbose=True,
n_jobs=18, n_jobs=18,
backend="multiprocessing" backend="multiprocessing",
device="CPU"
): ):
# set DFS log dir # set DFS log dir
...@@ -227,6 +230,7 @@ class CsstMsDataManager: ...@@ -227,6 +230,7 @@ class CsstMsDataManager:
self.n_jobs = n_jobs self.n_jobs = n_jobs
self.backend = backend self.backend = backend
self.device = device
# aXe # aXe
self.set_env() self.set_env()
...@@ -331,7 +335,8 @@ class CsstMsDataManager: ...@@ -331,7 +335,8 @@ class CsstMsDataManager:
log_ppl="csst-l1ppl.log", log_ppl="csst-l1ppl.log",
log_mod="csst-l1mod.log", log_mod="csst-l1mod.log",
n_jobs=18, n_jobs=18,
backend="multiprocessing" backend="multiprocessing",
device="CPU"
): ):
""" initialize the multi-band imaging data manager """ """ initialize the multi-band imaging data manager """
...@@ -376,7 +381,8 @@ class CsstMsDataManager: ...@@ -376,7 +381,8 @@ class CsstMsDataManager:
log_ppl=log_ppl, log_ppl=log_ppl,
log_mod=log_mod, log_mod=log_mod,
n_jobs=n_jobs, n_jobs=n_jobs,
backend=backend backend=backend,
device=device
) )
@staticmethod @staticmethod
...@@ -788,7 +794,8 @@ class CsstMsDataManager: ...@@ -788,7 +794,8 @@ class CsstMsDataManager:
clear_l1=False, clear_l1=False,
dfs_root="/share/dfs", dfs_root="/share/dfs",
n_jobs=18, n_jobs=18,
backend="multiprocessing" backend="multiprocessing",
device="CPU"
): ):
""" Initialize CsstMsDataManager from DFS. """ """ Initialize CsstMsDataManager from DFS. """
# (clear and) make directories # (clear and) make directories
...@@ -830,7 +837,8 @@ class CsstMsDataManager: ...@@ -830,7 +837,8 @@ class CsstMsDataManager:
use_dfs=use_dfs, use_dfs=use_dfs,
dfs_node=dfs_node, dfs_node=dfs_node,
n_jobs=n_jobs, n_jobs=n_jobs,
backend=backend backend=backend,
device=device
) )
assert dm.obs_id == obs_id assert dm.obs_id == obs_id
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment