From c99d540959b8f4dad799f0e34c180342847b0273 Mon Sep 17 00:00:00 2001 From: Bo Zhang Date: Wed, 13 Apr 2022 15:46:07 +0800 Subject: [PATCH] inherit CsstData from HDUList --- csst/__init__.py | 2 +- csst/core/data.py | 163 +++++++++++++++++++----------------- csst/msc/__init__.py | 2 +- csst/msc/data.py | 186 +++++++++++++++-------------------------- csst/msc/instrument.py | 123 ++++++++++++--------------- csst/msc/pipeline.py | 14 ++-- 6 files changed, 214 insertions(+), 276 deletions(-) diff --git a/csst/__init__.py b/csst/__init__.py index 7079475..877e3d8 100644 --- a/csst/__init__.py +++ b/csst/__init__.py @@ -1,4 +1,4 @@ import os -__version__ = "0.0.4" +__version__ = "0.0.5" PACKAGE_PATH = os.path.dirname(__file__) diff --git a/csst/core/data.py b/csst/core/data.py index 8689978..373db3a 100644 --- a/csst/core/data.py +++ b/csst/core/data.py @@ -5,43 +5,43 @@ from astropy.io.fits import HDUList, PrimaryHDU import numpy as np from csst.core.exception import CsstException - +from astropy.io import fits +from copy import deepcopy __all__ = ["CsstData", "INSTRUMENT_LIST"] INSTRUMENT_LIST = ["MSC", ] -class CsstData: - """ General CSST data class """ - _primary_hdu = [] - _l0data = [] # HDUList - _l1hdr_global = [] - _l1data = OrderedDict() # dict object - _l2data = OrderedDict() # - _auxdata = OrderedDict() +class CsstData(fits.HDUList): + """ General CSST raw class """ + # hdu_pri = fits.PrimaryHDU() + # hdu_l0 = fits.ImageHDU(raw=None, name="raw") + # hdu_l1 = fits.ImageHDU(raw=None, name="l1") + + """ + header methods: + 'add_blank', 'add_comment', 'add_history', 'append', 'cards', 'clear', 'comments', 'copy', 'count', 'extend', + 'fromfile', 'fromkeys', 'fromstring', 'fromtextfile', 'get', 'index', 'insert', 'items', 'keys', 'pop', 'popitem', + 'remove', 'rename_keyword', 'set', 'setdefault', 'strip', 'tofile', 'tostring', 'totextfile', 'update', 'values' + """ - def __init__(self, priHDU, imgHDU, instrument=None, detector=None): + def __init__(self, hdus=None, file=None): """ Parameters ---------- - priHDU: - primary HDU - imgHDU: - image HDU - instrument: - instrument - detector: - detector + hdus: + a alist of HDUs + file: + open file object """ - self._primary_hdu = priHDU - self._l0data = imgHDU - self.instrument = instrument - self.detector = detector + if hdus is None: + hdus = [] + super(CsstData, self).__init__(hdus=hdus, file=file) - def get_l0data(self, copy=True): - """ get level 0 data from CsstData class + def get_data(self, copy=True, hdu=1): + """ get level 0 raw from CsstData class Parameters ---------- @@ -49,72 +49,79 @@ class CsstData: if True, return a copy. """ if copy: - return self._l0data.data.copy() + return self[hdu].data.copy() else: - return self._l0data.data + return self[hdu].data - def get_l0keyword(self, ext="pri", key="INSTRUME"): - """ get a specific keyword from fits header of level 0 image data + def get_keyword(self, key="INSTRUME", hdu=0): + """ get keyword from fits header Parameters ---------- - ext: {"pri"| "img"} - the HDU extension key: the key """ - if ext == 'pri': - try: - return self._primary_hdu.header.get(key) - except Exception as e: - print(e) - elif ext == 'img': - try: - return self._l0data.header.get(key) - except Exception as e: - print(e) - else: - raise CsstException + return self[hdu].header.get(key) - def set_l1keyword(self, key, value): - """ set L1 keyword """ - raise NotImplementedError("Well, not implemented...") - - def set_l1data(self, *args, **kwargs): - print('save image data to l2data') - raise NotImplementedError - - def get_auxdata(self, name): - """ get aux data + def set_keyword(self, key, value, hdu=1): + """ set keyword Parameters ---------- - """ - print('Parent class returns zero image.') - # return np.zeros_like(self.get_l0data()) - raise NotImplementedError - - def save_l1data(self, imgtype, filename): - """ save L1 image and auxilary data to file + key: + key + value: + value + hdu: + 0 for primary hdu, 1+ for raw hdu - Parameters - ---------- - imgtype: {} - image type """ - print("save L1 image to a fits file with name " + filename) - try: - self._l1hdr_global.set('TYPE', imgtype, 'Type of Level 1 data') - hdulist = fits.HDUList( - [ - fits.PrimaryHDU(header=self._l1hdr_global), - fits.ImageHDU(header=self._l1data[imgtype].header, data=self._l1data[imgtype].data), - ] - ) - hdulist.writeto(filename, overwrite=True) - except Exception as e: - print(e) - - def read(self, **kwargs): - """ read data from fits file """ + self[hdu].header[key] = value + return + + def set_data(self, data, hdu=1): + """ set image raw """ + self[hdu].data = data + return + + # def writeto(self, fp, overwrite=False): + # """ save L1 image and aux raw to file + # + # Parameters + # ---------- + # fp: str + # image type + # overwrite : bool + # if True, overwrite file + # """ + # self.writeto(fp, overwrite=overwrite) + + def get_auxdata(self): + """ get aux raw + In future, this is to automatically get aux raw from database. + """ raise NotImplementedError + + @classmethod + def read(cls, name, ignore_missing_simple=True): + """ read raw from fits file, should be implemented in child classes """ + return cls.fromfile(name, ignore_missing_simple=ignore_missing_simple) + + def deepcopy(self, name=None, data=None): + """ generate a deep copy of self """ + cp = self.__class__(deepcopy(self)) + if name is not None: + cp[1].name = name + if data is not None: + cp[1].data = data + return cp + + @property + def data(self): + if len(self) == 1: + return self[1].data + return self[1].data + + @property + def exptime(self): + return self[0].header["EXPTIME"] diff --git a/csst/msc/__init__.py b/csst/msc/__init__.py index 9485bf9..1e96140 100644 --- a/csst/msc/__init__.py +++ b/csst/msc/__init__.py @@ -1 +1 @@ -from .data import CsstMscData, CsstMscImgData \ No newline at end of file +from .data import CsstMscImgData \ No newline at end of file diff --git a/csst/msc/data.py b/csst/msc/data.py index 5b20b7d..c19c087 100644 --- a/csst/msc/data.py +++ b/csst/msc/data.py @@ -1,3 +1,4 @@ +# from abc import ABC from collections import OrderedDict import astropy.io.fits as fits from astropy.io.fits import HDUList, PrimaryHDU, ImageHDU @@ -6,143 +7,88 @@ from ..core.data import CsstData, INSTRUMENT_LIST import numpy as np -__all__ = ["CsstMscData", "CsstMscImgData"] +__all__ = ["CsstMscImgData", ] -class CsstMscData(CsstData): +class CsstMscImgData(CsstData): _l1img_types = {'sci': True, 'weight': True, 'flag': True} - def __init__(self, priHDU, imgHDU, **kwargs): - super(CsstMscData, self).__init__(priHDU, imgHDU, **kwargs) - self._l1hdr_global = priHDU.header.copy() - self._l1data['sci'] = ImageHDU() - self._l1data['weight'] = ImageHDU() - self._l1data['flag'] = ImageHDU() - - def set_flat(self, flat): - """ set flat + def __init__(self, hdus=None, file=None): + """ Parameters ---------- - flat: - flat image - - Returns - ------- - + hdus: + a list of HDUs + file: + open file object """ - self._auxdata['flat'] = flat - - def set_bias(self, biasimg): - """ set bias """ - self._auxdata['bias'] = biasimg + if hdus is None: + hdus = [] + super(CsstMscImgData, self).__init__(hdus=hdus, file=file) - def set_dark(self, darkimg): - """ set dark """ - self._auxdata['dark'] = darkimg + # meta info + self.instrument = self[0].header["INSTRUME"] + self.detector = self[0].header["DETECTOR"] - def set_badpixel(self, badpixelimg): - """ set badpixel """ - self._auxdata['badpixel'] = badpixelimg + # self._l1hdr_global = self[0].header.copy() + # self._l1data = dict() + # self._l1data['sci'] = ImageHDU() + # self._l1data['weight'] = ImageHDU() + # self._l1data['flag'] = ImageHDU() - def get_flat(self): + def get_flat(self, fp): """ get flat """ - return self._auxdata['flat'] + return fits.getdata(fp) - def get_bias(self): + def get_bias(self, fp): """ get bias """ - return self._auxdata['bias'] + return fits.getdata(fp) - def get_dark(self): + def get_dark(self, fp): """ get dark """ - return self._auxdata['dark'] - - def get_badpixel(self): - """ get badpixel """ - return self._auxdata['badpixel'] - - def init_l0data(self): - """ initialize L0 data """ - pass - - def set_l1keyword(self, key, value, comment=''): - """ set L1 keyword """ - print('check out whether ' + key + " is a valid key and " + value + " is valid value") - self._l1hdr_global.set(key, value, comment) - - def set_l1data(self, imgtype, img): - """ set L1 data """ - try: - if imgtype == 'sci': - self._l1data[imgtype].header['EXTNAME'] = 'img' - self._l1data[imgtype].header['BUNIT'] = 'e/s' - self._l1data[imgtype].data = img.astype(np.float32) / self._l1hdr_global['exptime'] - elif imgtype == 'weight': - self._l1data[imgtype].header['EXTNAME'] = 'wht' - self._l1data[imgtype].data = img.astype(np.float32) - elif imgtype == 'flag': - self._l1data[imgtype].header['EXTNAME'] = 'flg' - self._l1data[imgtype].data = img.astype(np.uint16) - else: - raise TypeError('unknow type image') - except Exception as e: - print(e) - print('save image data to l1data') - - def save_l1data(self, imgtype, filename): - """ save L1 data """ - print('check ' + imgtype + ' is validate') - try: - if self._l1img_types[imgtype]: - super().save_l1data(imgtype, filename) - except Exception as e: - print(e) - - def get_l1data(self, imgtype): - assert imgtype in ["sci", "flag", "weight"] - return self._l1img_types[imgtype] - - -class CsstMscImgData(CsstMscData): - def __init__(self, priHDU, imgHDU, **kwargs): - # print('create CsstMscImgData') - super(CsstMscImgData, self).__init__(priHDU, imgHDU, **kwargs) + return fits.getdata(fp) + + def get_l1data(self): + """ get L1 raw """ + imgdata = self.get_data(hdu=1) + exptime = self.get_keyword("EXPTIME", hdu=0) + # image + img = self.deepcopy(name="img", data=imgdata.astype(np.float32) / exptime) + img[1].header['BUNIT'] = 'e/s' + # weight + wht = self.deepcopy(name="wht", data=imgdata.astype(np.float32)) + wht[1].header.remove('BUNIT') + # flag + flg = self.deepcopy(name="flg", data=imgdata.astype(np.uint16)) + flg[1].header.remove('BUNIT') + return img, wht, flg def __repr__(self): return "".format(self.instrument, self.detector) - @staticmethod - def read(fp): - """ read data from fits file - - Parameters - ---------- - fp: - the file path of fits file - - Returns - ------- - CsstMscImgData - - Example - ------- - - >>> fp = "MSC_MS_210527171000_100000279_16_raw.fits" - >>> from csst.msc import CsstMscImgData - >>> data = CsstMscImgData.read(fp) - >>> # print some info - >>> print("data: ", data) - >>> print("instrument: ", data.get_l0keyword("pri", "INSTRUME")) - >>> print("object: ", data.get_l0keyword("pri", "OBJECT")) - """ - with fits.open(fp) as hdulist: - instrument = hdulist[0].header.get('INSTRUME') # strip or not? - detector = hdulist[0].header.get('DETECTOR') # strip or not? - print("@CsstMscImgData: reading data {} ...".format(fp)) - assert instrument in INSTRUMENT_LIST - if instrument == 'MSC' and 6 <= int(detector[3:5]) <= 25: - # multi-band imaging - hdu0 = hdulist[0].copy() - hdu1 = hdulist[1].copy() - data = CsstMscImgData(hdu0, hdu1, instrument=instrument, detector=detector) - return data + # @staticmethod + # def read(fp): + # """ read raw from fits file + # + # Parameters + # ---------- + # fp: + # the file path of fits file + # + # Returns + # ------- + # CsstMscImgData + # + # Example + # ------- + # + # >>> fp = "MSC_MS_210527171000_100000279_16_raw.fits" + # >>> from csst.msc import CsstMscImgData + # >>> raw = CsstMscImgData.read(fp) + # >>> # print some info + # >>> print("raw: ", raw) + # >>> print("instrument: ", raw.get_l0keyword("pri", "INSTRUME")) + # >>> print("object: ", raw.get_l0keyword("pri", "OBJECT")) + # """ + # return CsstMscImgData.fromfile(fp) diff --git a/csst/msc/instrument.py b/csst/msc/instrument.py index 81a24d1..a002a11 100644 --- a/csst/msc/instrument.py +++ b/csst/msc/instrument.py @@ -1,10 +1,15 @@ from pathlib import Path -from ccdproc import cosmicray_lacosmic import numpy as np +from ccdproc import cosmicray_lacosmic from deepCR import deepCR from ..core.processor import CsstProcessor, CsstProcStatus +from ..msc import CsstMscImgData +from .. import PACKAGE_PATH + + +DEEPCR_MODEL_PATH = PACKAGE_PATH + "/msc/deepcr_model/CSST_2021-12-30_CCD23_epoch20.pth" class CsstMscInstrumentProc(CsstProcessor): @@ -12,9 +17,9 @@ class CsstMscInstrumentProc(CsstProcessor): _switches = {'deepcr': True, 'clean': False} def __init__(self): - pass + super(CsstMscInstrumentProc).__init__() - def _do_fix(self, raw, bias, dark, flat, exptime): + def _do_fix(self, raw, bias, dark, flat): '''仪器效应改正 将raw扣除本底, 暗场, 平场. 并且避免了除0 @@ -25,10 +30,10 @@ class CsstMscInstrumentProc(CsstProcessor): flat: 平场 exptime: 曝光时长 ''' - self.__l1img = np.divide( - raw - bias - dark * exptime, flat, - out=np.zeros_like(raw, float), - where=(flat != 0), + self.__img = np.divide( + raw.data - bias.data - dark.data * raw.exptime, flat.data, + out=np.zeros_like(raw.data, float), + where=(flat.data != 0), ) def _do_badpix(self, flat): @@ -39,9 +44,9 @@ class CsstMscInstrumentProc(CsstProcessor): Args: flat: 平场 ''' - med = np.median(flat) - flg = (flat < 0.5 * med) | (1.5 * med < flat) - self.__flagimg = self.__flagimg | (flg * 1) + med = np.median(flat.data) + flg = (flat.data < 0.5 * med) | (1.5 * med < flat.data) + self.__flg = self.__flg | (flg * 1) def _do_hot_and_warm_pix(self, dark, exptime, rdnoise): '''热像元与暖像元标记 @@ -55,12 +60,12 @@ class CsstMscInstrumentProc(CsstProcessor): exptime: 曝光时长 rdnoise: 读出噪声 ''' - tmp = dark * exptime + tmp = dark.data * exptime tmp[tmp < 0] = 0 flg = 1 * rdnoise ** 2 <= tmp # 不确定是否包含 暂定包含 - self.__flagimg = self.__flagimg | (flg * 2) + self.__flg = self.__flg | (flg * 2) flg = (0.5 * rdnoise ** 2 < tmp) & (tmp < 1 * rdnoise ** 2) - self.__flagimg = self.__flagimg | (flg * 4) + self.__flg = self.__flg | (flg * 4) def _do_over_saturation(self, raw): '''饱和溢出像元标记 @@ -70,8 +75,8 @@ class CsstMscInstrumentProc(CsstProcessor): Args: raw: 科学图生图 ''' - flg = raw == 65535 - self.__flagimg = self.__flagimg | (flg * 8) + flg = raw.data == 65535 + self.__flg = self.__flg | (flg * 8) def _do_cray(self, gain, rdnoise): '''宇宙线像元标记 @@ -80,25 +85,16 @@ class CsstMscInstrumentProc(CsstProcessor): ''' if self._switches['deepcr']: - clean_model = str(Path(__file__).parent / - 'CSST_2021-12-30_CCD23_epoch20.pth') + clean_model = DEEPCR_MODEL_PATH inpaint_model = 'ACS-WFC-F606W-2-32' - model = deepCR(clean_model, - inpaint_model, - device='CPU', - hidden=50) - masked, cleaned = model.clean(self.__l1img, - threshold=0.5, - inpaint=True, - segment=True, - patch=256, - parallel=True, - n_jobs=2) + model = deepCR(clean_model, inpaint_model, device='CPU', hidden=50) + masked, cleaned = model.clean( + self.__img, threshold=0.5, inpaint=True, segment=True, patch=256, parallel=True, n_jobs=2) else: - cleaned, masked = cosmicray_lacosmic(ccd=self.__l1img, - sigclip=3., # cr_threshold - sigfrac=0.5, # neighbor_threshold - objlim=5., # constrast + cleaned, masked = cosmicray_lacosmic(ccd=self.__img, + sigclip=3., # cr_threshold + sigfrac=0.5, # neighbor_threshold + objlim=5., # constrast gain=gain, readnoise=rdnoise, satlevel=65535.0, @@ -115,9 +111,9 @@ class CsstMscInstrumentProc(CsstProcessor): verbose=False, gain_apply=True) - self.__flagimg = self.__flagimg | (masked * 16) + self.__flg = self.__flg | (masked * 16) if self._switches['clean']: - self.__l1img = cleaned + self.__img = cleaned def _do_weight(self, bias, gain, rdnoise, exptime): '''权重图 @@ -128,55 +124,44 @@ class CsstMscInstrumentProc(CsstProcessor): rdnoise: 读出噪声 exptime: 曝光时长 ''' - data = self.__l1img.copy() - data[self.__l1img < 0] = 0 + data = self.__img.copy() + data[self.__img < 0] = 0 weight_raw = 1. / (gain * data + rdnoise ** 2) bias_weight = np.std(bias) weight = 1. / (1. / weight_raw + 1. / bias_weight) * exptime ** 2 - weight[self.__flagimg > 0] = 0 - self.__weightimg = weight + weight[self.__flg > 0] = 0 + self.__wht = weight def prepare(self, **kwargs): for name in kwargs: self._switches[name] = kwargs[name] - def run(self, data): - if type(data).__name__ == 'CsstMscImgData' or type(data).__name__ == 'CsstMscSlsData': - raw = data.get_l0data() - self.__l1img = raw.copy() - self.__weightimg = np.zeros_like(raw) - self.__flagimg = np.zeros_like(raw, dtype=np.uint16) - - exptime = data.get_l0keyword('pri', 'EXPTIME') - gain = data.get_l0keyword('img', 'GAIN1') - rdnoise = data.get_l0keyword('img', 'RDNOISE1') - flat = data.get_flat() - bias = data.get_bias() - dark = data.get_dark() + def run(self, raw: CsstMscImgData, bias, dark, flat): - print('Flat and bias correction') + assert isinstance(raw, CsstMscImgData) + self.__img = np.copy(raw.data) + self.__wht = np.zeros_like(raw.data, dtype=float) + self.__flg = np.zeros_like(raw.data, dtype=np.uint16) - self._do_fix(raw, bias, dark, flat, exptime) - self._do_badpix(flat) - self._do_hot_and_warm_pix(dark, exptime, rdnoise) - self._do_over_saturation(raw) - self._do_cray(gain, rdnoise) - self._do_weight(bias, gain, rdnoise, exptime) + exptime = raw.get_keyword('EXPTIME', hdu=0) + gain = raw.get_keyword('GAIN1', hdu=1) + rdnoise = raw.get_keyword('RDNOISE1', hdu=1) - print('finish the run and save the results back to CsstData') + # Flat and bias correction + self._do_fix(raw, bias, dark, flat) + self._do_badpix(flat) + self._do_hot_and_warm_pix(dark, exptime, rdnoise) + self._do_over_saturation(raw) + self._do_cray(gain, rdnoise) + self._do_weight(bias, gain, rdnoise, exptime) - data.set_l1data('sci', self.__l1img) - data.set_l1data('weight', self.__weightimg) - data.set_l1data('flag', self.__flagimg) + print('finish the run and save the results back to CsstData') - print('Update keywords') - data.set_l1keyword('SOMEKEY', 'some value', - 'Test if I can append the header') + img = raw.deepcopy(name="SCI", data=self.__img) + wht = raw.deepcopy(name="WHT", data=self.__wht) + flg = raw.deepcopy(name="FLG", data=self.__flg) - self._status = CsstProcStatus.normal - else: - self._status = CsstProcStatus.ioerror - return self._status + return img, wht, flg def cleanup(self): pass diff --git a/csst/msc/pipeline.py b/csst/msc/pipeline.py index e70c162..af3d941 100644 --- a/csst/msc/pipeline.py +++ b/csst/msc/pipeline.py @@ -36,10 +36,10 @@ from csst.msc.data import CsstMscImgData from csst.msc.instrument import CsstMscInstrumentProc from astropy.io import fits -# get aux data -bs = fits.getdata("/data/ref/MSC_CLB_210525190000_100000014_13_combine.fits") -dk = fits.getdata("/data/ref/MSC_CLD_210525192000_100000014_13_combine.fits") -ft = fits.getdata("/data/ref/MSC_CLF_210525191000_100000014_13_combine.fits") +# get aux raw +bs = fits.getdata("/raw/ref/MSC_CLB_210525190000_100000014_13_combine.fits") +dk = fits.getdata("/raw/ref/MSC_CLD_210525192000_100000014_13_combine.fits") +ft = fits.getdata("/raw/ref/MSC_CLF_210525191000_100000014_13_combine.fits") fp_img_list = [] fp_flg_list = [] @@ -47,10 +47,10 @@ fp_wht_list = [] data_list = [] for fp in fp_list: - # read image data + # read image raw data = CsstMscImgData.read(fp) - # set aux data + # set aux raw data.set_bias(bs) data.set_dark(dk) data.set_flat(ft) @@ -66,7 +66,7 @@ for fp in fp_list: fp_flg = fp.replace("raw.fits", "flg.fits") fp_wht = fp.replace("raw.fits", "wht.fits") - # save l1 data + # save l1 raw data.save_l1data('sci', fp_img) data.save_l1data('flag', fp_flg) data.save_l1data('weight', fp_wht) -- GitLab