Commit 04092ce1 authored by BO ZHANG's avatar BO ZHANG 🏀
Browse files

add support for fits: write test FITS file in setUp

parent 6e696d3b
Pipeline #10813 passed with stage
......@@ -15,37 +15,43 @@ from csst_fs.s3_config import load_s3_options
s3_options = load_s3_options()
def open(name, **kwargs) -> fits.HDUList:
if name.startswith("s3://"):
def open(filename, **kwargs) -> fits.HDUList:
if filename.startswith("s3://"):
# read FITS file from s3
return fits.open(name, use_fsspec=True, fsspec_kwargs=s3_options, **kwargs)
return fits.open(filename, use_fsspec=True, fsspec_kwargs=s3_options, **kwargs)
else:
# read FITS file from local
return fits.open(name, **kwargs)
return fits.open(filename, **kwargs)
def getheader(name, **kwargs) -> fits.HDUList:
if name.startswith("s3://"):
def getheader(filename, **kwargs) -> fits.HDUList:
if filename.startswith("s3://"):
# read FITS file from s3
return fits.getheader(name, use_fsspec=True, fsspec_kwargs=s3_options, **kwargs)
return fits.getheader(
filename, use_fsspec=True, fsspec_kwargs=s3_options, **kwargs
)
else:
# read FITS file from local
return fits.getheader(name, **kwargs)
return fits.getheader(filename, **kwargs)
def getval(name, **kwargs) -> fits.HDUList:
if name.startswith("s3://"):
def getval(filename, **kwargs) -> fits.HDUList:
if filename.startswith("s3://"):
# read FITS file from s3
return fits.getval(name, use_fsspec=True, fsspec_kwargs=s3_options, **kwargs)
return fits.getval(
filename, use_fsspec=True, fsspec_kwargs=s3_options, **kwargs
)
else:
# read FITS file from local
return fits.getval(name, **kwargs)
return fits.getval(filename, **kwargs)
def getdata(name, **kwargs) -> fits.HDUList:
if name.startswith("s3://"):
def getdata(filename, **kwargs) -> fits.HDUList:
if filename.startswith("s3://"):
# read FITS file from s3
return fits.getdata(name, use_fsspec=True, fsspec_kwargs=s3_options, **kwargs)
return fits.getdata(
filename, use_fsspec=True, fsspec_kwargs=s3_options, **kwargs
)
else:
# read FITS file from local
return fits.getdata(name, **kwargs)
return fits.getdata(filename, **kwargs)
......@@ -44,6 +44,6 @@ class TestFitsHeaderOps(unittest.TestCase):
self.assertEqual(val, True)
def test_fits_getdata(self):
data = fits.getdata(test_fits_file, ext=0)
data = fits.getdata(test_fits_file, ext=1)
self.assertIsInstance(data, np.ndarray)
self.assertEqual(data.shape, (5, 5))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment