Commit 3d2380f0 authored by Yan Zhaojun's avatar Yan Zhaojun
Browse files

update

parent 9dc4a371
...@@ -15,7 +15,7 @@ parameters in parallel and serial direction. ...@@ -15,7 +15,7 @@ parameters in parallel and serial direction.
import numpy as np import numpy as np
#CDM03bidir # CDM03bidir
class CDM03bidir(): class CDM03bidir():
""" """
Class to run CDM03 CTI model, class Fortran routine to perform the actual CDM03 calculations. Class to run CDM03 CTI model, class Fortran routine to perform the actual CDM03 calculations.
...@@ -27,6 +27,7 @@ class CDM03bidir(): ...@@ -27,6 +27,7 @@ class CDM03bidir():
:param log: instance to Python logging :param log: instance to Python logging
:type log: logging instance :type log: logging instance
""" """
def __init__(self, settings, data, log=None): def __init__(self, settings, data, log=None):
""" """
Class constructor. Class constructor.
...@@ -39,48 +40,50 @@ class CDM03bidir(): ...@@ -39,48 +40,50 @@ class CDM03bidir():
:type log: logging instance :type log: logging instance
""" """
self.data = data self.data = data
self.values = dict(quads=(0,1,2,3), xsize=2048, ysize=2066, dob=0.0, rdose=8.0e9) self.values = dict(quads=(0, 1, 2, 3), xsize=2048,
ysize=2066, dob=0.0, rdose=8.0e9)
self.values.update(settings) self.values.update(settings)
self.log = log self.log = log
self._setupLogger() self._setupLogger()
#default CDM03 settings # default CDM03 settings
self.params = dict(beta_p=0.6, beta_s=0.6, fwc=200000., vth=1.168e7, vg=6.e-11, t=20.48e-3, self.params = dict(beta_p=0.6, beta_s=0.6, fwc=200000., vth=1.168e7, vg=6.e-11, t=20.48e-3,
sfwc=730000., svg=1.0e-10, st=5.0e-6, parallel=1., serial=0.0) sfwc=730000., svg=1.0e-10, st=5.0e-6, parallel=1., serial=0.0)
#update with inputs # update with inputs
self.params.update(self.values) self.params.update(self.values)
#read in trap information # read in trap information
trapdata = np.loadtxt(self.values['dir_path']+self.values['paralleltrapfile']) trapdata = np.loadtxt(
self.values['dir_path']+self.values['paralleltrapfile'])
if trapdata.ndim > 1: if trapdata.ndim > 1:
self.nt_p = trapdata[:, 0] self.nt_p = trapdata[:, 0]
self.sigma_p = trapdata[:, 1] self.sigma_p = trapdata[:, 1]
self.taur_p = trapdata[:, 2] self.taur_p = trapdata[:, 2]
else: else:
#only one trap species # only one trap species
self.nt_p = [trapdata[0],] self.nt_p = [trapdata[0],]
self.sigma_p = [trapdata[1],] self.sigma_p = [trapdata[1],]
self.taur_p = [trapdata[2],] self.taur_p = [trapdata[2],]
trapdata = np.loadtxt(self.values['dir_path']+self.values['serialtrapfile']) trapdata = np.loadtxt(
self.values['dir_path']+self.values['serialtrapfile'])
if trapdata.ndim > 1: if trapdata.ndim > 1:
self.nt_s = trapdata[:, 0] self.nt_s = trapdata[:, 0]
self.sigma_s = trapdata[:, 1] self.sigma_s = trapdata[:, 1]
self.taur_s = trapdata[:, 2] self.taur_s = trapdata[:, 2]
else: else:
#only one trap species # only one trap species
self.nt_s = [trapdata[0],] self.nt_s = [trapdata[0],]
self.sigma_s = [trapdata[1],] self.sigma_s = [trapdata[1],]
self.taur_s = [trapdata[2],] self.taur_s = [trapdata[2],]
#scale thibaut's values # scale thibaut's values
if 'thibaut' in self.values['parallelTrapfile']: if 'thibaut' in self.values['parallelTrapfile']:
self.nt_p /= 0.576 #thibaut's values traps / pixel self.nt_p /= 0.576 # thibaut's values traps / pixel
self.sigma_p *= 1.e4 #thibaut's values in m**2 self.sigma_p *= 1.e4 # thibaut's values in m**2
if 'thibaut' in self.values['serialTrapfile']: if 'thibaut' in self.values['serialTrapfile']:
self.nt_s *= 0.576 #thibaut's values traps / pixel #should be division? self.nt_s *= 0.576 # thibaut's values traps / pixel #should be division?
self.sigma_s *= 1.e4 #thibaut's values in m**2 self.sigma_s *= 1.e4 # thibaut's values in m**2
def _setupLogger(self): def _setupLogger(self):
""" """
...@@ -90,7 +93,6 @@ class CDM03bidir(): ...@@ -90,7 +93,6 @@ class CDM03bidir():
# if self.log is None: # if self.log is None:
# self.logger = False # self.logger = False
def applyRadiationDamage(self, data, iquadrant=0): def applyRadiationDamage(self, data, iquadrant=0):
""" """
Apply radian damage based on FORTRAN CDM03 model. The method assumes that Apply radian damage based on FORTRAN CDM03 model. The method assumes that
...@@ -127,9 +129,9 @@ class CDM03bidir(): ...@@ -127,9 +129,9 @@ class CDM03bidir():
array will be laid out in memory in C-style (row-major order). array will be laid out in memory in C-style (row-major order).
:return: image that has been run through the CDM03 model :return: image that has been run through the CDM03 model
:rtype: ndarray :rtype: ndarray """""
"""""
#return data # return data
iflip = iquadrant / 2 iflip = iquadrant / 2
jflip = iquadrant % 2 jflip = iquadrant % 2
...@@ -154,8 +156,8 @@ class CDM03bidir(): ...@@ -154,8 +156,8 @@ class CDM03bidir():
self.log.info('jflip=%i' % jflip) self.log.info('jflip=%i' % jflip)
################################################################################# #################################################################################
###modify # modify
#sys.path.append('../so') # sys.path.append('../so')
# from ifs_so.cdm03.cpython-38-x86_64-linux-gnu import cdm03bidir # from ifs_so.cdm03.cpython-38-x86_64-linux-gnu import cdm03bidir
# import cdm03bidir # import cdm03bidir
...@@ -169,9 +171,7 @@ class CDM03bidir(): ...@@ -169,9 +171,7 @@ class CDM03bidir():
params, params,
[data.shape[0], data.shape[1], len(self.nt_p), len(self.nt_s), len(self.params)]) [data.shape[0], data.shape[1], len(self.nt_p), len(self.nt_s), len(self.params)])
return np.asanyarray(CTIed) return np.asanyarray(CTIed)
################################################################################################################# #################################################################################################################
################################################################################################################# #################################################################################################################
This diff is collapsed.
...@@ -207,7 +207,7 @@ class StrayLight(object): ...@@ -207,7 +207,7 @@ class StrayLight(object):
############################################################################### ###############################################################################
### test # test
# path='/home/yan/MCI/' # path='/home/yan/MCI/'
# time_jd = 2460417.59979167 # time_jd = 2460417.59979167
# x_sat = -4722.543136 # x_sat = -4722.543136
......
...@@ -205,7 +205,7 @@ class StrayLight(object): ...@@ -205,7 +205,7 @@ class StrayLight(object):
############################################################################### ###############################################################################
## test # test
# path='/home/yan/MCI/' # path='/home/yan/MCI/'
# time_jd = 2460417.59979167 # time_jd = 2460417.59979167
# x_sat = -4722.543136 # x_sat = -4722.543136
......
...@@ -40,11 +40,11 @@ def MCIinformation(): ...@@ -40,11 +40,11 @@ def MCIinformation():
""" """
######################################################################################################### #########################################################################################################
out=dict() out = dict()
out.update({'dob' : 0, 'rdose' : 8.0e9, out.update({'dob': 0, 'rdose': 8.0e9,
'parallelTrapfile' : 'cdm_euclid_parallel.dat', 'serialTrapfile' : 'cdm_euclid_serial.dat', 'parallelTrapfile': 'cdm_euclid_parallel.dat', 'serialTrapfile': 'cdm_euclid_serial.dat',
'beta_s' : 0.6, 'beta_p': 0.6, 'fwc' : 90000, 'vth' : 1.168e7, 't' : 20.48e-3, 'vg' : 6.e-11, 'beta_s': 0.6, 'beta_p': 0.6, 'fwc': 90000, 'vth': 1.168e7, 't': 20.48e-3, 'vg': 6.e-11,
'st' : 5.0e-6, 'sfwc' : 730000., 'svg' : 1.0e-10}) 'st': 5.0e-6, 'sfwc': 730000., 'svg': 1.0e-10})
return out return out
......
...@@ -31,6 +31,7 @@ class cosmicrays(): ...@@ -31,6 +31,7 @@ class cosmicrays():
:param information: cosmic ray track information (file containing track length and energy information) and :param information: cosmic ray track information (file containing track length and energy information) and
exposure time. exposure time.
""" """
def __init__(self, log, image, exptime, crInfo=None, information=None): def __init__(self, log, image, exptime, crInfo=None, information=None):
""" """
Cosmic ray generation class. Can either draw events from distributions or Cosmic ray generation class. Can either draw events from distributions or
...@@ -43,7 +44,7 @@ class cosmicrays(): ...@@ -43,7 +44,7 @@ class cosmicrays():
exposure time. exposure time.
""" """
self.exptime=exptime self.exptime = exptime
self.log = log self.log = log
...@@ -56,11 +57,11 @@ class cosmicrays(): ...@@ -56,11 +57,11 @@ class cosmicrays():
else: else:
self._readCosmicrayInformation() self._readCosmicrayInformation()
################################################## ##################################################
##################################################
############
def _cosmicRayIntercepts(self, lum, x0, y0, l, phi): ############
def _cosmicRayIntercepts(self, lum, x0, y0, dl, phi):
""" """
Derive cosmic ray streak intercept points. Derive cosmic ray streak intercept points.
...@@ -73,18 +74,18 @@ class cosmicrays(): ...@@ -73,18 +74,18 @@ class cosmicrays():
:return: cosmic ray map (image) :return: cosmic ray map (image)
:rtype: nd-array :rtype: nd-array
""" """
#create empty array # create empty array
crImage = np.zeros((self.ysize, self.xsize), dtype=np.float64) crImage = np.zeros((self.ysize, self.xsize), dtype=np.float64)
#x and y shifts # x and y shifts
dx = l * np.cos(phi) / 2. dx = dl * np.cos(phi) / 2.
dy = l * np.sin(phi) / 2. dy = dl * np.sin(phi) / 2.
mskdx = np.abs(dx) < 1e-8 mskdx = np.abs(dx) < 1e-8
mskdy = np.abs(dy) < 1e-8 mskdy = np.abs(dy) < 1e-8
dx[mskdx] = 0. dx[mskdx] = 0.
dy[mskdy] = 0. dy[mskdy] = 0.
#pixels in x-direction # pixels in x-direction
ilo = np.round(x0.copy() - dx) ilo = np.round(x0.copy() - dx)
msk = ilo < 0. msk = ilo < 0.
ilo[msk] = 0 ilo[msk] = 0
...@@ -95,7 +96,7 @@ class cosmicrays(): ...@@ -95,7 +96,7 @@ class cosmicrays():
ihi[msk] = self.xsize ihi[msk] = self.xsize
ihi = ihi.astype(int) ihi = ihi.astype(int)
#pixels in y-directions # pixels in y-directions
jlo = np.round(y0.copy() - dy) jlo = np.round(y0.copy() - dy)
msk = jlo < 0. msk = jlo < 0.
jlo[msk] = 0 jlo[msk] = 0
...@@ -106,7 +107,7 @@ class cosmicrays(): ...@@ -106,7 +107,7 @@ class cosmicrays():
jhi[msk] = self.ysize jhi[msk] = self.ysize
jhi = jhi.astype(int) jhi = jhi.astype(int)
#loop over the individual events # loop over the individual events
for i, luminosity in enumerate(lum): for i, luminosity in enumerate(lum):
n = 0 # count the intercepts n = 0 # count the intercepts
...@@ -114,7 +115,7 @@ class cosmicrays(): ...@@ -114,7 +115,7 @@ class cosmicrays():
x = [] x = []
y = [] y = []
#Compute X intercepts on the pixel grid # Compute X intercepts on the pixel grid
if ilo[i] < ihi[i]: if ilo[i] < ihi[i]:
for xcoord in range(ilo[i], ihi[i]): for xcoord in range(ilo[i], ihi[i]):
ok = (xcoord - x0[i]) / dx[i] ok = (xcoord - x0[i]) / dx[i]
...@@ -132,7 +133,7 @@ class cosmicrays(): ...@@ -132,7 +133,7 @@ class cosmicrays():
x.append(xcoord) x.append(xcoord)
y.append(y0[i] + ok * dy[i]) y.append(y0[i] + ok * dy[i])
#Compute Y intercepts on the pixel grid # Compute Y intercepts on the pixel grid
if jlo[i] < jhi[i]: if jlo[i] < jhi[i]:
for ycoord in range(jlo[i], jhi[i]): for ycoord in range(jlo[i], jhi[i]):
ok = (ycoord - y0[i]) / dy[i] ok = (ycoord - y0[i]) / dy[i]
...@@ -150,13 +151,13 @@ class cosmicrays(): ...@@ -150,13 +151,13 @@ class cosmicrays():
x.append(x0[i] + ok * dx[i]) x.append(x0[i] + ok * dx[i])
y.append(ycoord) y.append(ycoord)
#check if no intercepts were found # check if no intercepts were found
if n < 1: if n < 1:
xc = int(np.floor(x0[i])) xc = int(np.floor(x0[i]))
yc = int(np.floor(y0[i])) yc = int(np.floor(y0[i]))
crImage[yc, xc] += luminosity crImage[yc, xc] += luminosity
#Find the arguments that sort the intersections along the track # Find the arguments that sort the intersections along the track
u = np.asarray(u) u = np.asarray(u)
x = np.asarray(x) x = np.asarray(x)
y = np.asarray(y) y = np.asarray(y)
...@@ -167,7 +168,7 @@ class cosmicrays(): ...@@ -167,7 +168,7 @@ class cosmicrays():
x = x[args] x = x[args]
y = y[args] y = y[args]
#Decide which cell each interval traverses, and the path length # Decide which cell each interval traverses, and the path length
for i in range(1, n - 1): for i in range(1, n - 1):
w = u[i + 1] - u[i] w = u[i + 1] - u[i]
cx = int(1 + np.floor((x[i + 1] + x[i]) / 2.)) cx = int(1 + np.floor((x[i + 1] + x[i]) / 2.))
...@@ -177,11 +178,11 @@ class cosmicrays(): ...@@ -177,11 +178,11 @@ class cosmicrays():
crImage[cy, cx] += (w * luminosity) crImage[cy, cx] += (w * luminosity)
return crImage return crImage
##################################################
############################################
############################################################################ ############################################################################
#####################################
########################################
def _drawEventsToCoveringFactor(self, coveringFraction=3.0, limit=1000, verbose=False): def _drawEventsToCoveringFactor(self, coveringFraction=3.0, limit=1000, verbose=False):
""" """
Generate cosmic ray events up to a covering fraction and include it to a cosmic ray map (self.cosmicrayMap). Generate cosmic ray events up to a covering fraction and include it to a cosmic ray map (self.cosmicrayMap).
...@@ -198,30 +199,32 @@ class cosmicrays(): ...@@ -198,30 +199,32 @@ class cosmicrays():
""" """
self.cosmicrayMap = np.zeros((self.ysize, self.xsize)) self.cosmicrayMap = np.zeros((self.ysize, self.xsize))
#how many events to draw at once, too large number leads to exceeding the covering fraction # how many events to draw at once, too large number leads to exceeding the covering fraction
####cr_n = int(295 * self.exptime / 565. * coveringFraction / 1.4) # cr_n = int(295 * self.exptime / 565. * coveringFraction / 1.4)
cr_n = int(5000 * self.exptime / 565. * coveringFraction) cr_n = int(5000 * self.exptime / 565. * coveringFraction)
covering = 0.0 covering = 0.0
while covering < coveringFraction: while covering < coveringFraction:
#pseudo-random numbers taken from a uniform distribution between 0 and 1 # pseudo-random numbers taken from a uniform distribution between 0 and 1
np.random.seed() np.random.seed()
luck = np.random.rand(cr_n) luck = np.random.rand(cr_n)
#draw the length of the tracks # draw the length of the tracks
ius = InterpolatedUnivariateSpline(self.cr['cr_cdf'], self.cr['cr_u']) ius = InterpolatedUnivariateSpline(
self.cr['cr_cdf'], self.cr['cr_u'])
self.cr['cr_l'] = ius(luck) self.cr['cr_l'] = ius(luck)
if limit is None: if limit is None:
ius = InterpolatedUnivariateSpline(self.cr['cr_cde'], self.cr['cr_v']) ius = InterpolatedUnivariateSpline(
self.cr['cr_cde'], self.cr['cr_v'])
self.cr['cr_e'] = ius(luck) self.cr['cr_e'] = ius(luck)
else: else:
#set the energy directly to the limit # set the energy directly to the limit
self.cr['cr_e'] = np.asarray([limit,]) self.cr['cr_e'] = np.asarray([limit,])
#Choose the properties such as positions and an angle from a random Uniform dist # Choose the properties such as positions and an angle from a random Uniform dist
np.random.seed() np.random.seed()
cr_x = self.xsize * np.random.rand(int(np.floor(cr_n))) cr_x = self.xsize * np.random.rand(int(np.floor(cr_n)))
...@@ -231,17 +234,19 @@ class cosmicrays(): ...@@ -231,17 +234,19 @@ class cosmicrays():
np.random.seed() np.random.seed()
cr_phi = np.pi * np.random.rand(int(np.floor(cr_n))) cr_phi = np.pi * np.random.rand(int(np.floor(cr_n)))
#find the intercepts # find the intercepts
self.cosmicrayMap += self._cosmicRayIntercepts(self.cr['cr_e'], cr_x, cr_y, self.cr['cr_l'], cr_phi) self.cosmicrayMap += self._cosmicRayIntercepts(
self.cr['cr_e'], cr_x, cr_y, self.cr['cr_l'], cr_phi)
#count the covering factor # count the covering factor
area_cr = np.count_nonzero(self.cosmicrayMap) area_cr = np.count_nonzero(self.cosmicrayMap)
covering = 100.*area_cr / (self.xsize*self.ysize) covering = 100.*area_cr / (self.xsize*self.ysize)
text = 'The cosmic ray covering factor is %i pixels i.e. %.3f per cent' % (area_cr, covering) text = 'The cosmic ray covering factor is %i pixels i.e. %.3f per cent' % (
area_cr, covering)
self.log.info(text) self.log.info(text)
###################################################33 # 33
def addUpToFraction(self, coveringFraction, limit=None, verbose=False): def addUpToFraction(self, coveringFraction, limit=None, verbose=False):
""" """
...@@ -257,13 +262,10 @@ class cosmicrays(): ...@@ -257,13 +262,10 @@ class cosmicrays():
:return: image with cosmic rays :return: image with cosmic rays
:rtype: ndarray :rtype: ndarray
""" """
self._drawEventsToCoveringFactor(coveringFraction, limit=limit, verbose=verbose) self._drawEventsToCoveringFactor(
coveringFraction, limit=limit, verbose=verbose)
#paste cosmic rays # paste cosmic rays
self.image += self.cosmicrayMap self.image += self.cosmicrayMap
return self.image return self.image
...@@ -23,9 +23,10 @@ def setUpLogger(log_filename, loggername='logger'): ...@@ -23,9 +23,10 @@ def setUpLogger(log_filename, loggername='logger'):
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
# Add the log message handler to the logger # Add the log message handler to the logger
handler = logging.handlers.RotatingFileHandler(log_filename) handler = logging.handlers.RotatingFileHandler(log_filename)
#maxBytes=20, backupCount=5) # maxBytes=20, backupCount=5)
# create formatter # create formatter
formatter = logging.Formatter('%(asctime)s - %(module)s - %(funcName)s - %(levelname)s - %(message)s') formatter = logging.Formatter(
'%(asctime)s - %(module)s - %(funcName)s - %(levelname)s - %(message)s')
# add formatter to ch # add formatter to ch
handler.setFormatter(formatter) handler.setFormatter(formatter)
# add handler to logger # add handler to logger
......
...@@ -15,7 +15,8 @@ from astropy import units as u ...@@ -15,7 +15,8 @@ from astropy import units as u
from scipy.interpolate import interp1d from scipy.interpolate import interp1d
def Calzetti_Law(wave, Rv = 4.05):
def Calzetti_Law(wave, Rv=4.05):
"""Dust Extinction Curve by Calzetti et al. (2000) """Dust Extinction Curve by Calzetti et al. (2000)
Args: Args:
...@@ -30,14 +31,15 @@ def Calzetti_Law(wave, Rv = 4.05): ...@@ -30,14 +31,15 @@ def Calzetti_Law(wave, Rv = 4.05):
reddening_curve = np.zeros(len(wave)) reddening_curve = np.zeros(len(wave))
idx = np.logical_and(wave >= 1200, wave <= 6300) idx = np.logical_and(wave >= 1200, wave <= 6300)
reddening_curve[idx] = 2.659 * ( -2.156 + 1.509 * wave_number[idx] - 0.198 * \ reddening_curve[idx] = 2.659 * (-2.156 + 1.509 * wave_number[idx] - 0.198 *
(wave_number[idx] ** 2)) + 0.011 * (wave_number[idx] **3 ) + Rv (wave_number[idx] ** 2)) + 0.011 * (wave_number[idx] ** 3) + Rv
idx = np.logical_and(wave >= 6300, wave <= 22000) idx = np.logical_and(wave >= 6300, wave <= 22000)
reddening_curve[idx] = 2.659 * ( -1.857 + 1.040 * wave_number[idx]) + Rv reddening_curve[idx] = 2.659 * (-1.857 + 1.040 * wave_number[idx]) + Rv
return reddening_curve return reddening_curve
def reddening(wave, flux, ebv = 0.0, law = 'calzetti', Rv = 4.05):
def reddening(wave, flux, ebv=0.0, law='calzetti', Rv=4.05):
""" """
Reddening an input spectra through a given reddening curve. Reddening an input spectra through a given reddening curve.
...@@ -52,10 +54,11 @@ def reddening(wave, flux, ebv = 0.0, law = 'calzetti', Rv = 4.05): ...@@ -52,10 +54,11 @@ def reddening(wave, flux, ebv = 0.0, law = 'calzetti', Rv = 4.05):
float: Flux of spectra after reddening. float: Flux of spectra after reddening.
""" """
if law == 'calzetti': if law == 'calzetti':
curve = Calzetti_Law(wave, Rv = Rv) curve = Calzetti_Law(wave, Rv=Rv)
fluxNew = flux / (10. ** (0.4 * ebv * curve)) fluxNew = flux / (10. ** (0.4 * ebv * curve))
return fluxNew return fluxNew
def flux_to_mag(wave, flux, path, band='GAIA_bp'): def flux_to_mag(wave, flux, path, band='GAIA_bp'):
"""Convert flux of given spectra to magnitude """Convert flux of given spectra to magnitude
...@@ -68,13 +71,13 @@ def flux_to_mag(wave, flux, path, band='GAIA_bp'): ...@@ -68,13 +71,13 @@ def flux_to_mag(wave, flux, path, band='GAIA_bp'):
float: value of magnitude float: value of magnitude
""" """
# /home/yan/MCI_sim/MCI_input/SED_Code/data # /home/yan/MCI_sim/MCI_input/SED_Code/data
##import os #
###parent = os.path.dirname(os.path.realpath(__file__)) # parent = os.path.dirname(os.path.realpath(__file__))
band = ascii.read(path+'MCI_inputData/SED_Code/seddata/' + band + '.dat') band = ascii.read(path+'MCI_inputData/SED_Code/seddata/' + band + '.dat')
wave0= band['col1'] wave0 = band['col1']
curv0= band['col2'] curv0 = band['col2']
# Setting the response # Setting the response
func = interp1d(wave0, curv0) func = interp1d(wave0, curv0)
...@@ -89,7 +92,8 @@ def flux_to_mag(wave, flux, path, band='GAIA_bp'): ...@@ -89,7 +92,8 @@ def flux_to_mag(wave, flux, path, band='GAIA_bp'):
return -2.5 * np.log10(Tflux) return -2.5 * np.log10(Tflux)
def calibrate(wave, flux, mag, path,band='GAIA_bp'):
def calibrate(wave, flux, mag, path, band='GAIA_bp'):
""" """
Calibrate the spectra according to the magnitude. Calibrate the spectra according to the magnitude.
...@@ -102,78 +106,83 @@ def calibrate(wave, flux, mag, path,band='GAIA_bp'): ...@@ -102,78 +106,83 @@ def calibrate(wave, flux, mag, path,band='GAIA_bp'):
Returns: Returns:
float: Flux of calibrated spectra. Units: 1e-17 erg/s/A/cm^2 float: Flux of calibrated spectra. Units: 1e-17 erg/s/A/cm^2
""" """
inst_mag = flux_to_mag(wave, flux, path,band = band) inst_mag = flux_to_mag(wave, flux, path, band=band)
instflux = 10 ** (-0.4 * inst_mag) instflux = 10 ** (-0.4 * inst_mag)
realflux = (mag * u.STmag).to(u.erg/u.s/u.cm**2/u.AA).value realflux = (mag * u.STmag).to(u.erg/u.s/u.cm**2/u.AA).value
# Normalization # Normalization
flux_ratio = realflux / instflux flux_ratio = realflux / instflux
flux_calibrate = flux * flux_ratio * 1e17 # Units: 10^-17 erg/s/A/cm^2 # Units: 10^-17 erg/s/A/cm^2
flux_calibrate = flux * flux_ratio * 1e17
return flux_calibrate return flux_calibrate
# ------------ # ------------
# SED Template # SED Template
class Gal_Temp(): class Gal_Temp():
""" """
Template of Galaxy SED Template of Galaxy SED
""" """
def __init__(self,path): def __init__(self, path):
###import os #
###parent = os.path.dirname(os.path.realpath(__file__)) # parent = os.path.dirname(os.path.realpath(__file__))
self.path=path self.path = path
hdulist = fits.open(self.path+'MCI_inputData/SED_Code/seddata/galaxy_temp.fits') hdulist = fits.open(
self.path+'MCI_inputData/SED_Code/seddata/galaxy_temp.fits')
self.wave = hdulist[1].data['wave'] self.wave = hdulist[1].data['wave']
self.flux = hdulist[2].data self.flux = hdulist[2].data
self.age_grid = hdulist[3].data['logAge'] self.age_grid = hdulist[3].data['logAge']
self.feh_grid = hdulist[3].data['FeH'] self.feh_grid = hdulist[3].data['FeH']
def toMag(self, redshift = 0): def toMag(self, redshift=0):
"""Calculating magnitude """Calculating magnitude
Args: Args:
redshift (float, optional): redshift of spectra. Defaults to 0. redshift (float, optional): redshift of spectra. Defaults to 0.
""" """
wave = self.wave * (1 + redshift) wave = self.wave * (1 + redshift)
self.umag = flux_to_mag(wave, self.flux, self.path,band='SDSS_u') self.umag = flux_to_mag(wave, self.flux, self.path, band='SDSS_u')
self.gmag = flux_to_mag(wave, self.flux, self.path,band='SDSS_g') self.gmag = flux_to_mag(wave, self.flux, self.path, band='SDSS_g')
self.rmag = flux_to_mag(wave, self.flux, self.path,band='SDSS_r') self.rmag = flux_to_mag(wave, self.flux, self.path, band='SDSS_r')
self.imag = flux_to_mag(wave, self.flux, self.path,band='SDSS_i') self.imag = flux_to_mag(wave, self.flux, self.path, band='SDSS_i')
self.zmag = flux_to_mag(wave, self.flux, self.path,band='SDSS_z') self.zmag = flux_to_mag(wave, self.flux, self.path, band='SDSS_z')
class Star_Temp(): class Star_Temp():
""" """
Template of Stellar SED Template of Stellar SED
""" """
def __init__(self,path): def __init__(self, path):
##import os #
self.path=path self.path = path
####parent = os.path.dirname(os.path.realpath(__file__)) # parent = os.path.dirname(os.path.realpath(__file__))
###print("获取其父目录——" + parent) # 从当前文件路径中获取目录 # print("获取其父目录——" + parent) # 从当前文件路径中获取目录
hdulist = fits.open(path+'MCI_inputData/SED_Code/seddata/stellar_temp.fits') hdulist = fits.open(
path+'MCI_inputData/SED_Code/seddata/stellar_temp.fits')
self.wave = hdulist[1].data['wave'] self.wave = hdulist[1].data['wave']
self.flux = hdulist[2].data self.flux = hdulist[2].data
self.Teff_grid = hdulist[3].data['Teff'] self.Teff_grid = hdulist[3].data['Teff']
self.FeH_grid = hdulist[3].data['FeH'] self.FeH_grid = hdulist[3].data['FeH']
self.bpmag = flux_to_mag(self.wave, self.flux, path,band='GAIA_bp') self.bpmag = flux_to_mag(self.wave, self.flux, path, band='GAIA_bp')
self.rpmag = flux_to_mag(self.wave, self.flux, path,band='GAIA_rp') self.rpmag = flux_to_mag(self.wave, self.flux, path, band='GAIA_rp')
def toMag(self): def toMag(self):
wave = self.wave wave = self.wave
self.bpmag = flux_to_mag(wave, self.flux, self.path,band='GAIA_bp') self.bpmag = flux_to_mag(wave, self.flux, self.path, band='GAIA_bp')
self.rpmag = flux_to_mag(wave, self.flux, self.path,band='GAIA_rp') self.rpmag = flux_to_mag(wave, self.flux, self.path, band='GAIA_rp')
# ------------- # -------------
# SED Modelling # SED Modelling
def Model_Stellar_SED(wave, bp, rp, temp): def Model_Stellar_SED(wave, bp, rp, temp):
"""Modelling stellar SED based on bp, rp magnitude """Modelling stellar SED based on bp, rp magnitude
...@@ -194,9 +203,10 @@ def Model_Stellar_SED(wave, bp, rp, temp): ...@@ -194,9 +203,10 @@ def Model_Stellar_SED(wave, bp, rp, temp):
idx = np.argmin(np.abs(colors - color0)) idx = np.argmin(np.abs(colors - color0))
flux0 = temp.flux[idx] flux0 = temp.flux[idx]
flux1 = np.interp(wave, temp.wave, flux0) flux1 = np.interp(wave, temp.wave, flux0)
flux = calibrate(wave, flux1, rp, band = 'GAIA_rp') flux = calibrate(wave, flux1, rp, band='GAIA_rp')
return flux return flux
def Model_Galaxy_SED(wave, ugriz, z, temp, path): def Model_Galaxy_SED(wave, ugriz, z, temp, path):
"""Modelling galaxy SED based on u,g,r,i,z magnitude """Modelling galaxy SED based on u,g,r,i,z magnitude
...@@ -231,12 +241,12 @@ def Model_Galaxy_SED(wave, ugriz, z, temp, path): ...@@ -231,12 +241,12 @@ def Model_Galaxy_SED(wave, ugriz, z, temp, path):
Alambda = Calzetti_Law(np.array([6213 / (1 + z), 7625 / (1 + z)])) Alambda = Calzetti_Law(np.array([6213 / (1 + z), 7625 / (1 + z)]))
eri0 = (Alambda[0] - Alambda[1]) eri0 = (Alambda[0] - Alambda[1])
ebv = dri/eri0 ebv = dri/eri0
if ebv<0: if ebv < 0:
ebv=0 ebv = 0
if ebv>0.5: if ebv > 0.5:
ebv=0.5 ebv = 0.5
flux1 = reddening(temp.wave, flux0, ebv = ebv) flux1 = reddening(temp.wave, flux0, ebv=ebv)
flux2 = np.interp(wave, temp.wave * (1 + z), flux1) flux2 = np.interp(wave, temp.wave * (1 + z), flux1)
flux = calibrate(wave, flux2, ugriz[2], path, band = 'SDSS_r') flux = calibrate(wave, flux2, ugriz[2], path, band='SDSS_r')
return flux return flux
from ctypes import * from ctypes import *
def checkInputList(input_list, n): def checkInputList(input_list, n):
if type(input_list) != type([1, 2, 3]): if not isinstance(type(input_list), list):
# if type(input_list) != type([1, 2, 3]):
raise TypeError("Input type is not list!", input_list) raise TypeError("Input type is not list!", input_list)
for i in input_list: for i in input_list:
if type(i) != type(1.1): if not isinstance(type(i), float): # type(i) != type(1.1):
if type(i) != type(1): if not isinstance(type(i), int): # type(i) != type(1):
raise TypeError("Input list's element is not float or int!", input_list) raise TypeError(
"Input list's element is not float or int!", input_list)
if len(input_list) != n: if len(input_list) != n:
raise RuntimeError("Length of input list is not equal to stars' number!", input_list) raise RuntimeError(
"Length of input list is not equal to stars' number!", input_list)
def onOrbitObsPosition(path, input_ra_list, input_dec_list, input_pmra_list, input_pmdec_list, input_rv_list, \
input_parallax_list, input_nstars, input_x, input_y, input_z, input_vx, input_vy, \ def onOrbitObsPosition(path, input_ra_list, input_dec_list, input_pmra_list, input_pmdec_list, input_rv_list,
input_parallax_list, input_nstars, input_x, input_y, input_z, input_vx, input_vy,
input_vz, input_epoch, input_date_str, input_time_str): input_vz, input_epoch, input_date_str, input_time_str):
#Check input parameters # Check input parameters
if type(input_nstars) != type(1): if not isinstance(type(input_nstars), int): # type(input_nstars) != type(1):
raise TypeError("Parameter 7 is not int!", input_nstars) raise TypeError("Parameter 7 is not int!", input_nstars)
checkInputList(input_ra_list, input_nstars) checkInputList(input_ra_list, input_nstars)
...@@ -24,20 +28,20 @@ def onOrbitObsPosition(path, input_ra_list, input_dec_list, input_pmra_list, inp ...@@ -24,20 +28,20 @@ def onOrbitObsPosition(path, input_ra_list, input_dec_list, input_pmra_list, inp
checkInputList(input_rv_list, input_nstars) checkInputList(input_rv_list, input_nstars)
checkInputList(input_parallax_list, input_nstars) checkInputList(input_parallax_list, input_nstars)
if type(input_x) != type(1.1): if not isinstance(type(input_x), float): # type(input_x) != type(1.1):
raise TypeError("Parameter 8 is not double!", input_x) raise TypeError("Parameter 8 is not double!", input_x)
if type(input_y) != type(1.1): if not isinstance(type(input_y), float): # type(input_y) != type(1.1):
raise TypeError("Parameter 9 is not double!", input_y) raise TypeError("Parameter 9 is not double!", input_y)
if type(input_z) != type(1.1): if not isinstance(type(input_z), float): # type(input_z) != type(1.1):
raise TypeError("Parameter 10 is not double!", input_z) raise TypeError("Parameter 10 is not double!", input_z)
if type(input_vx) != type(1.1): if not isinstance(type(input_vx), float): # type(input_vx) != type(1.1):
raise TypeError("Parameter 11 is not double!", input_vx) raise TypeError("Parameter 11 is not double!", input_vx)
if type(input_vy) != type(1.1): if not isinstance(type(input_vy), float): # type(input_vy) != type(1.1):
raise TypeError("Parameter 12 is not double!", input_vy) raise TypeError("Parameter 12 is not double!", input_vy)
if type(input_vz) != type(1.1): if not isinstance(type(input_vz), float): # type(input_vz) != type(1.1):
raise TypeError("Parameter 13 is not double!", input_vz) raise TypeError("Parameter 13 is not double!", input_vz)
#Convert km -> m # Convert km -> m
input_x = input_x*1000.0 input_x = input_x*1000.0
input_y = input_y*1000.0 input_y = input_y*1000.0
input_z = input_z*1000.0 input_z = input_z*1000.0
...@@ -45,49 +49,57 @@ def onOrbitObsPosition(path, input_ra_list, input_dec_list, input_pmra_list, inp ...@@ -45,49 +49,57 @@ def onOrbitObsPosition(path, input_ra_list, input_dec_list, input_pmra_list, inp
input_vy = input_vy*1000.0 input_vy = input_vy*1000.0
input_vz = input_vz*1000.0 input_vz = input_vz*1000.0
if type(input_date_str) != type("2025-03-05"): if not isinstance(type(input_date_str), str): # type(input_date_str) != type("2025-03-05"):
raise TypeError("Parameter 15 is not string!", input_date_str) raise TypeError("Parameter 15 is not string!", input_date_str)
else: else:
input_date_str = input_date_str.strip() input_date_str = input_date_str.strip()
if not (input_date_str[4]=="-" and input_date_str[7]=="-"): if not (input_date_str[4] == "-" and input_date_str[7] == "-"):
raise TypeError("Parameter 15 format error (1)!", input_date_str) raise TypeError("Parameter 15 format error (1)!", input_date_str)
else: else:
tmp = input_date_str.split("-") tmp = input_date_str.split("-")
if len(tmp) != 3: if len(tmp) != 3:
raise TypeError("Parameter 15 format error (2)!", input_date_str) raise TypeError(
"Parameter 15 format error (2)!", input_date_str)
input_year = int(tmp[0]) input_year = int(tmp[0])
input_month = int(tmp[1]) input_month = int(tmp[1])
input_day = int(tmp[2]) input_day = int(tmp[2])
if not (input_year>=1900 and input_year<=2100): if not (input_year >= 1900 and input_year <= 2100):
raise TypeError("Parameter 15 year range error [1900 ~ 2100]!", input_year) raise TypeError(
if not (input_month>=1 and input_month<=12): "Parameter 15 year range error [1900 ~ 2100]!", input_year)
raise TypeError("Parameter 15 month range error [1 ~ 12]!", input_month) if not (input_month >= 1 and input_month <= 12):
if not (input_day>=1 and input_day<=31): raise TypeError(
raise TypeError("Parameter 15 day range error [1 ~ 31]!", input_day) "Parameter 15 month range error [1 ~ 12]!", input_month)
if not (input_day >= 1 and input_day <= 31):
raise TypeError(
"Parameter 15 day range error [1 ~ 31]!", input_day)
if type(input_time_str) != type("20:15:15.15"): if not isinstance(type(input_time_str), str): # type(input_time_str) != type("20:15:15.15"):
raise TypeError("Parameter 16 is not string!", input_time_str) raise TypeError("Parameter 16 is not string!", input_time_str)
else: else:
input_time_str = input_time_str.strip() input_time_str = input_time_str.strip()
if not (input_time_str[2]==":" and input_time_str[5]==":"): if not (input_time_str[2] == ":" and input_time_str[5] == ":"):
raise TypeError("Parameter 16 format error (1)!", input_time_str) raise TypeError("Parameter 16 format error (1)!", input_time_str)
else: else:
tmp = input_time_str.split(":") tmp = input_time_str.split(":")
if len(tmp) != 3: if len(tmp) != 3:
raise TypeError("Parameter 16 format error (2)!", input_time_str) raise TypeError(
"Parameter 16 format error (2)!", input_time_str)
input_hour = int(tmp[0]) input_hour = int(tmp[0])
input_minute = int(tmp[1]) input_minute = int(tmp[1])
input_second = float(tmp[2]) input_second = float(tmp[2])
if not (input_hour>=0 and input_hour<=23): if not (input_hour >= 0 and input_hour <= 23):
raise TypeError("Parameter 16 hour range error [0 ~ 23]!", input_hour) raise TypeError(
if not (input_minute>=0 and input_minute<=59): "Parameter 16 hour range error [0 ~ 23]!", input_hour)
raise TypeError("Parameter 16 minute range error [0 ~ 59]!", input_minute) if not (input_minute >= 0 and input_minute <= 59):
if not (input_second>=0 and input_second<60.0): raise TypeError(
raise TypeError("Parameter 16 second range error [0 ~ 60)!", input_second) "Parameter 16 minute range error [0 ~ 59]!", input_minute)
#Inital dynamic lib if not (input_second >= 0 and input_second < 60.0):
raise TypeError(
"Parameter 16 second range error [0 ~ 60)!", input_second)
# Inital dynamic lib
import os import os
currfile=os.getcwd() currfile = os.getcwd()
print(currfile) print(currfile)
shao = cdll.LoadLibrary(path+'MCI_inputData/TianCe/libshao.so') shao = cdll.LoadLibrary(path+'MCI_inputData/TianCe/libshao.so')
...@@ -95,11 +107,11 @@ def onOrbitObsPosition(path, input_ra_list, input_dec_list, input_pmra_list, inp ...@@ -95,11 +107,11 @@ def onOrbitObsPosition(path, input_ra_list, input_dec_list, input_pmra_list, inp
shao.onOrbitObs.restype = c_int shao.onOrbitObs.restype = c_int
d3 = c_double * 3 d3 = c_double * 3
shao.onOrbitObs.argtypes = [c_double, c_double, c_double, c_double, c_double, c_double, \ shao.onOrbitObs.argtypes = [c_double, c_double, c_double, c_double, c_double, c_double,
c_int, c_int, c_int, c_int, c_int, c_double, \ c_int, c_int, c_int, c_int, c_int, c_double,
c_double, POINTER(d3), POINTER(d3), \ c_double, POINTER(d3), POINTER(d3),
c_int, c_int, c_int, c_int, c_int, c_double, \ c_int, c_int, c_int, c_int, c_int, c_double,
POINTER(c_double), POINTER(c_double) ] POINTER(c_double), POINTER(c_double)]
output_ra_list = list() output_ra_list = list()
output_dec_list = list() output_dec_list = list()
for i in range(input_nstars): for i in range(input_nstars):
...@@ -120,10 +132,10 @@ def onOrbitObsPosition(path, input_ra_list, input_dec_list, input_pmra_list, inp ...@@ -120,10 +132,10 @@ def onOrbitObsPosition(path, input_ra_list, input_dec_list, input_pmra_list, inp
DAT = c_double(37.0) DAT = c_double(37.0)
output_ra = c_double(0.0) output_ra = c_double(0.0)
output_dec = c_double(0.0) output_dec = c_double(0.0)
rs = shao.onOrbitObs(input_ra, input_dec, input_parallax, input_pmra, input_pmdec, input_rv, \ rs = shao.onOrbitObs(input_ra, input_dec, input_parallax, input_pmra, input_pmdec, input_rv,
input_year_c, input_month_c, input_day_c, input_hour_c, input_minute_c, input_second_c, \ input_year_c, input_month_c, input_day_c, input_hour_c, input_minute_c, input_second_c,
DAT, byref(p3), byref(v3), \ DAT, byref(p3), byref(v3),
input_year_c, input_month_c, input_day_c, input_hour_c, input_minute_c, input_second_c, \ input_year_c, input_month_c, input_day_c, input_hour_c, input_minute_c, input_second_c,
byref(output_ra), byref(output_dec)) byref(output_ra), byref(output_dec))
if rs != 0: if rs != 0:
raise RuntimeError("Calculate error!") raise RuntimeError("Calculate error!")
...@@ -131,4 +143,3 @@ def onOrbitObsPosition(path, input_ra_list, input_dec_list, input_pmra_list, inp ...@@ -131,4 +143,3 @@ def onOrbitObsPosition(path, input_ra_list, input_dec_list, input_pmra_list, inp
output_dec_list.append(output_dec.value) output_dec_list.append(output_dec.value)
return output_ra_list, output_dec_list return output_ra_list, output_dec_list
\ No newline at end of file
...@@ -9,6 +9,7 @@ from astropy import units as u ...@@ -9,6 +9,7 @@ from astropy import units as u
from scipy import interpolate from scipy import interpolate
def zodiacal(ra, dec, time, path): def zodiacal(ra, dec, time, path):
""" """
For given RA, DEC and TIME, return the interpolated zodical spectrum in Leinert-1998. For given RA, DEC and TIME, return the interpolated zodical spectrum in Leinert-1998.
...@@ -34,7 +35,8 @@ def zodiacal(ra, dec, time, path): ...@@ -34,7 +35,8 @@ def zodiacal(ra, dec, time, path):
astro_sun = get_sun(t) astro_sun = get_sun(t)
ra_sun, dec_sun = astro_sun.gcrs.ra.deg, astro_sun.gcrs.dec.deg ra_sun, dec_sun = astro_sun.gcrs.ra.deg, astro_sun.gcrs.dec.deg
radec_sun = SkyCoord(ra=ra_sun*u.degree, dec=dec_sun*u.degree, frame='gcrs') radec_sun = SkyCoord(ra=ra_sun*u.degree,
dec=dec_sun*u.degree, frame='gcrs')
lb_sun = radec_sun.transform_to('geocentrictrueecliptic') lb_sun = radec_sun.transform_to('geocentrictrueecliptic')
# get offsets between the target and sun. # get offsets between the target and sun.
...@@ -45,7 +47,8 @@ def zodiacal(ra, dec, time, path): ...@@ -45,7 +47,8 @@ def zodiacal(ra, dec, time, path):
lamda = abs(lb_obj.lon.degree - lb_sun.lon.degree) lamda = abs(lb_obj.lon.degree - lb_sun.lon.degree)
# interpolated zodical surface brightness at 0.5 um # interpolated zodical surface brightness at 0.5 um
zodi = pd.read_csv(path+'MCI_inputData/refs/zodi_map.dat', sep='\s+', header=None, comment='#') zodi = pd.read_csv(path+'MCI_inputData/refs/zodi_map.dat',
sep='\s+', header=None, comment='#')
beta_angle = np.array([0, 5, 10, 15, 20, 25, 30, 45, 60, 75]) beta_angle = np.array([0, 5, 10, 15, 20, 25, 30, 45, 60, 75])
lamda_angle = np.array([0, 5, 10, 15, 20, 25, 30, 35, 40, 45, lamda_angle = np.array([0, 5, 10, 15, 20, 25, 30, 35, 40, 45,
60, 75, 90, 105, 120, 135, 150, 165, 180]) 60, 75, 90, 105, 120, 135, 150, 165, 180])
...@@ -54,7 +57,8 @@ def zodiacal(ra, dec, time, path): ...@@ -54,7 +57,8 @@ def zodiacal(ra, dec, time, path):
zodi_obj = f(beta, lamda) # 10^�? W m�? sr�? um�? zodi_obj = f(beta, lamda) # 10^�? W m�? sr�? um�?
# read the zodical spectrum in the ecliptic # read the zodical spectrum in the ecliptic
cat_spec = pd.read_csv(path+'MCI_inputData/refs/solar_spec.dat', sep='\s+', header=None, comment='#') cat_spec = pd.read_csv(
path+'MCI_inputData/refs/solar_spec.dat', sep='\s+', header=None, comment='#')
wave = cat_spec[0].values # A wave = cat_spec[0].values # A
spec0 = cat_spec[1].values # 10^-8 W m^�? sr^�? μm^�? spec0 = cat_spec[1].values # 10^-8 W m^�? sr^�? μm^�?
zodi_norm = 252 # 10^-8 W m^�? sr^�? μm^�? zodi_norm = 252 # 10^-8 W m^�? sr^�? μm^�?
...@@ -63,9 +67,10 @@ def zodiacal(ra, dec, time, path): ...@@ -63,9 +67,10 @@ def zodiacal(ra, dec, time, path):
# convert to the commonly used unit of MJy/sr, erg/s/cm^2/A/sr # convert to the commonly used unit of MJy/sr, erg/s/cm^2/A/sr
wave_A = wave # A wave_A = wave # A
#spec_mjy = spec * 0.1 * wave_A**2 / 3e18 * 1e23 * 1e-6 # MJy/sr # spec_mjy = spec * 0.1 * wave_A**2 / 3e18 * 1e23 * 1e-6 # MJy/sr
spec_erg = spec * 0.1 # erg/s/cm^2/A/sr spec_erg = spec * 0.1 # erg/s/cm^2/A/sr
spec_erg2 = spec_erg / 4.25452e10 # erg/s/cm^2/A/arcsec^2 # erg/s/cm^2/A/arcsec^2
spec_erg2 = spec_erg / 4.25452e10
return wave_A, spec_erg2 return wave_A, spec_erg2
......
...@@ -15,9 +15,9 @@ import sys ...@@ -15,9 +15,9 @@ import sys
import faulthandler import faulthandler
from csst_mci_sim import csst_mci_sim from csst_mci_sim import csst_mci_sim
class TestDemoFunction(unittest.TestCase): class TestDemoFunction(unittest.TestCase):
def test_mci_sim_1(self): def test_mci_sim_1(self):
""" """
Aim Aim
--- ---
...@@ -35,7 +35,7 @@ class TestDemoFunction(unittest.TestCase): ...@@ -35,7 +35,7 @@ class TestDemoFunction(unittest.TestCase):
faulthandler.enable() faulthandler.enable()
# demo function test # demo function test
dir_path = os.path.join(os.environ['UNIT_TEST_DATA_ROOT'],'mci_sim/') dir_path = os.path.join(os.environ['UNIT_TEST_DATA_ROOT'], 'mci_sim/')
print(dir_path) print(dir_path)
# 获取当前工作目录 # 获取当前工作目录
...@@ -47,19 +47,20 @@ class TestDemoFunction(unittest.TestCase): ...@@ -47,19 +47,20 @@ class TestDemoFunction(unittest.TestCase):
sourcein = 'EXDF' sourcein = 'EXDF'
print(configfile) print(configfile)
debug=True debug = True
result_path = dir_path +'mci_sim_result/' result_path = dir_path + 'mci_sim_result/'
csst_mci_sim.runMCIsim(sourcein, configfile, dir_path, result_path, debug, 1) csst_mci_sim.runMCIsim(sourcein, configfile,
dir_path, result_path, debug, 1)
self.assertEqual( self.assertEqual(
1 , 1, 1, 1,
"case 1: EXDF sim passes.", "case 1: EXDF sim passes.",
) )
############################################ ############################################
def test_mci_sim_2(self):
def test_mci_sim_2(self):
""" """
Aim Aim
--- ---
...@@ -77,7 +78,7 @@ class TestDemoFunction(unittest.TestCase): ...@@ -77,7 +78,7 @@ class TestDemoFunction(unittest.TestCase):
faulthandler.enable() faulthandler.enable()
# demo function test # demo function test
dir_path = os.path.join(os.environ['UNIT_TEST_DATA_ROOT'],'mci_sim/') dir_path = os.path.join(os.environ['UNIT_TEST_DATA_ROOT'], 'mci_sim/')
print(dir_path) print(dir_path)
# 获取当前工作目录 # 获取当前工作目录
...@@ -89,20 +90,20 @@ class TestDemoFunction(unittest.TestCase): ...@@ -89,20 +90,20 @@ class TestDemoFunction(unittest.TestCase):
sourcein = 'STAR' sourcein = 'STAR'
print(configfile) print(configfile)
debug=True debug = True
result_path = dir_path +'mci_sim_result/' result_path = dir_path + 'mci_sim_result/'
csst_mci_sim.runMCIsim(sourcein, configfile, dir_path, result_path, debug, 1) csst_mci_sim.runMCIsim(sourcein, configfile,
dir_path, result_path, debug, 1)
self.assertEqual( self.assertEqual(
1 , 1, 1, 1,
"case 2: STAR sim passes.", "case 2: STAR sim passes.",
) )
######################################################### #########################################################
def test_mci_sim_3(self): def test_mci_sim_3(self):
""" """
Aim Aim
--- ---
...@@ -120,7 +121,7 @@ class TestDemoFunction(unittest.TestCase): ...@@ -120,7 +121,7 @@ class TestDemoFunction(unittest.TestCase):
faulthandler.enable() faulthandler.enable()
# demo function test # demo function test
dir_path = os.path.join(os.environ['UNIT_TEST_DATA_ROOT'],'mci_sim/') dir_path = os.path.join(os.environ['UNIT_TEST_DATA_ROOT'], 'mci_sim/')
print(dir_path) print(dir_path)
# 获取当前工作目录 # 获取当前工作目录
...@@ -132,19 +133,20 @@ class TestDemoFunction(unittest.TestCase): ...@@ -132,19 +133,20 @@ class TestDemoFunction(unittest.TestCase):
sourcein = 'BIAS' sourcein = 'BIAS'
print(configfile) print(configfile)
debug=True debug = True
result_path = dir_path +'mci_sim_result/' result_path = dir_path + 'mci_sim_result/'
csst_mci_sim.runMCIsim(sourcein, configfile, dir_path, result_path, debug, 1) csst_mci_sim.runMCIsim(sourcein, configfile,
dir_path, result_path, debug, 1)
self.assertEqual( self.assertEqual(
1 , 1, 1, 1,
"case 3: BIAS sim passes.", "case 3: BIAS sim passes.",
) )
######################################################### #########################################################
def test_mci_sim_4(self):
def test_mci_sim_4(self):
""" """
Aim Aim
--- ---
...@@ -162,7 +164,7 @@ class TestDemoFunction(unittest.TestCase): ...@@ -162,7 +164,7 @@ class TestDemoFunction(unittest.TestCase):
faulthandler.enable() faulthandler.enable()
# demo function test # demo function test
dir_path = os.path.join(os.environ['UNIT_TEST_DATA_ROOT'],'mci_sim/') dir_path = os.path.join(os.environ['UNIT_TEST_DATA_ROOT'], 'mci_sim/')
print(dir_path) print(dir_path)
# 获取当前工作目录 # 获取当前工作目录
...@@ -174,20 +176,20 @@ class TestDemoFunction(unittest.TestCase): ...@@ -174,20 +176,20 @@ class TestDemoFunction(unittest.TestCase):
sourcein = 'DARK' sourcein = 'DARK'
print(configfile) print(configfile)
debug=True debug = True
result_path = dir_path +'mci_sim_result/' result_path = dir_path + 'mci_sim_result/'
csst_mci_sim.runMCIsim(sourcein, configfile, dir_path, result_path, debug, 1) csst_mci_sim.runMCIsim(sourcein, configfile,
dir_path, result_path, debug, 1)
self.assertEqual( self.assertEqual(
1 , 1, 1, 1,
"case 4: DARK sim passes.", "case 4: DARK sim passes.",
) )
######################################################### #########################################################
def test_mci_sim_5(self): def test_mci_sim_5(self):
""" """
Aim Aim
--- ---
...@@ -205,7 +207,7 @@ class TestDemoFunction(unittest.TestCase): ...@@ -205,7 +207,7 @@ class TestDemoFunction(unittest.TestCase):
faulthandler.enable() faulthandler.enable()
# demo function test # demo function test
dir_path = os.path.join(os.environ['UNIT_TEST_DATA_ROOT'],'mci_sim/') dir_path = os.path.join(os.environ['UNIT_TEST_DATA_ROOT'], 'mci_sim/')
print(dir_path) print(dir_path)
# 获取当前工作目录 # 获取当前工作目录
...@@ -216,19 +218,17 @@ class TestDemoFunction(unittest.TestCase): ...@@ -216,19 +218,17 @@ class TestDemoFunction(unittest.TestCase):
sourcein = 'FLAT' sourcein = 'FLAT'
print(configfile) print(configfile)
debug=True debug = True
result_path = dir_path +'mci_sim_result/' result_path = dir_path + 'mci_sim_result/'
csst_mci_sim.runMCIsim(sourcein, configfile, dir_path, result_path, debug, 1) csst_mci_sim.runMCIsim(sourcein, configfile,
dir_path, result_path, debug, 1)
self.assertEqual( self.assertEqual(
1 , 1, 1, 1,
"case 5: FLAT sim passes.", "case 5: FLAT sim passes.",
) )
# ############################################################################ # ############################################################################
\ No newline at end of file
...@@ -42,7 +42,7 @@ class TestDemoFunction(unittest.TestCase): ...@@ -42,7 +42,7 @@ class TestDemoFunction(unittest.TestCase):
# current_path = os.getcwd() # current_path = os.getcwd()
# print("当前路径:", current_path) # print("当前路径:", current_path)
wave0, zodi0=zodiacal.zodiacal(10.0, 20.0, '2024-04-04', dir_path) wave0, zodi0 = zodiacal.zodiacal(10.0, 20.0, '2024-04-04', dir_path)
self.assertEqual( self.assertEqual(
1, 1, 1, 1,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment