Commit c99d5409 authored by BO ZHANG's avatar BO ZHANG 🏀
Browse files

inherit CsstData from HDUList

parent ca026c55
import os
__version__ = "0.0.4"
__version__ = "0.0.5"
PACKAGE_PATH = os.path.dirname(__file__)
......@@ -5,43 +5,43 @@ from astropy.io.fits import HDUList, PrimaryHDU
import numpy as np
from csst.core.exception import CsstException
from astropy.io import fits
from copy import deepcopy
__all__ = ["CsstData", "INSTRUMENT_LIST"]
INSTRUMENT_LIST = ["MSC", ]
class CsstData:
""" General CSST data class """
_primary_hdu = []
_l0data = [] # HDUList
_l1hdr_global = []
_l1data = OrderedDict() # dict object
_l2data = OrderedDict() #
_auxdata = OrderedDict()
class CsstData(fits.HDUList):
""" General CSST raw class """
# hdu_pri = fits.PrimaryHDU()
# hdu_l0 = fits.ImageHDU(raw=None, name="raw")
# hdu_l1 = fits.ImageHDU(raw=None, name="l1")
"""
header methods:
'add_blank', 'add_comment', 'add_history', 'append', 'cards', 'clear', 'comments', 'copy', 'count', 'extend',
'fromfile', 'fromkeys', 'fromstring', 'fromtextfile', 'get', 'index', 'insert', 'items', 'keys', 'pop', 'popitem',
'remove', 'rename_keyword', 'set', 'setdefault', 'strip', 'tofile', 'tostring', 'totextfile', 'update', 'values'
"""
def __init__(self, priHDU, imgHDU, instrument=None, detector=None):
def __init__(self, hdus=None, file=None):
"""
Parameters
----------
priHDU:
primary HDU
imgHDU:
image HDU
instrument:
instrument
detector:
detector
hdus:
a alist of HDUs
file:
open file object
"""
self._primary_hdu = priHDU
self._l0data = imgHDU
self.instrument = instrument
self.detector = detector
if hdus is None:
hdus = []
super(CsstData, self).__init__(hdus=hdus, file=file)
def get_l0data(self, copy=True):
""" get level 0 data from CsstData class
def get_data(self, copy=True, hdu=1):
""" get level 0 raw from CsstData class
Parameters
----------
......@@ -49,72 +49,79 @@ class CsstData:
if True, return a copy.
"""
if copy:
return self._l0data.data.copy()
return self[hdu].data.copy()
else:
return self._l0data.data
return self[hdu].data
def get_l0keyword(self, ext="pri", key="INSTRUME"):
""" get a specific keyword from fits header of level 0 image data
def get_keyword(self, key="INSTRUME", hdu=0):
""" get keyword from fits header
Parameters
----------
ext: {"pri"| "img"}
the HDU extension
key:
the key
"""
if ext == 'pri':
try:
return self._primary_hdu.header.get(key)
except Exception as e:
print(e)
elif ext == 'img':
try:
return self._l0data.header.get(key)
except Exception as e:
print(e)
else:
raise CsstException
return self[hdu].header.get(key)
def set_l1keyword(self, key, value):
""" set L1 keyword """
raise NotImplementedError("Well, not implemented...")
def set_l1data(self, *args, **kwargs):
print('save image data to l2data')
raise NotImplementedError
def get_auxdata(self, name):
""" get aux data
def set_keyword(self, key, value, hdu=1):
""" set keyword
Parameters
----------
"""
print('Parent class returns zero image.')
# return np.zeros_like(self.get_l0data())
raise NotImplementedError
def save_l1data(self, imgtype, filename):
""" save L1 image and auxilary data to file
key:
key
value:
value
hdu:
0 for primary hdu, 1+ for raw hdu
Parameters
----------
imgtype: {}
image type
"""
print("save L1 image to a fits file with name " + filename)
try:
self._l1hdr_global.set('TYPE', imgtype, 'Type of Level 1 data')
hdulist = fits.HDUList(
[
fits.PrimaryHDU(header=self._l1hdr_global),
fits.ImageHDU(header=self._l1data[imgtype].header, data=self._l1data[imgtype].data),
]
)
hdulist.writeto(filename, overwrite=True)
except Exception as e:
print(e)
def read(self, **kwargs):
""" read data from fits file """
self[hdu].header[key] = value
return
def set_data(self, data, hdu=1):
""" set image raw """
self[hdu].data = data
return
# def writeto(self, fp, overwrite=False):
# """ save L1 image and aux raw to file
#
# Parameters
# ----------
# fp: str
# image type
# overwrite : bool
# if True, overwrite file
# """
# self.writeto(fp, overwrite=overwrite)
def get_auxdata(self):
""" get aux raw
In future, this is to automatically get aux raw from database.
"""
raise NotImplementedError
@classmethod
def read(cls, name, ignore_missing_simple=True):
""" read raw from fits file, should be implemented in child classes """
return cls.fromfile(name, ignore_missing_simple=ignore_missing_simple)
def deepcopy(self, name=None, data=None):
""" generate a deep copy of self """
cp = self.__class__(deepcopy(self))
if name is not None:
cp[1].name = name
if data is not None:
cp[1].data = data
return cp
@property
def data(self):
if len(self) == 1:
return self[1].data
return self[1].data
@property
def exptime(self):
return self[0].header["EXPTIME"]
from .data import CsstMscData, CsstMscImgData
\ No newline at end of file
from .data import CsstMscImgData
\ No newline at end of file
# from abc import ABC
from collections import OrderedDict
import astropy.io.fits as fits
from astropy.io.fits import HDUList, PrimaryHDU, ImageHDU
......@@ -6,143 +7,88 @@ from ..core.data import CsstData, INSTRUMENT_LIST
import numpy as np
__all__ = ["CsstMscData", "CsstMscImgData"]
__all__ = ["CsstMscImgData", ]
class CsstMscData(CsstData):
class CsstMscImgData(CsstData):
_l1img_types = {'sci': True, 'weight': True, 'flag': True}
def __init__(self, priHDU, imgHDU, **kwargs):
super(CsstMscData, self).__init__(priHDU, imgHDU, **kwargs)
self._l1hdr_global = priHDU.header.copy()
self._l1data['sci'] = ImageHDU()
self._l1data['weight'] = ImageHDU()
self._l1data['flag'] = ImageHDU()
def set_flat(self, flat):
""" set flat
def __init__(self, hdus=None, file=None):
"""
Parameters
----------
flat:
flat image
Returns
-------
hdus:
a list of HDUs
file:
open file object
"""
self._auxdata['flat'] = flat
def set_bias(self, biasimg):
""" set bias """
self._auxdata['bias'] = biasimg
if hdus is None:
hdus = []
super(CsstMscImgData, self).__init__(hdus=hdus, file=file)
def set_dark(self, darkimg):
""" set dark """
self._auxdata['dark'] = darkimg
# meta info
self.instrument = self[0].header["INSTRUME"]
self.detector = self[0].header["DETECTOR"]
def set_badpixel(self, badpixelimg):
""" set badpixel """
self._auxdata['badpixel'] = badpixelimg
# self._l1hdr_global = self[0].header.copy()
# self._l1data = dict()
# self._l1data['sci'] = ImageHDU()
# self._l1data['weight'] = ImageHDU()
# self._l1data['flag'] = ImageHDU()
def get_flat(self):
def get_flat(self, fp):
""" get flat """
return self._auxdata['flat']
return fits.getdata(fp)
def get_bias(self):
def get_bias(self, fp):
""" get bias """
return self._auxdata['bias']
return fits.getdata(fp)
def get_dark(self):
def get_dark(self, fp):
""" get dark """
return self._auxdata['dark']
def get_badpixel(self):
""" get badpixel """
return self._auxdata['badpixel']
def init_l0data(self):
""" initialize L0 data """
pass
def set_l1keyword(self, key, value, comment=''):
""" set L1 keyword """
print('check out whether ' + key + " is a valid key and " + value + " is valid value")
self._l1hdr_global.set(key, value, comment)
def set_l1data(self, imgtype, img):
""" set L1 data """
try:
if imgtype == 'sci':
self._l1data[imgtype].header['EXTNAME'] = 'img'
self._l1data[imgtype].header['BUNIT'] = 'e/s'
self._l1data[imgtype].data = img.astype(np.float32) / self._l1hdr_global['exptime']
elif imgtype == 'weight':
self._l1data[imgtype].header['EXTNAME'] = 'wht'
self._l1data[imgtype].data = img.astype(np.float32)
elif imgtype == 'flag':
self._l1data[imgtype].header['EXTNAME'] = 'flg'
self._l1data[imgtype].data = img.astype(np.uint16)
else:
raise TypeError('unknow type image')
except Exception as e:
print(e)
print('save image data to l1data')
def save_l1data(self, imgtype, filename):
""" save L1 data """
print('check ' + imgtype + ' is validate')
try:
if self._l1img_types[imgtype]:
super().save_l1data(imgtype, filename)
except Exception as e:
print(e)
def get_l1data(self, imgtype):
assert imgtype in ["sci", "flag", "weight"]
return self._l1img_types[imgtype]
class CsstMscImgData(CsstMscData):
def __init__(self, priHDU, imgHDU, **kwargs):
# print('create CsstMscImgData')
super(CsstMscImgData, self).__init__(priHDU, imgHDU, **kwargs)
return fits.getdata(fp)
def get_l1data(self):
""" get L1 raw """
imgdata = self.get_data(hdu=1)
exptime = self.get_keyword("EXPTIME", hdu=0)
# image
img = self.deepcopy(name="img", data=imgdata.astype(np.float32) / exptime)
img[1].header['BUNIT'] = 'e/s'
# weight
wht = self.deepcopy(name="wht", data=imgdata.astype(np.float32))
wht[1].header.remove('BUNIT')
# flag
flg = self.deepcopy(name="flg", data=imgdata.astype(np.uint16))
flg[1].header.remove('BUNIT')
return img, wht, flg
def __repr__(self):
return "<CsstMscImgData: {} {}>".format(self.instrument, self.detector)
@staticmethod
def read(fp):
""" read data from fits file
Parameters
----------
fp:
the file path of fits file
Returns
-------
CsstMscImgData
Example
-------
>>> fp = "MSC_MS_210527171000_100000279_16_raw.fits"
>>> from csst.msc import CsstMscImgData
>>> data = CsstMscImgData.read(fp)
>>> # print some info
>>> print("data: ", data)
>>> print("instrument: ", data.get_l0keyword("pri", "INSTRUME"))
>>> print("object: ", data.get_l0keyword("pri", "OBJECT"))
"""
with fits.open(fp) as hdulist:
instrument = hdulist[0].header.get('INSTRUME') # strip or not?
detector = hdulist[0].header.get('DETECTOR') # strip or not?
print("@CsstMscImgData: reading data {} ...".format(fp))
assert instrument in INSTRUMENT_LIST
if instrument == 'MSC' and 6 <= int(detector[3:5]) <= 25:
# multi-band imaging
hdu0 = hdulist[0].copy()
hdu1 = hdulist[1].copy()
data = CsstMscImgData(hdu0, hdu1, instrument=instrument, detector=detector)
return data
# @staticmethod
# def read(fp):
# """ read raw from fits file
#
# Parameters
# ----------
# fp:
# the file path of fits file
#
# Returns
# -------
# CsstMscImgData
#
# Example
# -------
#
# >>> fp = "MSC_MS_210527171000_100000279_16_raw.fits"
# >>> from csst.msc import CsstMscImgData
# >>> raw = CsstMscImgData.read(fp)
# >>> # print some info
# >>> print("raw: ", raw)
# >>> print("instrument: ", raw.get_l0keyword("pri", "INSTRUME"))
# >>> print("object: ", raw.get_l0keyword("pri", "OBJECT"))
# """
# return CsstMscImgData.fromfile(fp)
from pathlib import Path
from ccdproc import cosmicray_lacosmic
import numpy as np
from ccdproc import cosmicray_lacosmic
from deepCR import deepCR
from ..core.processor import CsstProcessor, CsstProcStatus
from ..msc import CsstMscImgData
from .. import PACKAGE_PATH
DEEPCR_MODEL_PATH = PACKAGE_PATH + "/msc/deepcr_model/CSST_2021-12-30_CCD23_epoch20.pth"
class CsstMscInstrumentProc(CsstProcessor):
......@@ -12,9 +17,9 @@ class CsstMscInstrumentProc(CsstProcessor):
_switches = {'deepcr': True, 'clean': False}
def __init__(self):
pass
super(CsstMscInstrumentProc).__init__()
def _do_fix(self, raw, bias, dark, flat, exptime):
def _do_fix(self, raw, bias, dark, flat):
'''仪器效应改正
将raw扣除本底, 暗场, 平场. 并且避免了除0
......@@ -25,10 +30,10 @@ class CsstMscInstrumentProc(CsstProcessor):
flat: 平场
exptime: 曝光时长
'''
self.__l1img = np.divide(
raw - bias - dark * exptime, flat,
out=np.zeros_like(raw, float),
where=(flat != 0),
self.__img = np.divide(
raw.data - bias.data - dark.data * raw.exptime, flat.data,
out=np.zeros_like(raw.data, float),
where=(flat.data != 0),
)
def _do_badpix(self, flat):
......@@ -39,9 +44,9 @@ class CsstMscInstrumentProc(CsstProcessor):
Args:
flat: 平场
'''
med = np.median(flat)
flg = (flat < 0.5 * med) | (1.5 * med < flat)
self.__flagimg = self.__flagimg | (flg * 1)
med = np.median(flat.data)
flg = (flat.data < 0.5 * med) | (1.5 * med < flat.data)
self.__flg = self.__flg | (flg * 1)
def _do_hot_and_warm_pix(self, dark, exptime, rdnoise):
'''热像元与暖像元标记
......@@ -55,12 +60,12 @@ class CsstMscInstrumentProc(CsstProcessor):
exptime: 曝光时长
rdnoise: 读出噪声
'''
tmp = dark * exptime
tmp = dark.data * exptime
tmp[tmp < 0] = 0
flg = 1 * rdnoise ** 2 <= tmp # 不确定是否包含 暂定包含
self.__flagimg = self.__flagimg | (flg * 2)
self.__flg = self.__flg | (flg * 2)
flg = (0.5 * rdnoise ** 2 < tmp) & (tmp < 1 * rdnoise ** 2)
self.__flagimg = self.__flagimg | (flg * 4)
self.__flg = self.__flg | (flg * 4)
def _do_over_saturation(self, raw):
'''饱和溢出像元标记
......@@ -70,8 +75,8 @@ class CsstMscInstrumentProc(CsstProcessor):
Args:
raw: 科学图生图
'''
flg = raw == 65535
self.__flagimg = self.__flagimg | (flg * 8)
flg = raw.data == 65535
self.__flg = self.__flg | (flg * 8)
def _do_cray(self, gain, rdnoise):
'''宇宙线像元标记
......@@ -80,25 +85,16 @@ class CsstMscInstrumentProc(CsstProcessor):
'''
if self._switches['deepcr']:
clean_model = str(Path(__file__).parent /
'CSST_2021-12-30_CCD23_epoch20.pth')
clean_model = DEEPCR_MODEL_PATH
inpaint_model = 'ACS-WFC-F606W-2-32'
model = deepCR(clean_model,
inpaint_model,
device='CPU',
hidden=50)
masked, cleaned = model.clean(self.__l1img,
threshold=0.5,
inpaint=True,
segment=True,
patch=256,
parallel=True,
n_jobs=2)
model = deepCR(clean_model, inpaint_model, device='CPU', hidden=50)
masked, cleaned = model.clean(
self.__img, threshold=0.5, inpaint=True, segment=True, patch=256, parallel=True, n_jobs=2)
else:
cleaned, masked = cosmicray_lacosmic(ccd=self.__l1img,
sigclip=3., # cr_threshold
sigfrac=0.5, # neighbor_threshold
objlim=5., # constrast
cleaned, masked = cosmicray_lacosmic(ccd=self.__img,
sigclip=3., # cr_threshold
sigfrac=0.5, # neighbor_threshold
objlim=5., # constrast
gain=gain,
readnoise=rdnoise,
satlevel=65535.0,
......@@ -115,9 +111,9 @@ class CsstMscInstrumentProc(CsstProcessor):
verbose=False,
gain_apply=True)
self.__flagimg = self.__flagimg | (masked * 16)
self.__flg = self.__flg | (masked * 16)
if self._switches['clean']:
self.__l1img = cleaned
self.__img = cleaned
def _do_weight(self, bias, gain, rdnoise, exptime):
'''权重图
......@@ -128,55 +124,44 @@ class CsstMscInstrumentProc(CsstProcessor):
rdnoise: 读出噪声
exptime: 曝光时长
'''
data = self.__l1img.copy()
data[self.__l1img < 0] = 0
data = self.__img.copy()
data[self.__img < 0] = 0
weight_raw = 1. / (gain * data + rdnoise ** 2)
bias_weight = np.std(bias)
weight = 1. / (1. / weight_raw + 1. / bias_weight) * exptime ** 2
weight[self.__flagimg > 0] = 0
self.__weightimg = weight
weight[self.__flg > 0] = 0
self.__wht = weight
def prepare(self, **kwargs):
for name in kwargs:
self._switches[name] = kwargs[name]
def run(self, data):
if type(data).__name__ == 'CsstMscImgData' or type(data).__name__ == 'CsstMscSlsData':
raw = data.get_l0data()
self.__l1img = raw.copy()
self.__weightimg = np.zeros_like(raw)
self.__flagimg = np.zeros_like(raw, dtype=np.uint16)
exptime = data.get_l0keyword('pri', 'EXPTIME')
gain = data.get_l0keyword('img', 'GAIN1')
rdnoise = data.get_l0keyword('img', 'RDNOISE1')
flat = data.get_flat()
bias = data.get_bias()
dark = data.get_dark()
def run(self, raw: CsstMscImgData, bias, dark, flat):
print('Flat and bias correction')
assert isinstance(raw, CsstMscImgData)
self.__img = np.copy(raw.data)
self.__wht = np.zeros_like(raw.data, dtype=float)
self.__flg = np.zeros_like(raw.data, dtype=np.uint16)
self._do_fix(raw, bias, dark, flat, exptime)
self._do_badpix(flat)
self._do_hot_and_warm_pix(dark, exptime, rdnoise)
self._do_over_saturation(raw)
self._do_cray(gain, rdnoise)
self._do_weight(bias, gain, rdnoise, exptime)
exptime = raw.get_keyword('EXPTIME', hdu=0)
gain = raw.get_keyword('GAIN1', hdu=1)
rdnoise = raw.get_keyword('RDNOISE1', hdu=1)
print('finish the run and save the results back to CsstData')
# Flat and bias correction
self._do_fix(raw, bias, dark, flat)
self._do_badpix(flat)
self._do_hot_and_warm_pix(dark, exptime, rdnoise)
self._do_over_saturation(raw)
self._do_cray(gain, rdnoise)
self._do_weight(bias, gain, rdnoise, exptime)
data.set_l1data('sci', self.__l1img)
data.set_l1data('weight', self.__weightimg)
data.set_l1data('flag', self.__flagimg)
print('finish the run and save the results back to CsstData')
print('Update keywords')
data.set_l1keyword('SOMEKEY', 'some value',
'Test if I can append the header')
img = raw.deepcopy(name="SCI", data=self.__img)
wht = raw.deepcopy(name="WHT", data=self.__wht)
flg = raw.deepcopy(name="FLG", data=self.__flg)
self._status = CsstProcStatus.normal
else:
self._status = CsstProcStatus.ioerror
return self._status
return img, wht, flg
def cleanup(self):
pass
......@@ -36,10 +36,10 @@ from csst.msc.data import CsstMscImgData
from csst.msc.instrument import CsstMscInstrumentProc
from astropy.io import fits
# get aux data
bs = fits.getdata("/data/ref/MSC_CLB_210525190000_100000014_13_combine.fits")
dk = fits.getdata("/data/ref/MSC_CLD_210525192000_100000014_13_combine.fits")
ft = fits.getdata("/data/ref/MSC_CLF_210525191000_100000014_13_combine.fits")
# get aux raw
bs = fits.getdata("/raw/ref/MSC_CLB_210525190000_100000014_13_combine.fits")
dk = fits.getdata("/raw/ref/MSC_CLD_210525192000_100000014_13_combine.fits")
ft = fits.getdata("/raw/ref/MSC_CLF_210525191000_100000014_13_combine.fits")
fp_img_list = []
fp_flg_list = []
......@@ -47,10 +47,10 @@ fp_wht_list = []
data_list = []
for fp in fp_list:
# read image data
# read image raw
data = CsstMscImgData.read(fp)
# set aux data
# set aux raw
data.set_bias(bs)
data.set_dark(dk)
data.set_flat(ft)
......@@ -66,7 +66,7 @@ for fp in fp_list:
fp_flg = fp.replace("raw.fits", "flg.fits")
fp_wht = fp.replace("raw.fits", "wht.fits")
# save l1 data
# save l1 raw
data.save_l1data('sci', fp_img)
data.save_l1data('flag', fp_flg)
data.save_l1data('weight', fp_wht)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment