Commit 2e34f283 authored by Xie Zhou's avatar Xie Zhou
Browse files

add gpu interface

parent 1e309842
...@@ -131,3 +131,6 @@ dmypy.json ...@@ -131,3 +131,6 @@ dmypy.json
# Pyre type checker # Pyre type checker
.pyre/ .pyre/
# joblib temp
joblib*
\ No newline at end of file
...@@ -97,7 +97,7 @@ class CsstMscInstrumentProc(CsstProcessor): ...@@ -97,7 +97,7 @@ class CsstMscInstrumentProc(CsstProcessor):
if self._switches['deepcr']: if self._switches['deepcr']:
clean_model = DEEPCR_MODEL_PATH clean_model = DEEPCR_MODEL_PATH
inpaint_model = 'ACS-WFC-F606W-2-32' inpaint_model = 'ACS-WFC-F606W-2-32'
model = deepCR(clean_model, inpaint_model, device='CPU', hidden=50) model = deepCR(clean_model, inpaint_model, device=self.device, hidden=50)
if self.n_jobs > 1: if self.n_jobs > 1:
masked, cleaned = model.clean( masked, cleaned = model.clean(
self.__img, threshold=0.5, inpaint=True, binary=True, segment=True, patch=256, parallel=True, self.__img, threshold=0.5, inpaint=True, binary=True, segment=True, patch=256, parallel=True,
...@@ -126,10 +126,7 @@ class CsstMscInstrumentProc(CsstProcessor): ...@@ -126,10 +126,7 @@ class CsstMscInstrumentProc(CsstProcessor):
psfbeta=4.765, psfbeta=4.765,
verbose=False, verbose=False,
gain_apply=True) gain_apply=True)
masked = masked.astype(np.uint16)
print(self.__flg, type(self.__flg), self.__flg.dtype)
print(masked, type(masked), masked.dtype)
self.__flg = self.__flg | (masked * 16) self.__flg = self.__flg | (masked * 16)
if self._switches['clean']: if self._switches['clean']:
self.__img = cleaned self.__img = cleaned
...@@ -151,9 +148,10 @@ class CsstMscInstrumentProc(CsstProcessor): ...@@ -151,9 +148,10 @@ class CsstMscInstrumentProc(CsstProcessor):
weight[self.__flg > 0] = 0 weight[self.__flg > 0] = 0
self.__wht = weight self.__wht = weight
def prepare(self, n_jobs=2, n_threads=1, **kwargs): def prepare(self, n_jobs=2, n_threads=1, device='CPU', **kwargs):
self.n_jobs = n_jobs self.n_jobs = n_jobs
self.set_num_threads(n_threads) self.set_num_threads(n_threads)
self.device = device
for name in kwargs: for name in kwargs:
self._switches[name] = kwargs[name] self._switches[name] = kwargs[name]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment