Commit 0fce31a8 authored by Hu Yi's avatar Hu Yi
Browse files

Update crmask.py, add numpy style docs to methods of cr_mask and cr_train.

parent 271ff743
......@@ -9,7 +9,7 @@
#Software package dependencies, numpy, astropy, ccdproc and deepCR
#Installation dependencies on Ubuntu 20.04
#apt install python3-numpy python3-astropy python3-ccdproc python3-pip
#python3 -m pip install deepCR
#python3 -m pip install pytorch deepCR
#Version 0.2
#changelog
......@@ -66,7 +66,7 @@ class CRMask:
gpu_flag : (optional) boolean
whether use GPU, default is False
config_path : (optional) string
configuration file path, default is ``./crmask.ini``
configuration file path, default is ``../conf/MSC_crmask.ini``
"""
self.model = model
if model == 'deepCR_train':
......@@ -181,6 +181,9 @@ class CRMask:
self.config = config
def cr_mask_lacosmic(self):
"""
This method is called by `cr_mask`, do NOT use it directly.
"""
config = self.config
......@@ -315,6 +318,9 @@ class CRMask:
return masked_hdulist
def cr_mask_deepCR(self):
"""
This method is called by `cr_mask`, do NOT use it directly.
"""
from deepCR import deepCR
config = self.config
......@@ -432,9 +438,24 @@ class CRMask:
else:
return masked_hdulist
#return a hdulist of a masked image, and a hdulist of a cleaned image if fill_flag is set.
def cr_mask(self):
#here do cr mask task
"""
Cosmic ray detection and mask.
Returns
-------
masked : numpy.ndarray
cosmic ray masked image.
cleaned : numpy.ndarray, optional
Only returned if `fill_flag` is True
cosmic ray cleaned image.
Examples
-------
>>> from crmask import CRMask
>>> crobj = CRMask('xxxx.fits', 'deepCR')
>>> crobj.cr_mask()
"""
if self.model == 'lacosmic':
if self.fill_flag:
masked, cleaned = CRMask.cr_mask_lacosmic(self)
......@@ -503,6 +524,9 @@ class CRMask:
return masked
def cr_train_deepCR_image_to_ndarray(self, image_sets, patch):
"""
This method is called by `cr_train`, do NOT use it directly.
"""
if isinstance(image_sets, str):
if image_sets[-4:] == '.npy':
......@@ -557,6 +581,9 @@ class CRMask:
return input_image
def cr_train_deepCR_prepare_data(self, patch):
"""
This method is called by `cr_train`, do NOT use it directly.
"""
if self.image_sets != None:
self.training_image = CRMask.cr_train_deepCR_image_to_ndarray(self, self.image_sets, patch)
np.save('image.npy', self.training_image)
......@@ -581,6 +608,9 @@ class CRMask:
def cr_train_deepCR(self):
"""
This method is called by `cr_train`, do NOT use it directly.
"""
from deepCR import train
......@@ -679,12 +709,31 @@ class CRMask:
print('Training completes!')
def cr_train(self):
"""
Training models, only support ``deepCR_train``. It will generate pytorch's *.pth file.
The train is very painful and time consuming, do NOT use it in pipelines.
Returns
-------
No returns
Examples
-------
>>> from crmask import CRMask
>>> imglist = ['MSC_MS_210525170000_100000010_23_sci.fits', 'MSC_MS_210525171000_100000011_23_sci.fits', 'MSC_MS_210525172000_100000012_23_sci.fits', 'MSC_MS_210525173000_100000013_23_sci.fits', 'MSC_MS_210525174000_100000014_23_sci.fits', 'MSC_MS_210525175000_100000015_23_sci.fits', 'MSC_MS_210525180000_100000016_23_sci.fits', 'MSC_MS_210525181000_100000017_23_sci.fits', 'MSC_MS_210525182000_100000018_23_sci.fits', 'MSC_MS_210525183000_100000019_23_sci.fits']
>>> masklist = ['MSC_CRM_210525170000_100000010_23_raw.fits', 'MSC_CRM_210525171000_100000011_23_raw.fits', 'MSC_CRM_210525172000_100000012_23_raw.fits', 'MSC_CRM_210525173000_100000013_23_raw.fits', 'MSC_CRM_210525174000_100000014_23_raw.fits', 'MSC_CRM_210525175000_100000015_23_raw.fits', 'MSC_CRM_210525180000_100000016_23_raw.fits', 'MSC_CRM_210525181000_100000017_23_raw.fits', 'MSC_CRM_210525182000_100000018_23_raw.fits', 'MSC_CRM_210525183000_100000019_23_raw.fits']
>>> trainobj = CRMask(imglist, mask = masklist, model = 'deepCR_train')
>>> trainobj.cr_train()
"""
if self.model == 'deepCR_train':
CRMask.cr_train_deepCR(self)
else:
raise ValueError('Unsupported training model')
def cr_benchmark(self):
"""
Do NOT use this method, just for internal test.
"""
if isinstance(self.res, str) or isinstance(self.res, Path):
hdulist = pyfits.open(self.res)
res = hdulist[1].data
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment