Commit 45494b69 authored by Zhang Xin's avatar Zhang Xin
Browse files

fix position bug

parent 74c8e88d
......@@ -323,7 +323,7 @@ class MockObject(object):
# specImg.wcs = local_wcs
# specImg.setOrigin(origin_order_x, origin_order_y)
print('DEBUG: BEGIN -----------',bandNo,k)
# print('DEBUG: BEGIN -----------',bandNo,k)
img_s = v[0]
......@@ -339,7 +339,7 @@ class MockObject(object):
specImg.wcs = local_wcs
specImg.setOrigin(origin_order_x, origin_order_y)
try:
specImg = psf_model.get_PSF_AND_convolve_withsubImg(chip, cutImg=specImg, bandNo=bandNo, g_order=k, grating_split_pos=grating_split_pos)
specImg = psf_model.get_PSF_AND_convolve_withsubImg(chip, cutImg=specImg, pos_img_local=pos_img_local, bandNo=bandNo, g_order=k, grating_split_pos=grating_split_pos)
except:
psf, pos_shear = psf_model.get_PSF(chip=chip, pos_img=pos_img)
......
......@@ -481,7 +481,7 @@ class PSFInterpSLS(PSFModel):
return PSF_int_trans, PSF_int
def get_PSF_AND_convolve_withsubImg(self, chip, cutImg=None, bandNo=1, g_order='A', grating_split_pos=3685):
def get_PSF_AND_convolve_withsubImg(self, chip, cutImg=None, pos_img_local=[1000, 1000], bandNo=1, g_order='A', grating_split_pos=3685):
"""
Get the PSF at a given image position
......@@ -503,8 +503,8 @@ class PSFInterpSLS(PSFModel):
# pos_img_x = pos_img_local[0] + x_start
# pos_img_y = pos_img_local[1] + y_start
# pos_img = galsim.PositionD(pos_img_x, pos_img_y)
centerPos_local = cutImg.ncol/2.
if centerPos_local < grating_split_pos:
# centerPos_local = cutImg.ncol/2.
if pos_img_local[0] < grating_split_pos:
psf_data = self.grating1_data
else:
psf_data = self.grating2_data
......@@ -525,6 +525,7 @@ class PSFInterpSLS(PSFModel):
npc = 10
m_size = int(pcs.shape[0]**0.5)
sumImg = np.sum(cutImg.array)
tmp_img = cutImg*0
for j in np.arange(npc):
X_ = jnp.hstack((pos_p[:,1].flatten()[:, None], pos_p[:,0].flatten()[:, None]),dtype=np.float32)
......@@ -534,43 +535,58 @@ class PSFInterpSLS(PSFModel):
cy_len = int(chip.npix_y)
n_x = jnp.arange(0, cx_len, 1, dtype = int)
n_y = jnp.arange(0, cy_len, 1, dtype = int)
M, N = jnp.meshgrid(n_x, n_y)
t1=datetime.datetime.now()
# t1=datetime.datetime.now()
# U = interpolate.griddata(X_, Z_, (M[0:cy_len, 0:cx_len],N[0:cy_len, 0:cx_len]),
# method='nearest',fill_value=1.0)
ys = cutImg.ymin
if ys < 0:
ys = 0
ye = cutImg.ymin+cutImg.nrow
if ye >= cy_len-1:
ye = cy_len-1
if ye - ys <=0:
continue
xs = cutImg.xmin
if xs < 0:
xs = 0
xe = cutImg.xmin+cutImg.ncol
if xe >= cx_len-1:
xe = cx_len-1
if xe - xs <=0:
b_img = galsim.Image(cx_len, cy_len)
b_img.setOrigin(0,0)
bounds = cutImg.bounds & b_img.bounds
if bounds.area() == 0:
continue
# ys = cutImg.ymin
# if ys < 0:
# ys = 0
# ye = cutImg.ymin+cutImg.nrow
# if ye >= cy_len-1:
# ye = cy_len-1
# if ye - ys <=0:
# continue
# xs = cutImg.xmin
# if xs < 0:
# xs = 0
# xe = cutImg.xmin+cutImg.ncol
# if xe >= cx_len-1:
# xe = cx_len-1
# if xe - xs <=0:
# continue
ys = bounds.ymin
ye = bounds.ymax+1
xs = bounds.xmin
xe = bounds.xmax+1
U = interpolate.griddata(X_, Z_, (M[ys:ye, xs:xe],N[ys:ye, xs:xe]),
method='nearest',fill_value=1.0)
t2=datetime.datetime.now()
# t2=datetime.datetime.now()
print("time interpolate:", t2-t1)
# print("time interpolate:", t2-t1)
if U.shape != cutImg.array.shape:
print('DEBUG:SHAPE',cutImg.ncol,cutImg.nrow,cutImg.xmin, cutImg.ymin)
continue
img_tmp = cutImg.array*U
# if U.shape != cutImg.array.shape:
# print('DEBUG:SHAPE',cutImg.ncol,cutImg.nrow,cutImg.xmin, cutImg.ymin)
# continue
img_tmp = cutImg
img_tmp[bounds] = img_tmp[bounds]*U
psf = pcs[:, j].reshape(m_size, m_size)
tmp_img = tmp_img + signal.fftconvolve(img_tmp, psf, mode='same', axes=None)
tmp_img = tmp_img + signal.fftconvolve(img_tmp.array, psf, mode='same', axes=None)
t3=datetime.datetime.now()
print("time convole:", t3-t2)
# t3=datetime.datetime.now()
# print("time convole:", t3-t2)
del U
del img_tmp
if np.sum(tmp_img.array)==0:
tmp_img = cutImg
else:
tmp_img = tmp_img/np.sum(tmp_img.array)*sumImg
return tmp_img
......@@ -644,17 +660,17 @@ class PSFInterpSLS(PSFModel):
sub_size = 4
cx_len = int(chip.npix_x/sub_size)
cy_len = int(chip.npix_y/sub_size)
n_x = jnp.arange(0, cx_len, 1, dtype = int)
n_y = jnp.arange(0, cy_len, 1, dtype = int)
n_x = jnp.arange(0, chip.npix_x, sub_size, dtype = int)
n_y = jnp.arange(0, chip.npix_y, sub_size, dtype = int)
M, N = jnp.meshgrid(n_x, n_y)
t1=datetime.datetime.now()
# U = interpolate.griddata(X_, Z_, (M[0:cy_len, 0:cx_len],N[0:cy_len, 0:cx_len]),
# method='nearest',fill_value=1.0)
U1 = interpolate.griddata(X_, Z_, (M[0:cy_len, 0:cx_len],N[0:cy_len, 0:cx_len]),
U1 = interpolate.griddata(X_, Z_, (M, N),
method='nearest',fill_value=1.0)
U = np.zeros_like(chip.img.array, dtype=np.float32)
for mi in np.arange(cx_len):
for mi in np.arange(cy_len):
for mj in np.arange(cx_len):
U[mi*sub_size:(mi+1)*sub_size, mj*sub_size:(mj+1)*sub_size]=U1[mi,mj]
t2=datetime.datetime.now()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment