diff --git a/crmask.py b/crmask.py index 6c9280ea9055c13cb2562c9728927cc01a60abe5..e3ef2eca61bb2af4e11aedf35413c7c2e1c65ed1 100644 --- a/crmask.py +++ b/crmask.py @@ -9,7 +9,7 @@ #Software package dependencies, numpy, astropy, ccdproc and deepCR #Installation dependencies on Ubuntu 20.04 #apt install python3-numpy python3-astropy python3-ccdproc python3-pip -#python3 -m pip install deepCR +#python3 -m pip install pytorch deepCR #Version 0.2 #changelog @@ -66,7 +66,7 @@ class CRMask: gpu_flag : (optional) boolean whether use GPU, default is False config_path : (optional) string - configuration file path, default is ``./crmask.ini`` + configuration file path, default is ``../conf/MSC_crmask.ini`` """ self.model = model if model == 'deepCR_train': @@ -181,6 +181,9 @@ class CRMask: self.config = config def cr_mask_lacosmic(self): + """ + This method is called by `cr_mask`, do NOT use it directly. + """ config = self.config @@ -315,6 +318,9 @@ class CRMask: return masked_hdulist def cr_mask_deepCR(self): + """ + This method is called by `cr_mask`, do NOT use it directly. + """ from deepCR import deepCR config = self.config @@ -432,9 +438,24 @@ class CRMask: else: return masked_hdulist - #return a hdulist of a masked image, and a hdulist of a cleaned image if fill_flag is set. def cr_mask(self): - #here do cr mask task + """ + Cosmic ray detection and mask. + + Returns + ------- + masked : numpy.ndarray + cosmic ray masked image. + cleaned : numpy.ndarray, optional + Only returned if `fill_flag` is True + + cosmic ray cleaned image. + Examples + ------- + >>> from crmask import CRMask + >>> crobj = CRMask('xxxx.fits', 'deepCR') + >>> crobj.cr_mask() + """ if self.model == 'lacosmic': if self.fill_flag: masked, cleaned = CRMask.cr_mask_lacosmic(self) @@ -503,6 +524,9 @@ class CRMask: return masked def cr_train_deepCR_image_to_ndarray(self, image_sets, patch): + """ + This method is called by `cr_train`, do NOT use it directly. + """ if isinstance(image_sets, str): if image_sets[-4:] == '.npy': @@ -557,6 +581,9 @@ class CRMask: return input_image def cr_train_deepCR_prepare_data(self, patch): + """ + This method is called by `cr_train`, do NOT use it directly. + """ if self.image_sets != None: self.training_image = CRMask.cr_train_deepCR_image_to_ndarray(self, self.image_sets, patch) np.save('image.npy', self.training_image) @@ -581,6 +608,9 @@ class CRMask: def cr_train_deepCR(self): + """ + This method is called by `cr_train`, do NOT use it directly. + """ from deepCR import train @@ -679,12 +709,31 @@ class CRMask: print('Training completes!') def cr_train(self): + """ + Training models, only support ``deepCR_train``. It will generate pytorch's *.pth file. + The train is very painful and time consuming, do NOT use it in pipelines. + + Returns + ------- + No returns + + Examples + ------- + >>> from crmask import CRMask + >>> imglist = ['MSC_MS_210525170000_100000010_23_sci.fits', 'MSC_MS_210525171000_100000011_23_sci.fits', 'MSC_MS_210525172000_100000012_23_sci.fits', 'MSC_MS_210525173000_100000013_23_sci.fits', 'MSC_MS_210525174000_100000014_23_sci.fits', 'MSC_MS_210525175000_100000015_23_sci.fits', 'MSC_MS_210525180000_100000016_23_sci.fits', 'MSC_MS_210525181000_100000017_23_sci.fits', 'MSC_MS_210525182000_100000018_23_sci.fits', 'MSC_MS_210525183000_100000019_23_sci.fits'] + >>> masklist = ['MSC_CRM_210525170000_100000010_23_raw.fits', 'MSC_CRM_210525171000_100000011_23_raw.fits', 'MSC_CRM_210525172000_100000012_23_raw.fits', 'MSC_CRM_210525173000_100000013_23_raw.fits', 'MSC_CRM_210525174000_100000014_23_raw.fits', 'MSC_CRM_210525175000_100000015_23_raw.fits', 'MSC_CRM_210525180000_100000016_23_raw.fits', 'MSC_CRM_210525181000_100000017_23_raw.fits', 'MSC_CRM_210525182000_100000018_23_raw.fits', 'MSC_CRM_210525183000_100000019_23_raw.fits'] + >>> trainobj = CRMask(imglist, mask = masklist, model = 'deepCR_train') + >>> trainobj.cr_train() + """ if self.model == 'deepCR_train': CRMask.cr_train_deepCR(self) else: raise ValueError('Unsupported training model') def cr_benchmark(self): + """ + Do NOT use this method, just for internal test. + """ if isinstance(self.res, str) or isinstance(self.res, Path): hdulist = pyfits.open(self.res) res = hdulist[1].data