Commit 5174c719 authored by GZhao's avatar GZhao
Browse files

update pep8 for unittest code

parent 190f1129
Pipeline #7117 passed with stage
in 0 seconds
...@@ -49,6 +49,7 @@ class TestEMCCD(unittest.TestCase): ...@@ -49,6 +49,7 @@ class TestEMCCD(unittest.TestCase):
image = emccd.readout(image_focal, None, expt, iamge_cosmic_ray, emgain=emgain) image = emccd.readout(image_focal, None, expt, iamge_cosmic_ray, emgain=emgain)
self.assertEqual(image.shape[0], emccd.bias_shape[0]) self.assertEqual(image.shape[0], emccd.bias_shape[0])
self.assertEqual(image.shape[1], emccd.bias_shape[1]) self.assertEqual(image.shape[1], emccd.bias_shape[1])
def test_em_fix_fun(self): def test_em_fix_fun(self):
emccd = CpicVisEmccd() emccd = CpicVisEmccd()
emgain = emccd.em_fix_fuc_fit(-5) emgain = emccd.em_fix_fuc_fit(-5)
...@@ -57,7 +58,7 @@ class TestEMCCD(unittest.TestCase): ...@@ -57,7 +58,7 @@ class TestEMCCD(unittest.TestCase):
def test_emccd_update(self): def test_emccd_update(self):
emccd = CpicVisEmccd() emccd = CpicVisEmccd()
emccd.ccd_temp = -100 emccd.ccd_temp = -100
emgain = emccd.emgain_set(1024, ccd_temp=None, self_update=False ) emgain = emccd.emgain_set(1024, ccd_temp=None, self_update=False)
self.assertAlmostEqual(emgain, 1.23, places=2) self.assertAlmostEqual(emgain, 1.23, places=2)
...@@ -74,5 +75,3 @@ class TestEMCCD(unittest.TestCase): ...@@ -74,5 +75,3 @@ class TestEMCCD(unittest.TestCase):
# bias_images = np.array(bias_images) # bias_images = np.array(bias_images)
# from astropy.io import fits # from astropy.io import fits
# fits.writeto("bias.fits", bias_images) # fits.writeto("bias.fits", bias_images)
...@@ -9,7 +9,7 @@ import csst_cpic_sim.io as io ...@@ -9,7 +9,7 @@ import csst_cpic_sim.io as io
from astropy.io import fits from astropy.io import fits
import numpy as np import numpy as np
import yaml import yaml
import os import logging
cstar = { cstar = {
'magnitude': 5, 'magnitude': 5,
...@@ -34,20 +34,19 @@ params = { ...@@ -34,20 +34,19 @@ params = {
'EXPSTART': '2020-01-01T00:00:00.000', 'EXPSTART': '2020-01-01T00:00:00.000',
'EXPEND': '2020-01-02T00:00:00.000', 'EXPEND': '2020-01-02T00:00:00.000',
'frame_info': [ 'frame_info': [
{'expt_start': 0, {
'expt_end': 20, 'expt_start': 0,
'platescale': 1.0, 'expt_end': 20,
'iwa': 0, 'platescale': 1.0,
'chiptemp': 1} 'iwa': 0,
'chiptemp': 1}
] ]
} }
camera = CpicVisEmccd() camera = CpicVisEmccd()
import logging
class TestIO(unittest.TestCase):
class TestIO(unittest.TestCase):
def test_obsid_parser(self): def test_obsid_parser(self):
self.assertRaises(ValueError, obsid_parser, '20190101') self.assertRaises(ValueError, obsid_parser, '20190101')
self.assertRaises(ValueError, obsid_parser, '123456789012') self.assertRaises(ValueError, obsid_parser, '123456789012')
...@@ -99,7 +98,7 @@ rotation: 0 ...@@ -99,7 +98,7 @@ rotation: 0
emgain: 100 emgain: 100
emset: -1 emset: -1
skybg: 21 skybg: 21
expstart: expstart:
target: target:
cstar: cstar:
...@@ -125,5 +124,6 @@ target: ...@@ -125,5 +124,6 @@ target:
self.assertEqual( self.assertEqual(
output_name[:len(tmp_folder_path)], tmp_folder_path) output_name[:len(tmp_folder_path)], tmp_folder_path)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
import unittest import unittest
import shutil
import os
# from CpicImgSim import observation_simulation, quick_run # from CpicImgSim import observation_simulation, quick_run
from csst_cpic_sim.main import main, quick_run_v2, observation_simulation_from_config from csst_cpic_sim.main import main, quick_run_v2, observation_simulation_from_config
from csst_cpic_sim.config import config from csst_cpic_sim.config import config
import os
test_dir = os.path.dirname(__file__) test_dir = os.path.dirname(__file__)
cases = os.path.join(test_dir, 'testcases') cases = os.path.join(test_dir, 'testcases')
output = os.path.join(test_dir, 'test_output') output = os.path.join(test_dir, 'test_output')
config['output'] = output config['output'] = output
from unittest.mock import patch # from unittest.mock import patch
a = [] a = []
import io
import shutil
def clear_output(): def clear_output():
if os.path.exists(output): if os.path.exists(output):
shutil.rmtree(output) shutil.rmtree(output)
os.mkdir(output) os.mkdir(output)
class TestMain(unittest.TestCase): class TestMain(unittest.TestCase):
# def test_help(self): # def test_help(self):
# main(argv = None) # main(argv = None)
...@@ -36,8 +38,6 @@ class TestMain(unittest.TestCase): ...@@ -36,8 +38,6 @@ class TestMain(unittest.TestCase):
file = os.listdir(output) file = os.listdir(output)
self.assertEqual(len(file), 1) self.assertEqual(len(file), 1)
self.assertEqual(file[0][:9], 'demo_0_20') self.assertEqual(file[0][:9], 'demo_0_20')
def test_quick_run_func(self): def test_quick_run_func(self):
clear_output() clear_output()
...@@ -47,8 +47,6 @@ class TestMain(unittest.TestCase): ...@@ -47,8 +47,6 @@ class TestMain(unittest.TestCase):
file = os.listdir(output) file = os.listdir(output)
self.assertEqual(len(file), 1) self.assertEqual(len(file), 1)
self.assertEqual(file[0][:5], 'blank') self.assertEqual(file[0][:5], 'blank')
def test_full_run_func(self): def test_full_run_func(self):
clear_output() clear_output()
...@@ -66,7 +64,7 @@ class TestMain(unittest.TestCase): ...@@ -66,7 +64,7 @@ class TestMain(unittest.TestCase):
file = os.listdir(output) file = os.listdir(output)
self.assertEqual(len(file), 4) self.assertEqual(len(file), 4)
clear_output() clear_output()
observation_simulation_from_config( observation_simulation_from_config(
os.path.join(cases, '05_sci.yaml'), os.path.join(cases, '05_sci.yaml'),
...@@ -75,7 +73,6 @@ class TestMain(unittest.TestCase): ...@@ -75,7 +73,6 @@ class TestMain(unittest.TestCase):
file = os.listdir(output) file = os.listdir(output)
self.assertEqual(len(file), 1) self.assertEqual(len(file), 1)
self.assertEqual(file[0][:5], 'SCI') self.assertEqual(file[0][:5], 'SCI')
def test_full_run_cmdline(self): def test_full_run_cmdline(self):
clear_output() clear_output()
...@@ -83,7 +80,6 @@ class TestMain(unittest.TestCase): ...@@ -83,7 +80,6 @@ class TestMain(unittest.TestCase):
file = os.listdir(output) file = os.listdir(output)
self.assertEqual(len(file), 1) self.assertEqual(len(file), 1)
self.assertEqual(file[0][:5], 'SCI') self.assertEqual(file[0][:5], 'SCI')
# def test_main(self): # def test_main(self):
# target_example = { # target_example = {
......
import unittest import unittest
import time
from csst_cpic_sim.target import star_photlam from csst_cpic_sim.target import star_photlam
from csst_cpic_sim.optics import ideal_focus_image, focal_convolve, focal_mask, filter_throughput, ideal_focus_image from csst_cpic_sim.optics import ideal_focus_image, focal_convolve, focal_mask, filter_throughput, ideal_focus_image
from csst_cpic_sim.config import which_focalplane, S from csst_cpic_sim.config import S
import numpy as np import numpy as np
from astropy.io import fits
from csst_cpic_sim.optics import FILTERS from csst_cpic_sim.optics import FILTERS
def make_test_sub_image(size, shape): def make_test_sub_image(size, shape):
shape = np.array([shape, shape]) shape = np.array([shape, shape])
sub_image = np.zeros(shape) sub_image = np.zeros(shape)
...@@ -28,21 +25,21 @@ def make_test_sub_image(size, shape): ...@@ -28,21 +25,21 @@ def make_test_sub_image(size, shape):
sub_image2 = (np.sqrt(xx**2*2 + yy**2) < size).astype(int) sub_image2 = (np.sqrt(xx**2*2 + yy**2) < size).astype(int)
sub_image = sub_image2 * (1 - sub_image) sub_image = sub_image2 * (1 - sub_image)
return sub_image return sub_image
def gaussian_psf(band, spectrum, shape, error=0.1): def gaussian_psf(band, spectrum, shape, error=0.1):
psf_shape = [shape, shape] psf_shape = [shape, shape]
xx, yy = np.mgrid[0:psf_shape[0], 0:psf_shape[1]] xx, yy = np.mgrid[0:psf_shape[0], 0:psf_shape[1]]
center = np.array([(psf_shape[0]-1)/2, (psf_shape[1]-1)/2]) center = np.array([(psf_shape[0]-1)/2, (psf_shape[1]-1)/2])
sigma = 10 sigma = 10
psf = np.exp(-((xx-center[0])**2 + psf = np.exp(-((xx-center[0])**2 + (yy-center[1])**2) / (2*sigma**2))
(yy-center[1])**2) / (2*sigma**2))
psf = psf / psf.sum() psf = psf / psf.sum()
filter = filter_throughput(band) filter = filter_throughput(band)
return psf * (spectrum * filter).integrate() return psf * (spectrum * filter).integrate()
class TestOptics(unittest.TestCase): class TestOptics(unittest.TestCase):
def test_filter_throughtput(self): def test_filter_throughtput(self):
...@@ -53,10 +50,6 @@ class TestOptics(unittest.TestCase): ...@@ -53,10 +50,6 @@ class TestOptics(unittest.TestCase):
bandpass = filter_throughput('deFault') bandpass = filter_throughput('deFault')
self.assertIsInstance(bandpass, S.spectrum.SpectralElement) self.assertIsInstance(bandpass, S.spectrum.SpectralElement)
# def test_which_focalpalne(self):
# self.assertEqual(which_focalplane('f565'), 'vis')
def test_ideal_focus_image(self): def test_ideal_focus_image(self):
targets = [ targets = [
[-20, 0, star_photlam(2, 'G2'), None], [-20, 0, star_photlam(2, 'G2'), None],
...@@ -64,8 +57,8 @@ class TestOptics(unittest.TestCase): ...@@ -64,8 +57,8 @@ class TestOptics(unittest.TestCase):
[8, 0, star_photlam(-5, 'G2'), make_test_sub_image(10, 100)], [8, 0, star_photlam(-5, 'G2'), make_test_sub_image(10, 100)],
] ]
bandpass = S.Box(6000, 500) bandpass = S.Box(6000, 500)
foc = ideal_focus_image( ideal_focus_image(
bandpass, bandpass,
targets, targets,
0.0165, 0.0165,
...@@ -73,7 +66,7 @@ class TestOptics(unittest.TestCase): ...@@ -73,7 +66,7 @@ class TestOptics(unittest.TestCase):
) )
# fits.writeto('foc.fits', foc, overwrite=True) # fits.writeto('foc.fits', foc, overwrite=True)
foc = ideal_focus_image( ideal_focus_image(
bandpass, bandpass,
targets, targets,
0.0165, 0.0165,
...@@ -81,7 +74,7 @@ class TestOptics(unittest.TestCase): ...@@ -81,7 +74,7 @@ class TestOptics(unittest.TestCase):
rotation=30, rotation=30,
) )
foc = ideal_focus_image( ideal_focus_image(
bandpass, bandpass,
{}, {},
0.0165, 0.0165,
...@@ -89,7 +82,6 @@ class TestOptics(unittest.TestCase): ...@@ -89,7 +82,6 @@ class TestOptics(unittest.TestCase):
rotation=30, rotation=30,
) )
# fits.writeto('foc_rot30.fits', foc, overwrite=True) # fits.writeto('foc_rot30.fits', foc, overwrite=True)
def test_focal_mask(self): def test_focal_mask(self):
image = np.zeros((100, 100)) + 1 image = np.zeros((100, 100)) + 1
...@@ -103,28 +95,17 @@ class TestOptics(unittest.TestCase): ...@@ -103,28 +95,17 @@ class TestOptics(unittest.TestCase):
[8, 0, star_photlam(-5, 'G2'), make_test_sub_image(10, 100)], [8, 0, star_photlam(-5, 'G2'), make_test_sub_image(10, 100)],
] ]
img_final = focal_convolve('f661', {}) focal_convolve('f661', {})
img_final = focal_convolve('f661', targets) focal_convolve('f661', targets)
img_final = focal_convolve('f661', targets, init_shifts=[10, 10]) focal_convolve('f661', targets, init_shifts=[10, 10])
# fits.writeto('cov.fits', img_final, overwrite=True)
# if __name__ == '__main__': # if __name__ == '__main__':
# # # unittest.main() # # # unittest.main()
# # import time # # import time
# # from CpicImgSim.target import star_photlam # # from CpicImgSim.target import star_photlam
# test_convolve_psf() # test_convolve_psf()
# # import matplotlib.pyplot as plt # # import matplotlib.pyplot as plt
......
...@@ -9,9 +9,10 @@ from csst_cpic_sim.config import S ...@@ -9,9 +9,10 @@ from csst_cpic_sim.config import S
tests_folder = os.path.dirname(os.path.abspath(__file__)) tests_folder = os.path.dirname(os.path.abspath(__file__))
class TestTarget(unittest.TestCase): class TestTarget(unittest.TestCase):
def test_target_object(self): def test_target_object(self):
d_cstar= { d_cstar = {
'magnitude': 5, 'magnitude': 5,
'ra': '120d', 'ra': '120d',
'dec': '40d', 'dec': '40d',
...@@ -34,20 +35,19 @@ class TestTarget(unittest.TestCase): ...@@ -34,20 +35,19 @@ class TestTarget(unittest.TestCase):
'phase_angle': 90, 'phase_angle': 90,
'sp_model': 'hybrid_planet' 'sp_model': 'hybrid_planet'
} }
old_planet = TargetOjbect(d_planet, cstar = cstar) old_planet = TargetOjbect(d_planet, cstar=cstar)
self.assertEqual(old_planet.sp_model, 'hybrid_planet') self.assertEqual(old_planet.sp_model, 'hybrid_planet')
d_planet['sp_model'] = 'bcc_planet' d_planet['sp_model'] = 'bcc_planet'
d_planet['coe_cloud'] = 1 d_planet['coe_cloud'] = 1
d_planet['coe_metal'] = 0 d_planet['coe_metal'] = 0
old_planet = TargetOjbect(d_planet, cstar = cstar) old_planet = TargetOjbect(d_planet, cstar=cstar)
self.assertEqual(old_planet.sp_model, 'bcc_planet') self.assertEqual(old_planet.sp_model, 'bcc_planet')
def test_bcc_albedo_spectrum(self): def test_bcc_albedo_spectrum(self):
spectrum = AlbedoCat(90, 1, 0) spectrum = AlbedoCat(90, 1, 0)
self.assertIsInstance(spectrum, S.spectrum.SpectralElement) self.assertIsInstance(spectrum, S.spectrum.SpectralElement)
spectrum = bcc_spectrum(0.5, 0.5) spectrum = bcc_spectrum(0.5, 0.5)
self.assertIsInstance(spectrum, S.spectrum.SpectralElement) self.assertIsInstance(spectrum, S.spectrum.SpectralElement)
self.assertEqual(spectrum.waveunits.name, 'angstrom') self.assertEqual(spectrum.waveunits.name, 'angstrom')
...@@ -56,7 +56,6 @@ class TestTarget(unittest.TestCase): ...@@ -56,7 +56,6 @@ class TestTarget(unittest.TestCase):
# plt.plot(spectrum.wave, spectrum.throughput) # plt.plot(spectrum.wave, spectrum.throughput)
# plt.show() # plt.show()
def test_hybrid_albedo_spectrum(self): def test_hybrid_albedo_spectrum(self):
planet = hybrid_albedo_spectrum(0.5, 1) planet = hybrid_albedo_spectrum(0.5, 1)
self.assertIsInstance(planet, S.spectrum.SpectralElement) self.assertIsInstance(planet, S.spectrum.SpectralElement)
...@@ -124,7 +123,7 @@ class TestTarget(unittest.TestCase): ...@@ -124,7 +123,7 @@ class TestTarget(unittest.TestCase):
self.assertRaises(ValueError, extract_target_x_y, dict(ra='120d')) self.assertRaises(ValueError, extract_target_x_y, dict(ra='120d'))
self.assertRaises(ValueError, extract_target_x_y, self.assertRaises(ValueError, extract_target_x_y,
dict(ra='120d', dec='40d')) dict(ra='120d', dec='40d'))
def test_detect_template_path(self): def test_detect_template_path(self):
for f in os.listdir(): for f in os.listdir():
if os.path.isfile(f): if os.path.isfile(f):
...@@ -138,19 +137,16 @@ class TestTarget(unittest.TestCase): ...@@ -138,19 +137,16 @@ class TestTarget(unittest.TestCase):
self.assertRaises(FileExistsError, detect_template_path, 'demo_5_35.yaml') self.assertRaises(FileExistsError, detect_template_path, 'demo_5_35.yaml')
def test_target_file_load(self): def test_target_file_load(self):
t0 = target_file_load({0:0}) t0 = target_file_load({0: 0})
self.assertEqual(t0[0], 0) self.assertEqual(t0[0], 0)
t_error = target_file_load(['1']) t_error = target_file_load(['1'])
self.assertEqual(t_error, {}) self.assertEqual(t_error, {})
t1 = target_file_load('') t1 = target_file_load('')
self.assertEqual(t1, {}) self.assertEqual(t1, {})
t1 = target_file_load(' ') t1 = target_file_load(' ')
self.assertEqual(t1, {}) self.assertEqual(t1, {})
...@@ -172,7 +168,6 @@ class TestTarget(unittest.TestCase): ...@@ -172,7 +168,6 @@ class TestTarget(unittest.TestCase):
self.assertEqual(t3['cstar']['sp_model'], 'reference') self.assertEqual(t3['cstar']['sp_model'], 'reference')
self.assertEqual(len(t3['objects']), 0) self.assertEqual(len(t3['objects']), 0)
t4 = target_file_load('demo_0_20') t4 = target_file_load('demo_0_20')
self.assertEqual(t4['cstar']['magnitude'], 0) self.assertEqual(t4['cstar']['magnitude'], 0)
self.assertEqual(t4['cstar']['sp_model'], 'reference') self.assertEqual(t4['cstar']['sp_model'], 'reference')
...@@ -222,4 +217,3 @@ class TestTarget(unittest.TestCase): ...@@ -222,4 +217,3 @@ class TestTarget(unittest.TestCase):
template_star = target_file_load('demo_std_star') template_star = target_file_load('demo_std_star')
spectrums = spectrum_generator(template_star) spectrums = spectrum_generator(template_star)
self.assertEqual(len(spectrums), 1) self.assertEqual(len(spectrums), 1)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment