Commit 05c74432 authored by GZhao's avatar GZhao
Browse files

update test_io and test_optics

parent 11de1d29
Pipeline #4460 passed with stage
in 0 seconds
...@@ -10,8 +10,8 @@ cpism_refdata/ ...@@ -10,8 +10,8 @@ cpism_refdata/
*.egg-info *.egg-info
example/example_output example/example_output
refdata/starmodel refdata
refdata/target_model
# Other files and folders # Other files and folders
.settings/ .settings/
...@@ -31,6 +31,7 @@ docs/notebooks/image_files/_* ...@@ -31,6 +31,7 @@ docs/notebooks/image_files/_*
tests/.coverage tests/.coverage
tests/htmlcov/ tests/htmlcov/
tests/*.xml tests/*.xml
tests/test_output
# Executables # Executables
*.swf *.swf
......
...@@ -689,6 +689,7 @@ class CpicVisEmccd(object): ...@@ -689,6 +689,7 @@ class CpicVisEmccd(object):
# return img_line[:shape[0]*shape[1]].reshape(shape) # return img_line[:shape[0]*shape[1]].reshape(shape)
def readout(self, image_focal, em_set, expt_set, image_cosmic_ray=False, emgain=None): def readout(self, image_focal, em_set, expt_set, image_cosmic_ray=False, emgain=None):
expt = expt_set expt = expt_set
if expt_set == 0: if expt_set == 0:
expt = 0.001 expt = 0.001
......
...@@ -6,28 +6,48 @@ import numpy as np ...@@ -6,28 +6,48 @@ import numpy as np
config_aim = os.path.dirname(os.path.dirname(__file__)) config_aim = os.path.dirname(os.path.dirname(__file__))
config_aim = os.path.join(config_aim, 'data/refdata_path.yaml') config_aim = os.path.join(config_aim, 'data/refdata_path.yaml')
config_set = False
def set_config(refdata_path=None):
if refdata_path is None: # def set_config(refdata_path=None):
print("input cpism refencence data folder") # if refdata_path is None:
refdata_path = input() # print("input cpism refencence data folder")
refdata_path = os.path.abspath(refdata_path) # refdata_path = input()
with open(config_aim, 'w') as f: # refdata_path = os.path.abspath(refdata_path)
yaml.dump(refdata_path, f) # with open(config_aim, 'w') as f:
return refdata_path # yaml.dump(refdata_path, f)
# return refdata_path
try: # try:
# with open(config_aim, 'r') as f:
# cpism_refdata = yaml.load(f, Loader=yaml.FullLoader)
# if not os.path.isdir(cpism_refdata):
# raise FileNotFoundError('cpism refdata path not found')
# config_set = True
# except FileNotFoundError:
# warnings.warn(f'refdata not setup yet, set it before use')
# cpism_refdata = set_config()
def load_refdata_path(config_aim):
with open(config_aim, 'r') as f: with open(config_aim, 'r') as f:
cpism_refdata = yaml.load(f, Loader=yaml.FullLoader) refdata_list = yaml.load(f, Loader=yaml.FullLoader)
if not os.path.isdir(cpism_refdata):
raise FileNotFoundError('cpism refdata path not found') for refdata in refdata_list:
config_set = True if os.path.isdir(refdata):
except FileNotFoundError: return refdata
warnings.warn(f'refdata not setup yet, set it before use')
cpism_refdata = set_config() print("csst_cpic_sim refdata folder not found, please input cpism refencence data folder")
refdata = input()
refdata = os.path.abspath(refdata)
if os.path.isdir(refdata):
refdata_list.append(refdata)
with open(config_aim, 'w') as f:
yaml.dump(refdata_list, f)
exit()
cpism_refdata = load_refdata_path(config_aim)
config = {} config = {}
config['cpism_refdata'] = cpism_refdata config['cpism_refdata'] = cpism_refdata
...@@ -121,15 +141,15 @@ def which_focalplane(band): ...@@ -121,15 +141,15 @@ def which_focalplane(band):
ValueError ValueError
If the band is not in ['f565', 'f661', 'f743', 'f883', 'f940', 'f1265', 'f1425', 'f1542', 'wfs'] If the band is not in ['f565', 'f661', 'f743', 'f883', 'f940', 'f1265', 'f1425', 'f1542', 'wfs']
""" """
band = band.lower() # band = band.lower()
if band in ['f565', 'f661', 'f743', 'f883']: # if band in ['f565', 'f661', 'f743', 'f883']:
return 'vis' # return 'vis'
if band in ['f940', 'f1265', 'f1425', 'f1542']: # if band in ['f940', 'f1265', 'f1425', 'f1542']:
return 'nir' # return 'nir'
if band in ['wfs']: # if band in ['wfs']:
return 'wfs' # return 'wfs'
return 'vis'
raise ValueError(f"未知的波段{band}") # raise ValueError(f"未知的波段{band}")
def iso_time(time): def iso_time(time):
if isinstance(time, str): if isinstance(time, str):
...@@ -145,7 +165,7 @@ def relative_time(time): ...@@ -145,7 +165,7 @@ def relative_time(time):
if isinstance(time, float): if isinstance(time, float):
return time return time
if isinstance(time, int): if isinstance(time, int):
return float(int) return float(time)
utc0 = config['utc0'] utc0 = config['utc0']
time0 = datetime.timestamp(datetime.fromisoformat(utc0)) time0 = datetime.timestamp(datetime.fromisoformat(utc0))
......
...@@ -13,12 +13,13 @@ from .config import config, iso_time ...@@ -13,12 +13,13 @@ from .config import config, iso_time
default_output_dir = config['output'] default_output_dir = config['output']
log_level = config['log_level'] log_level = config['log_level']
header_check = config['check_fits_header'] header_check = config['check_fits_header']
log_dir = config['log_dir']
if not os.path.exists(log_dir):
os.makedirs(log_dir)
log = Logger(log_dir+'/cpism_pack.log', log_level).logger def set_up_logger(log_dir):
if not os.path.exists(log_dir):
os.makedirs(log_dir)
return Logger(log_dir+'/cpism_pack.log', log_level).logger
log = set_up_logger(config['log_dir'])
def check_and_update_fits_header(header): def check_and_update_fits_header(header):
""" """
...@@ -140,7 +141,6 @@ def obsid_parser( ...@@ -140,7 +141,6 @@ def obsid_parser(
'01': 'SCI', '01': 'SCI',
'02': 'DSF', '02': 'DSF',
'10': 'CALS', '10': 'CALS',
'00': 'TEMP'
} }
obstype = obstype_dict.get(obsid[1:3], 'DEFT') obstype = obstype_dict.get(obsid[1:3], 'DEFT')
return obstype return obstype
...@@ -196,7 +196,7 @@ def primary_hdu( ...@@ -196,7 +196,7 @@ def primary_hdu(
obsid = obs_info['obsid'] obsid = obs_info['obsid']
exp_start = obs_info.get('EXPSTART') exp_start = obs_info['EXPSTART']
exp_start = datetime.fromisoformat(exp_start) exp_start = datetime.fromisoformat(exp_start)
exp_end = obs_info['EXPEND'] exp_end = obs_info['EXPEND']
...@@ -283,7 +283,7 @@ def primary_hdu( ...@@ -283,7 +283,7 @@ def primary_hdu(
cabend = gnc_info.get('CABEND', exp_end.isoformat(timespec='seconds')) cabend = gnc_info.get('CABEND', exp_end.isoformat(timespec='seconds'))
cabend = iso_time(cabend) cabend = iso_time(cabend)
cabend_mjd = datetime_obj_to_mjd(datetime.fromisoformat(cabend)) cabend_mjd = datetime_obj_to_mjd(datetime.fromisoformat(cabend))
header['CABEND'] = gnc_info.get('CABEDN', header['EXPEND']) header['CABEND'] = cabend_mjd
header['SUNANGL1'] = gnc_info.get('SUNANGL1', header['SUNANGL0']) header['SUNANGL1'] = gnc_info.get('SUNANGL1', header['SUNANGL0'])
header['MOONANG1'] = gnc_info.get('MOONANG1', header['MOONANG0']) header['MOONANG1'] = gnc_info.get('MOONANG1', header['MOONANG0'])
header['TEL_ALT1'] = gnc_info.get('TEL_ALT1', header['TEL_ALT0']) header['TEL_ALT1'] = gnc_info.get('TEL_ALT1', header['TEL_ALT0'])
...@@ -318,7 +318,7 @@ def primary_hdu( ...@@ -318,7 +318,7 @@ def primary_hdu(
return hdu return hdu
def frame_header(obs_info, index, primary_header, camera_dict={}): def frame_header(obs_info, index, primary_header, camera_dict):
""" """
Generate the header for a single frame. Generate the header for a single frame.
...@@ -466,7 +466,7 @@ def frame_header(obs_info, index, primary_header, camera_dict={}): ...@@ -466,7 +466,7 @@ def frame_header(obs_info, index, primary_header, camera_dict={}):
return header return header
def save_fits_simple(images, obs_info, output_folder=None): def save_fits_simple(images, obs_info, output_folder='./'):
""" """
Save the image to a fits file with a simple header to TMP directory. Save the image to a fits file with a simple header to TMP directory.
...@@ -512,20 +512,19 @@ def save_fits_simple(images, obs_info, output_folder=None): ...@@ -512,20 +512,19 @@ def save_fits_simple(images, obs_info, output_folder=None):
shift = obs_info['shift'] shift = obs_info['shift']
header['shift'] = f"x:{shift[0]},y:{shift[1]}" header['shift'] = f"x:{shift[0]},y:{shift[1]}"
if output_folder is None:
fullname = f"{tmp_folder_path}/{filename}" fullname = os.path.join(output_folder, filename)
else: print(fullname)
fullname = f"{output_folder}/{filename}" if not os.path.exists(output_folder):
if os.path.exists(output_folder) is False: os.makedirs(output_folder)
os.makedirs(output_folder) log.debug(f"Output folder {output_folder} is created.")
log.debug(f"Output folder {output_folder} is created.")
log.debug(f"save fits file to {fullname}") log.debug(f"save fits file to {fullname}")
fits.writeto(fullname, images, overwrite=True, header=header) fits.writeto(fullname, images, overwrite=True, header=header)
return fullname return fullname
def save_fits(images, obs_info, gnc_info, camera_dict={}, csst_format=True, output_folder=None): def save_fits(images, obs_info, gnc_info, camera_dict={}, csst_format=True, output_folder='./'):
""" """
Save the image to a fits file. Save the image to a fits file.
...@@ -581,10 +580,7 @@ def save_fits(images, obs_info, gnc_info, camera_dict={}, csst_format=True, outp ...@@ -581,10 +580,7 @@ def save_fits(images, obs_info, gnc_info, camera_dict={}, csst_format=True, outp
frame_hdu.add_checksum() frame_hdu.add_checksum()
hdu_list.append(frame_hdu) hdu_list.append(frame_hdu)
if output_folder is None: folder = f"{output_folder}/{folder}"
folder = f"{default_output_dir}/{folder}"
else:
folder = f"{output_folder}/{folder}"
if not os.path.exists(folder): if not os.path.exists(folder):
os.makedirs(folder) os.makedirs(folder)
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import argparse, sys, tqdm, time, os, yaml import argparse, sys, tqdm, time, os, yaml
from glob import glob from glob import glob
from datetime import datetime from datetime import datetime
import traceback
import numpy as np import numpy as np
from .target import spectrum_generator, target_file_load from .target import spectrum_generator, target_file_load
...@@ -123,7 +124,8 @@ def vis_observation( ...@@ -123,7 +124,8 @@ def vis_observation(
params['EXPSTART'] = expt_start_iso.isoformat() params['EXPSTART'] = expt_start_iso.isoformat()
params['EXPEND'] = expt_end_iso.isoformat() params['EXPEND'] = expt_end_iso.isoformat()
params['frame_info'] = all_frame_info params['frame_info'] = all_frame_info
save_fits(image_cube, params, gnc_info, camera_dict=camera.__dict__.copy(), csst_format=csst_format, output_folder=output)
save_fits(image_cube, params, gnc_info, camera.__dict__.copy(), csst_format=csst_format, output_folder=output)
if prograss_bar: if prograss_bar:
pg_bar.close() pg_bar.close()
print(f' Done [{time.time() - start_time:.1f}s] ') print(f' Done [{time.time() - start_time:.1f}s] ')
...@@ -248,47 +250,52 @@ def observation_simulation_from_config(obs_file, config_file): ...@@ -248,47 +250,52 @@ def observation_simulation_from_config(obs_file, config_file):
for ind_target, file in enumerate(file_list): for ind_target, file in enumerate(file_list):
with open(file, 'r') as fid: try:
obs_info = yaml.load(fid, Loader=yaml.FullLoader) with open(file, 'r') as fid:
obs_info = yaml.load(fid, Loader=yaml.FullLoader)
target = target_file_load(obs_info.get('target', {}))
skybg = obs_info.get('skybg', None) target = target_file_load(obs_info.get('target', {}))
expt = obs_info['expt'] skybg = obs_info.get('skybg', None)
band = obs_info['band'] expt = obs_info['expt']
emset = obs_info['emset'] band = obs_info['band']
nframe = obs_info['nframe'] emset = obs_info['emset']
obsid = obs_info['obsid'] nframe = obs_info['nframe']
rotation = obs_info.get('rotation', 0) obsid = obs_info['obsid']
shift = obs_info.get('shift', [0, 0]) rotation = obs_info.get('rotation', 0)
gnc_info = obs_info.get('gnc_info', {}) shift = obs_info.get('shift', [0, 0])
time = obs_info.get('time', 0) gnc_info = obs_info.get('gnc_info', {})
emgain = obs_info.get('emgain', None) time = obs_info.get('time', 0)
emgain = obs_info.get('emgain', None)
time = relative_time(time) time = relative_time(time)
except Exception as e:
log.error(f"{file} is not a valid yaml file.")
log.error(f"Failed with {type(e).__name__}{e}.\n\n {traceback.format_exc()}")
continue
ind_camera = 0 ind_camera = 0
for camera_name, camera in zip(all_camera_name, all_camera): for camera_name, camera in zip(all_camera_name, all_camera):
try: ind_camera += 1
ind_camera += 1 ind_run = ind_target * len(all_camera) + ind_camera
ind_run = ind_target * len(all_camera) + ind_camera all_run = len(all_camera) * len(file_list)
all_run = len(all_camera) * len(file_list) info_text = f"({ind_run}/{all_run}) obsid[{obsid}] with {camera_name}"
info_text = f"({ind_run}/{all_run}) obsid[{obsid}]/{os.path.basename(file)[:-5]} with {camera_name}"
log.info(info_text) log.info(info_text)
if time == 0:
if time == 0: camera.time_syn(time, initial=True)
camera.time_syn(time, initial=True) else:
else: dt = time - camera.system_time
dt = time - camera.system_time if dt < 0:
if dt < 0: log.warning(f'Time is not synced. {dt} seconds are added.')
log.warning(f'Time is not synced. {dt} seconds are added.') dt = 0
dt = 0 camera.time_syn(dt, readout=False)
camera.time_syn(dt, readout=False)
if len(all_camera) > 1:
if len(all_camera) > 1: output = os.path.join(output_folder, camera_name)
output = os.path.join(output_folder, camera_name) else:
else: output = output_folder
output = output_folder
try:
vis_observation( vis_observation(
target, target,
skybg, skybg,
...@@ -307,7 +314,7 @@ def observation_simulation_from_config(obs_file, config_file): ...@@ -307,7 +314,7 @@ def observation_simulation_from_config(obs_file, config_file):
csst_format=csst_format, csst_format=csst_format,
prograss_bar=True) prograss_bar=True)
except Exception as e: except Exception as e:
raise(e) log.error(f"{info_text} failed with {type(e).__name__}{e}.\n\n {traceback.format_exc()}")
def main(argv=None): def main(argv=None):
parser = argparse.ArgumentParser(description='Cpic obsevation image simulation') parser = argparse.ArgumentParser(description='Cpic obsevation image simulation')
...@@ -315,7 +322,7 @@ def main(argv=None): ...@@ -315,7 +322,7 @@ def main(argv=None):
subparsers = parser.add_subparsers(help='type of runs') subparsers = parser.add_subparsers(help='type of runs')
parser_quickrun = subparsers.add_parser('quickrun', help='a quick observation with no configration file') parser_quickrun = subparsers.add_parser('quickrun', help='a quick observation with no configration file')
parser_quickrun.add_argument('target_string', type=str, help='example: \*5.1/25.3(1.3,1.5)/22.1(2.3,-4.5)') parser_quickrun.add_argument('target_string', type=str, help='example: *5.1/25.3(1.3,1.5)/22.1(2.3,-4.5)')
parser_quickrun.add_argument('expt', type=float, help='exposure time [ms]') parser_quickrun.add_argument('expt', type=float, help='exposure time [ms]')
parser_quickrun.add_argument('emgain', type=float, help='emgain or emgain set value if emgain_input is False') parser_quickrun.add_argument('emgain', type=float, help='emgain or emgain set value if emgain_input is False')
parser_quickrun.add_argument('nframe', type=int, help='number of frames') parser_quickrun.add_argument('nframe', type=int, help='number of frames')
...@@ -361,8 +368,8 @@ def main(argv=None): ...@@ -361,8 +368,8 @@ def main(argv=None):
args.func(args) args.func(args)
if __name__ == '__main__': # pragma: no cover # if __name__ == '__main__': # pragma: no cover
sys.exit(main()) # sys.exit(main())
# target_example = { # target_example = {
......
...@@ -684,12 +684,11 @@ def target_file_load( ...@@ -684,12 +684,11 @@ def target_file_load(
If all the above conditions are not met, an empty dict will be returned. If all the above conditions are not met, an empty dict will be returned.
""" """
if not target: # None or empty string or {}
return {}
if isinstance(target, dict): if isinstance(target, dict):
return target return target
if not target: # None or empty string
return {}
if isinstance(target, str): #filename or formatted string if isinstance(target, str): #filename or formatted string
target = target.strip() target = target.strip()
...@@ -699,14 +698,16 @@ def target_file_load( ...@@ -699,14 +698,16 @@ def target_file_load(
catalog_folder = config['catalog_folder'] catalog_folder = config['catalog_folder']
target_file = target target_file = target
target_file += '.yaml' if target_file[-5:].lower() != '.yaml' else "" target_file += '.yaml' if target_file[-5:].lower() != '.yaml' else ""
target_full_path = os.path.join(catalog_folder, target_file) target_name = os.path.basename(target_file)[:-5]
file_search = [target_file, os.path.join(catalog_folder, target_file)]
if os.path.isfile(target_full_path):
with open(target_full_path) as fid:
target = yaml.load(fid, Loader=yaml.FullLoader)
target['name'] = target_file[:-5]
return target
for file in file_search:
if os.path.isfile(file):
with open(file) as fid:
target = yaml.load(fid, Loader=yaml.FullLoader)
target['name'] = target_name
return target
target_str = target target_str = target
if (target_str[0] == '*'): if (target_str[0] == '*'):
objects = target_str[1:].split('/') objects = target_str[1:].split('/')
......
/nfsdata/share/simulation-unittest/cpic_sim - D:\workdir\Project\csst_cpic_sim\refdata
... - /nfsdata/share/simulation-unittest/cpic_sim
\ No newline at end of file
...@@ -21,4 +21,4 @@ output: ./ ...@@ -21,4 +21,4 @@ output: ./
sp2teff_model: ${cpism_refdata}/target_model/sptype2teff_lut.json sp2teff_model: ${cpism_refdata}/target_model/sptype2teff_lut.json
dm_pickle: ${cpism_refdata}/optics/dm_model.pkl dm_pickle: ${cpism_refdata}/optics/dm_model.pkl
pysyn_refdata: ${cpism_refdata}/starmodel/grp/redcat/trds pysyn_refdata: ${cpism_refdata}/starmodel/grp/redcat/trds
catalog_folder: ${cpism_refdata}/demo_catalog
\ No newline at end of file
python D:\workdir\Project\csst_cpic_sim\csst_cpic_sim\main.py $args python D:\workdir\Project\csst_cpic_sim\script\cpicsim.py $args
\ No newline at end of file \ No newline at end of file
import sys
from csst_cpic_sim.main import main
sys.exit(main())
\ No newline at end of file
[run] [run]
branch = True branch = True
source = CpicImgSim source = csst_cpic_sim
import unittest import unittest
from unittest import mock from unittest import mock
from csst_cpic_sim.io import obsid_parser, primary_hdu, frame_header, save_fits_simple
from csst_cpic_sim.config import config
from csst_cpic_sim.io import obsid_parser, primary_hdu, frame_header, set_up_logger
from csst_cpic_sim.camera import CpicVisEmccd
import csst_cpic_sim.io as io import csst_cpic_sim.io as io
from astropy.io import fits from astropy.io import fits
import numpy as np import numpy as np
import yaml import yaml
import os
cstar = { cstar = {
'magnitude': 5, 'magnitude': 5,
...@@ -22,22 +27,49 @@ params = { ...@@ -22,22 +27,49 @@ params = {
'nframe': 5, 'nframe': 5,
'band': 'f661', 'band': 'f661',
'emgain': 10, 'emgain': 10,
'obsid': '51012345678', 'obsid': '40112345678',
'rotation': 20, 'rotation': 20,
'shift': [0, 1], 'shift': [0, 1],
'EXPSTART': '2020-01-01T00:00:00.000',
'EXPEND': '2020-01-02T00:00:00.000',
'frame_info': [
{'expt_start': 0,
'expt_end': 20,
'platescale': 1.0,
'iwa': 0,
'chiptemp': 1}
]
} }
camera = CpicVisEmccd()
import logging
class TestIO(unittest.TestCase): class TestIO(unittest.TestCase):
@mock.patch("os.makedirs")
def test_set_log(self, patch):
log = set_up_logger(config['log_dir'])
self.assertIsInstance(log, logging.Logger)
self.assertRaises(FileNotFoundError, set_up_logger, 'new folder')
patch.assert_called_once_with('new folder')
def test_obsid_parser(self): def test_obsid_parser(self):
self.assertRaises(ValueError, obsid_parser, '20190101') self.assertRaises(ValueError, obsid_parser, '20190101')
self.assertRaises(ValueError, obsid_parser, '123456789012') self.assertRaises(ValueError, obsid_parser, '123456789012')
self.assertRaises(ValueError, obsid_parser, '51012345678')
self.assertEqual(obsid_parser('50012345678'), 'BIAS') self.assertEqual(obsid_parser('42012345678'), 'BIAS')
self.assertEqual(obsid_parser('50112345678'), 'DARK') self.assertEqual(obsid_parser('42112345678'), 'DARK')
self.assertEqual(obsid_parser('50212345678'), 'FLAT') self.assertEqual(obsid_parser('42212345678'), 'FLAT')
self.assertEqual(obsid_parser('50312345678'), 'BKGD') self.assertEqual(obsid_parser('42312345678'), 'BKG')
self.assertEqual(obsid_parser('51012345678'), 'SCIE') self.assertEqual(obsid_parser('42412345678'), 'LASER')
self.assertEqual(obsid_parser('40112345678'), 'SCI')
self.assertEqual(obsid_parser('40212345678'), 'DSF')
self.assertEqual(obsid_parser('41012345678'), 'CALS')
self.assertEqual(obsid_parser('40312345678'), 'DEFT')
def test_primary_hdu(self): def test_primary_hdu(self):
hdu1 = primary_hdu(params, {}, filename_output=False) hdu1 = primary_hdu(params, {}, filename_output=False)
...@@ -48,14 +80,17 @@ class TestIO(unittest.TestCase): ...@@ -48,14 +80,17 @@ class TestIO(unittest.TestCase):
self.assertIsInstance(filename, str) self.assertIsInstance(filename, str)
def test_frame_header(self): def test_frame_header(self):
header = frame_header(params, 1, '2021-01-01T00:00:00') hdu1 = primary_hdu(params, {}, filename_output=False)
header = frame_header(params, 0, hdu1.header, camera.__dict__)
self.assertEqual(header['IMGINDEX'], 1) self.assertEqual(header['IMGINDEX'], 1)
self.assertIsInstance(header, fits.Header) self.assertIsInstance(header, fits.Header)
def test_write_fits(self): @mock.patch("os.makedirs")
@mock.patch("astropy.io.fits.writeto")
def test_write_fits(self, patch_fits, patch_mkdir):
images = np.zeros((5, 10, 10)) images = np.zeros((5, 10, 10))
yaml_str = """ yaml_str = """
obsid: 51012345678 obsid: 42012345678
expt: 300 expt: 300
nframe: 10 nframe: 10
band: "f661" band: "f661"
...@@ -63,6 +98,7 @@ shift: [0, 0] ...@@ -63,6 +98,7 @@ shift: [0, 0]
rotation: 0 rotation: 0
emgain: 100 emgain: 100
skybg: 21 skybg: 21
expstart:
target: target:
cstar: cstar:
...@@ -70,22 +106,19 @@ target: ...@@ -70,22 +106,19 @@ target:
dec: "41.26917d" dec: "41.26917d"
sptype: "M0.5" sptype: "M0.5"
magnitude: 3.4 magnitude: 3.4
planets: objects:
- ra: 10.684792 - ra: 10.684792
dec: 41.26917 dec: 41.26917
sptype: "M0.5" sptype: "M0.5"
magnitude: 3.4 magnitude: 3.4
""" """
print(io.tmp_folder_path)
parameters = yaml.load(yaml_str, Loader=yaml.FullLoader) parameters = yaml.load(yaml_str, Loader=yaml.FullLoader)
mock_fits_writeto = mock.MagicMock() tmp_folder_path = 'test_folder_for_unit_test'
io.fits.writeto = mock_fits_writeto io.save_fits_simple(images, parameters, output_folder=tmp_folder_path)
io.tmp_folder_path = 'test_folder_for_unit_test' patch_mkdir.assert_called_once_with(tmp_folder_path)
io.save_fits_simple(images, parameters) output_name = patch_fits.call_args[0][0]
output_name = mock_fits_writeto.call_args[0][0]
self.assertEqual( self.assertEqual(
output_name[:len(io.tmp_folder_path)], io.tmp_folder_path) output_name[:len(tmp_folder_path)], tmp_folder_path)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -3,11 +3,31 @@ import unittest ...@@ -3,11 +3,31 @@ import unittest
import time import time
from csst_cpic_sim.target import star_photlam from csst_cpic_sim.target import star_photlam
from csst_cpic_sim.optics import make_focus_image, focal_mask, filter_throughput, ideal_focus_image from csst_cpic_sim.optics import ideal_focus_image, focal_convolve, focal_mask, filter_throughput, ideal_focus_image
from csst_cpic_sim.config import which_focalplane, S from csst_cpic_sim.config import which_focalplane, S
import numpy as np import numpy as np
from astropy.io import fits from astropy.io import fits
from csst_cpic_sim.optics import FILTERS
def make_test_sub_image(size, shape):
shape = np.array([shape, shape])
sub_image = np.zeros(shape)
center = (shape-1)/2
xx, yy = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]))
xx = xx - center[1]
yy = yy - center[0]
sub_image[np.abs(xx) < 0.6] = 1
sub_image[np.abs(yy) < 0.6] = 1
sub_image[(np.abs(xx) < 0.6) & (np.abs(yy) < 0.6)] = 0
sub_image[yy > 7] = 0
sub_image[yy < -3] = 0
sub_image[np.abs(xx) > 3] = 0
sub_image2 = (np.sqrt(xx**2*2 + yy**2) < size).astype(int)
sub_image = sub_image2 * (1 - sub_image)
return sub_image
def gaussian_psf(band, spectrum, shape, error=0.1): def gaussian_psf(band, spectrum, shape, error=0.1):
psf_shape = [shape, shape] psf_shape = [shape, shape]
...@@ -22,119 +42,37 @@ def gaussian_psf(band, spectrum, shape, error=0.1): ...@@ -22,119 +42,37 @@ def gaussian_psf(band, spectrum, shape, error=0.1):
return psf * (spectrum * filter).integrate() return psf * (spectrum * filter).integrate()
class TestOptics(unittest.TestCase): class TestOptics(unittest.TestCase):
def test_filter_throughtput(self): def test_filter_throughtput(self):
bandpass = filter_throughput('f661')
self.assertIsInstance(bandpass, S.spectrum.SpectralElement)
bandpass = filter_throughput('F1265') bands = list(FILTERS.keys())
bandpass = filter_throughput(bands[0])
self.assertIsInstance(bandpass, S.spectrum.SpectralElement) self.assertIsInstance(bandpass, S.spectrum.SpectralElement)
bandpass = filter_throughput('deFault') bandpass = filter_throughput('deFault')
self.assertIsInstance(bandpass, S.spectrum.SpectralElement) self.assertIsInstance(bandpass, S.spectrum.SpectralElement)
bandpass = filter_throughput('none')
self.assertIsInstance(bandpass, S.spectrum.SpectralElement)
def test_which_focalpalne(self): def test_which_focalpalne(self):
self.assertEqual(which_focalplane('f565'), 'vis') self.assertEqual(which_focalplane('f565'), 'vis')
self.assertEqual(which_focalplane('F661'), 'vis')
self.assertEqual(which_focalplane('f743'), 'vis')
self.assertEqual(which_focalplane('f883'), 'vis')
self.assertEqual(which_focalplane('F940'), 'nir')
self.assertEqual(which_focalplane('f1265'), 'nir')
self.assertEqual(which_focalplane('F1425'), 'nir')
self.assertEqual(which_focalplane('f1542'), 'nir')
self.assertEqual(which_focalplane('wfs'), 'wfs')
self.assertRaises(ValueError, which_focalplane, 'what')
def test_make_focus_image(self):
# test fuction to generate psf
# test targets
cstar = star_photlam(0, 'F2V', is_blackbody=True)
targets = [
[0, 0, cstar],
[1, 1, star_photlam(10-5, 'B2V', is_blackbody=True)],
[-2, 2, star_photlam(11-5, 'A2V', is_blackbody=True)],
[3, -3, star_photlam(12-5, 'G2V', is_blackbody=True)],
[100, 100, star_photlam(12, 'K2V', is_blackbody=True)],
[100, 100, star_photlam(12, 'K2V', is_blackbody=True)],
]
focus_image = make_focus_image( def test_ideal_focus_image(self):
'f661',
targets,
gaussian_psf,
init_shifts=[1, 1],
rotation=45,
platesize=[1024, 1024]
)
self.assertIsNotNone(focus_image)
focus_image = make_focus_image(
'f661',
[],
gaussian_psf,
init_shifts=[1, 1],
rotation=45,
platesize=[1024, 1024]
)
self.assertIsNotNone(focus_image)
def test_focal_mask(self):
image = np.zeros((100, 100)) + 1
image_out = focal_mask(image, 1, 0.1, throughtput=0)
self.assertEqual((image - image_out).sum(), 2000+2000-400)
if __name__ == '__main__':
# unittest.main()
import time
from CpicImgSim.target import star_photlam
def make_test_sub_image(size, shape):
shape = np.array([shape, shape])
sub_image = np.zeros(shape)
center = (shape-1)/2
xx, yy = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]))
xx = xx - center[1]
yy = yy - center[0]
sub_image[np.abs(xx) < 0.6] = 1
sub_image[np.abs(yy) < 0.6] = 1
sub_image[(np.abs(xx) < 0.6) & (np.abs(yy) < 0.6)] = 0
sub_image[yy > 7] = 0
sub_image[yy < -3] = 0
sub_image[np.abs(xx) > 3] = 0
sub_image2 = (np.sqrt(xx**2*2 + yy**2) < size).astype(int)
sub_image = sub_image2 * (1 - sub_image)
return sub_image
def test_ideal_focus_image():
targets = [ targets = [
[0, 0, star_photlam(2, 'G2'), None], [-20, 0, star_photlam(2, 'G2'), None],
[5, 3, star_photlam(0, 'G2'), make_test_sub_image(4, 20)], [5, 3, star_photlam(0, 'G2'), make_test_sub_image(4, 20)],
[8, 0, star_photlam(-5, 'G2'), make_test_sub_image(10, 100)], [8, 0, star_photlam(-5, 'G2'), make_test_sub_image(10, 100)],
] ]
bandpass = S.Box(6000, 500) bandpass = S.Box(6000, 500)
start_time = time.time()
foc = ideal_focus_image( foc = ideal_focus_image(
bandpass, bandpass,
targets, targets,
0.0165, 0.0165,
[1024, 1024], [1024, 1024],
) )
end_time = time.time() # fits.writeto('foc.fits', foc, overwrite=True)
execution_time = end_time - start_time
fits.writeto('foc.fits', foc, overwrite=True)
start_time = time.time()
foc = ideal_focus_image( foc = ideal_focus_image(
bandpass, bandpass,
targets, targets,
...@@ -142,37 +80,53 @@ if __name__ == '__main__': ...@@ -142,37 +80,53 @@ if __name__ == '__main__':
[1024, 1024], [1024, 1024],
rotation=30, rotation=30,
) )
end_time = time.time()
execution_time = end_time - start_time
fits.writeto('foc_rot30.fits', foc, overwrite=True) foc = ideal_focus_image(
bandpass,
{},
0.0165,
[1024, 1024],
rotation=30,
)
# fits.writeto('foc_rot30.fits', foc, overwrite=True)
def test_focal_mask(self):
image = np.zeros((100, 100)) + 1
image_out = focal_mask(image, 1, 0.1, throughtput=0)
self.assertEqual((image - image_out).sum(), 2000+2000-400)
def test_convolve_psf(): def test_convolve_psf(self):
targets = [ targets = [
[0, 0, star_photlam(2, 'G2'), None], [0, 0, star_photlam(2, 'G2'), None],
[5, 3, star_photlam(0, 'G2'), make_test_sub_image(4, 20)], [5, 3, star_photlam(0, 'G2'), make_test_sub_image(4, 20)],
[8, 0, star_photlam(-5, 'G2'), make_test_sub_image(10, 100)], [8, 0, star_photlam(-5, 'G2'), make_test_sub_image(10, 100)],
] ]
def cov_psf_func(wave, error=0.1): img_final = focal_convolve('f661', {})
psf_shape = [1024, 1024]
xx, yy = np.mgrid[0:psf_shape[0], 0:psf_shape[1]] img_final = focal_convolve('f661', targets)
center = np.array([(psf_shape[0]-1)/2, (psf_shape[1]-1)/2])
sigma = 10 img_final = focal_convolve('f661', targets, init_shifts=[10, 10])
psf = np.exp(-((xx-center[0])**2 +
(yy-center[1])**2) / (2*sigma**2)) # fits.writeto('cov.fits', img_final, overwrite=True)
psf = psf / psf.sum()
return psf
# if __name__ == '__main__':
from CpicImgSim.optics import focal_convolve # # # unittest.main()
img_final = focal_convolve('f661', targets, cov_psf_func) # # import time
# # from CpicImgSim.target import star_photlam
fits.writeto('cov.fits', img_final, overwrite=True)
test_convolve_psf() # test_convolve_psf()
# import matplotlib.pyplot as plt # # import matplotlib.pyplot as plt
# plt.imshow(make_test_sub_image(5, 6)) # # plt.imshow(make_test_sub_image(5, 6))
# plt.show() # # plt.show()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment