Commit 9f1aaa66 authored by Zhang Xin's avatar Zhang Xin
Browse files

Merge branch 'wcs_test' into 'develop'

starting point of the new version

See merge request csst_sim/csst-simulation!15
parents 10e28e08 a4832bdf
...@@ -48,7 +48,7 @@ class Observation(object): ...@@ -48,7 +48,7 @@ class Observation(object):
self.filter_list.append(filt) self.filter_list.append(filt)
self.all_filter.append(filt) self.all_filter.append(filt)
def run_one_chip(self, chip, filt, pointing, chip_output, wcs_fp=None, psf_model=None, shear_cat_file=None, cat_dir=None, sed_dir=None): def run_one_chip(self, chip, filt, pointing, chip_output, wcs_fp=None, psf_model=None, cat_dir=None, sed_dir=None):
chip_output.Log_info(':::::::::::::::::::Current Pointing Information::::::::::::::::::') chip_output.Log_info(':::::::::::::::::::Current Pointing Information::::::::::::::::::')
chip_output.Log_info("RA: %f, DEC; %f" % (pointing.ra, pointing.dec)) chip_output.Log_info("RA: %f, DEC; %f" % (pointing.ra, pointing.dec))
...@@ -68,8 +68,7 @@ class Observation(object): ...@@ -68,8 +68,7 @@ class Observation(object):
chip_output.Log_error("unrecognized PSF model type!!", flush=True) chip_output.Log_error("unrecognized PSF model type!!", flush=True)
# Figure out shear fields # Figure out shear fields
if shear_cat_file is not None: self.g1_field, self.g2_field, self.nshear = get_shear_field(config=self.config)
self.g1_field, self.g2_field, self.nshear = get_shear_field(config=self.config, shear_cat_file=shear_cat_file)
# Apply astrometric simulation for pointing # Apply astrometric simulation for pointing
if self.config["obs_setting"]["enable_astrometric_model"]: if self.config["obs_setting"]["enable_astrometric_model"]:
...@@ -111,8 +110,8 @@ class Observation(object): ...@@ -111,8 +110,8 @@ class Observation(object):
if self.config["obs_setting"]["enable_straylight_model"]: if self.config["obs_setting"]["enable_straylight_model"]:
filt.setFilterStrayLightPixel(jtime = pointing.jdt, sat_pos = np.array([pointing.sat_x, pointing.sat_y, pointing.sat_z]), pointing_radec = np.array([pointing.ra,pointing.dec]), sun_pos = np.array([pointing.sun_x,pointing.sun_y,pointing.sun_z])) filt.setFilterStrayLightPixel(jtime = pointing.jdt, sat_pos = np.array([pointing.sat_x, pointing.sat_y, pointing.sat_z]), pointing_radec = np.array([pointing.ra,pointing.dec]), sun_pos = np.array([pointing.sun_x,pointing.sun_y,pointing.sun_z]))
print("========================sky pix========================") chip_output.Log_info("========================sky pix========================")
print(filt.sky_background) chip_output.Log_info(filt.sky_background)
if chip.survey_type == "photometric": if chip.survey_type == "photometric":
sky_map = None sky_map = None
...@@ -160,7 +159,7 @@ class Observation(object): ...@@ -160,7 +159,7 @@ class Observation(object):
cut_filter = temp_filter cut_filter = temp_filter
if self.config["ins_effects"]["field_dist"] == True: if self.config["ins_effects"]["field_dist"] == True:
self.fd_model = FieldDistortion(chip=chip) self.fd_model = FieldDistortion(chip=chip, img_rot=pointing.img_pa.deg)
else: else:
self.fd_model = None self.fd_model = None
...@@ -188,6 +187,8 @@ class Observation(object): ...@@ -188,6 +187,8 @@ class Observation(object):
timestamp = pointing.timestamp, timestamp = pointing.timestamp,
exptime = pointing.exp_time, exptime = pointing.exp_time,
readoutTime = 40.) readoutTime = 40.)
chip_wcs = galsim.FitsWCS(header=h_ext)
for j in range(self.nobj): for j in range(self.nobj):
...@@ -243,12 +244,6 @@ class Observation(object): ...@@ -243,12 +244,6 @@ class Observation(object):
obj.g1, obj.g2 = 0., 0. obj.g1, obj.g2 = 0., 0.
else: else:
obj.g1, obj.g2 = self.g1_field, self.g2_field obj.g1, obj.g2 = self.g1_field, self.g2_field
elif self.config["shear_setting"]["shear_type"] == "extra":
try:
# [TODO]: every object with individual shear from input catalog(s)
obj.g1, obj.g2 = self.g1_field[j], self.g2_field[j]
except:
chip_output.Log_error("failed to load external shear.")
elif self.config["shear_setting"]["shear_type"] == "catalog": elif self.config["shear_setting"]["shear_type"] == "catalog":
pass pass
else: else:
...@@ -256,7 +251,7 @@ class Observation(object): ...@@ -256,7 +251,7 @@ class Observation(object):
raise ValueError("Unknown shear input") raise ValueError("Unknown shear input")
# Get position of object on the focal plane # Get position of object on the focal plane
pos_img, offset, local_wcs, real_wcs, fd_shear = obj.getPosImg_Offset_WCS(img=chip.img, fdmodel=self.fd_model, chip=chip, verbose=False, img_header=h_ext) pos_img, offset, local_wcs, real_wcs, fd_shear = obj.getPosImg_Offset_WCS(img=chip.img, fdmodel=self.fd_model, chip=chip, verbose=False, chip_wcs=chip_wcs, img_header=h_ext)
# [TODO] For now, only consider objects which their centers (after field distortion) are projected within the focal plane # [TODO] For now, only consider objects which their centers (after field distortion) are projected within the focal plane
# Otherwise they will be considered missed objects # Otherwise they will be considered missed objects
...@@ -410,7 +405,7 @@ class Observation(object): ...@@ -410,7 +405,7 @@ class Observation(object):
chip_output.Log_info("check running:2: pointing-%d chip-%d pid-%d memory-%6.2fGB"%(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )) chip_output.Log_info("check running:2: pointing-%d chip-%d pid-%d memory-%6.2fGB"%(pointing.id, chip.chipID, os.getpid(), (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ))
def runExposure_MPI_PointingList(self, pointing_list, shear_cat_file=None, chips=None, use_mpi=False): def runExposure_MPI_PointingList(self, pointing_list,chips=None, use_mpi=False):
if use_mpi: if use_mpi:
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
ind_thread = comm.Get_rank() ind_thread = comm.Get_rank()
......
import galsim import galsim
import numpy as np import numpy as np
import cmath
class FieldDistortion(object): class FieldDistortion(object):
def __init__(self, chip, fdModel=None, fdModel_path=None): def __init__(self, chip, fdModel=None, fdModel_path=None, img_rot=0.):
if fdModel is None: if fdModel is None:
if hasattr(chip, 'fdModel'): if hasattr(chip, 'fdModel'):
self.fdModel = chip.fdModel self.fdModel = chip.fdModel
...@@ -15,6 +16,7 @@ class FieldDistortion(object): ...@@ -15,6 +16,7 @@ class FieldDistortion(object):
raise ValueError("Error: no field distortion model has been specified!") raise ValueError("Error: no field distortion model has been specified!")
else: else:
self.fdModel = fdModel self.fdModel = fdModel
self.img_rot = img_rot
self.ifdModel = self.fdModel["wave1"] self.ifdModel = self.fdModel["wave1"]
self.ixfdModel = self.ifdModel["xImagePos"] self.ixfdModel = self.ifdModel["xImagePos"]
self.iyfdModel = self.ifdModel["yImagePos"] self.iyfdModel = self.ifdModel["yImagePos"]
...@@ -42,7 +44,7 @@ class FieldDistortion(object): ...@@ -42,7 +44,7 @@ class FieldDistortion(object):
return False return False
return True return True
def get_distorted(self, chip, pos_img, bandpass=None): def get_distorted(self, chip, pos_img, bandpass=None, img_rot=None):
""" Get the distored position for an undistorted image position """ Get the distored position for an undistorted image position
Parameters: Parameters:
...@@ -58,14 +60,14 @@ class FieldDistortion(object): ...@@ -58,14 +60,14 @@ class FieldDistortion(object):
""" """
if not self.isContainObj_FD(chip=chip, pos_img=pos_img): if not self.isContainObj_FD(chip=chip, pos_img=pos_img):
return galsim.PositionD(-1, -1), None return galsim.PositionD(-1, -1), None
if not img_rot:
img_rot = np.radians(self.img_rot)
else:
img_rot = np.radians(img_rot)
x, y = pos_img.x, pos_img.y x, y = pos_img.x, pos_img.y
x = self.ixfdModel(x, y)[0][0] x = self.ixfdModel(x, y)[0][0]
y = self.iyfdModel(x, y)[0][0] y = self.iyfdModel(x, y)[0][0]
ix_dx = self.ifx_dx(x, y) if self.irsModel:
ix_dy = self.ifx_dy(x, y)
iy_dx = self.ify_dx(x, y)
iy_dy = self.ify_dy(x, y)
if self.irsModel is not None:
# x1LowI, x1UpI, y1LowI, y1UpI = self.irsModel["interpLimit"] # x1LowI, x1UpI, y1LowI, y1UpI = self.irsModel["interpLimit"]
# if (x1LowI-x)*(x1UpI-x) <=0 and (y1LowI-y)*(y1UpI-y)<=0: # if (x1LowI-x)*(x1UpI-x) <=0 and (y1LowI-y)*(y1UpI-y)<=0:
# dx = self.ixrsModel(x, y)[0][0] # dx = self.ixrsModel(x, y)[0][0]
...@@ -88,8 +90,21 @@ class FieldDistortion(object): ...@@ -88,8 +90,21 @@ class FieldDistortion(object):
ix_dy = self.ifx_dy(x, y) + self.irx_dy(x, y) ix_dy = self.ifx_dy(x, y) + self.irx_dy(x, y)
iy_dx = self.ify_dx(x, y) + self.iry_dx(x, y) iy_dx = self.ify_dx(x, y) + self.iry_dx(x, y)
iy_dy = self.ify_dy(x, y) + self.iry_dy(x, y) iy_dy = self.ify_dy(x, y) + self.iry_dy(x, y)
else:
ix_dx = self.ifx_dx(x, y)
ix_dy = self.ifx_dy(x, y)
iy_dx = self.ify_dx(x, y)
iy_dy = self.ify_dy(x, y)
g1k_fd = 0.0 + (iy_dy - ix_dx) / (iy_dy + ix_dx) g1k_fd = 0.0 + (iy_dy - ix_dx) / (iy_dy + ix_dx)
g2k_fd = 0.0 - (iy_dx + ix_dy) / (iy_dy + ix_dx) g2k_fd = 0.0 - (iy_dx + ix_dy) / (iy_dy + ix_dx)
# [TODO] [TESTING] Rotate the shear:
g_abs = np.sqrt(g1k_fd**2 + g2k_fd**2)
phi = cmath.phase(complex(g1k_fd, g2k_fd))
# g_abs = 0.7
g1k_fd = g_abs * np.cos(phi + 2*img_rot)
g2k_fd = g_abs * np.sin(phi + 2*img_rot)
fd_shear = galsim.Shear(g1=g1k_fd, g2=g2k_fd) fd_shear = galsim.Shear(g1=g1k_fd, g2=g2k_fd)
return galsim.PositionD(x, y), fd_shear return galsim.PositionD(x, y), fd_shear
...@@ -144,8 +144,8 @@ def makeSubDir_PointingList(path_dict, config, pointing_ID=0): ...@@ -144,8 +144,8 @@ def makeSubDir_PointingList(path_dict, config, pointing_ID=0):
pass pass
return subImgdir, prefix return subImgdir, prefix
def get_shear_field(config, shear_cat_file=None): def get_shear_field(config):
if not config["shear_setting"]["shear_type"] in ["constant", "extra", "catalog"]: if not config["shear_setting"]["shear_type"] in ["constant", "catalog"]:
raise ValueError("Please set a right 'shear_method' parameter.") raise ValueError("Please set a right 'shear_method' parameter.")
if config["shear_setting"]["shear_type"] == "constant": if config["shear_setting"]["shear_type"] == "constant":
...@@ -153,18 +153,6 @@ def get_shear_field(config, shear_cat_file=None): ...@@ -153,18 +153,6 @@ def get_shear_field(config, shear_cat_file=None):
g2 = config["shear_setting"]["reduced_g2"] g2 = config["shear_setting"]["reduced_g2"]
nshear = 1 nshear = 1
# TODO logging # TODO logging
elif config["shear_setting"]["shear_type"] == "extra":
# TODO logging
if not os.path.exists(shear_cat_file):
raise ValueError("Cannot find external shear catalog file.")
try:
shearCat = np.loadtxt(shear_cat_file)
nshear = shearCat.shape[0]
g1, g2 = shearCat[:, 0], shearCat[:, 1]
except:
print("Failed to read the shear catalog file.")
print("Setting to no shear.")
g1, g2 = 0., 0.
else: else:
g1, g2 = 0., 0. g1, g2 = 0., 0.
nshear = 0 nshear = 0
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment