Commit 45494b69 authored by Zhang Xin's avatar Zhang Xin
Browse files

fix position bug

parent 74c8e88d
...@@ -323,7 +323,7 @@ class MockObject(object): ...@@ -323,7 +323,7 @@ class MockObject(object):
# specImg.wcs = local_wcs # specImg.wcs = local_wcs
# specImg.setOrigin(origin_order_x, origin_order_y) # specImg.setOrigin(origin_order_x, origin_order_y)
print('DEBUG: BEGIN -----------',bandNo,k) # print('DEBUG: BEGIN -----------',bandNo,k)
img_s = v[0] img_s = v[0]
...@@ -339,7 +339,7 @@ class MockObject(object): ...@@ -339,7 +339,7 @@ class MockObject(object):
specImg.wcs = local_wcs specImg.wcs = local_wcs
specImg.setOrigin(origin_order_x, origin_order_y) specImg.setOrigin(origin_order_x, origin_order_y)
try: try:
specImg = psf_model.get_PSF_AND_convolve_withsubImg(chip, cutImg=specImg, bandNo=bandNo, g_order=k, grating_split_pos=grating_split_pos) specImg = psf_model.get_PSF_AND_convolve_withsubImg(chip, cutImg=specImg, pos_img_local=pos_img_local, bandNo=bandNo, g_order=k, grating_split_pos=grating_split_pos)
except: except:
psf, pos_shear = psf_model.get_PSF(chip=chip, pos_img=pos_img) psf, pos_shear = psf_model.get_PSF(chip=chip, pos_img=pos_img)
......
...@@ -481,7 +481,7 @@ class PSFInterpSLS(PSFModel): ...@@ -481,7 +481,7 @@ class PSFInterpSLS(PSFModel):
return PSF_int_trans, PSF_int return PSF_int_trans, PSF_int
def get_PSF_AND_convolve_withsubImg(self, chip, cutImg=None, bandNo=1, g_order='A', grating_split_pos=3685): def get_PSF_AND_convolve_withsubImg(self, chip, cutImg=None, pos_img_local=[1000, 1000], bandNo=1, g_order='A', grating_split_pos=3685):
""" """
Get the PSF at a given image position Get the PSF at a given image position
...@@ -503,8 +503,8 @@ class PSFInterpSLS(PSFModel): ...@@ -503,8 +503,8 @@ class PSFInterpSLS(PSFModel):
# pos_img_x = pos_img_local[0] + x_start # pos_img_x = pos_img_local[0] + x_start
# pos_img_y = pos_img_local[1] + y_start # pos_img_y = pos_img_local[1] + y_start
# pos_img = galsim.PositionD(pos_img_x, pos_img_y) # pos_img = galsim.PositionD(pos_img_x, pos_img_y)
centerPos_local = cutImg.ncol/2. # centerPos_local = cutImg.ncol/2.
if centerPos_local < grating_split_pos: if pos_img_local[0] < grating_split_pos:
psf_data = self.grating1_data psf_data = self.grating1_data
else: else:
psf_data = self.grating2_data psf_data = self.grating2_data
...@@ -525,6 +525,7 @@ class PSFInterpSLS(PSFModel): ...@@ -525,6 +525,7 @@ class PSFInterpSLS(PSFModel):
npc = 10 npc = 10
m_size = int(pcs.shape[0]**0.5) m_size = int(pcs.shape[0]**0.5)
sumImg = np.sum(cutImg.array)
tmp_img = cutImg*0 tmp_img = cutImg*0
for j in np.arange(npc): for j in np.arange(npc):
X_ = jnp.hstack((pos_p[:,1].flatten()[:, None], pos_p[:,0].flatten()[:, None]),dtype=np.float32) X_ = jnp.hstack((pos_p[:,1].flatten()[:, None], pos_p[:,0].flatten()[:, None]),dtype=np.float32)
...@@ -534,43 +535,58 @@ class PSFInterpSLS(PSFModel): ...@@ -534,43 +535,58 @@ class PSFInterpSLS(PSFModel):
cy_len = int(chip.npix_y) cy_len = int(chip.npix_y)
n_x = jnp.arange(0, cx_len, 1, dtype = int) n_x = jnp.arange(0, cx_len, 1, dtype = int)
n_y = jnp.arange(0, cy_len, 1, dtype = int) n_y = jnp.arange(0, cy_len, 1, dtype = int)
M, N = jnp.meshgrid(n_x, n_y) M, N = jnp.meshgrid(n_x, n_y)
t1=datetime.datetime.now() # t1=datetime.datetime.now()
# U = interpolate.griddata(X_, Z_, (M[0:cy_len, 0:cx_len],N[0:cy_len, 0:cx_len]), # U = interpolate.griddata(X_, Z_, (M[0:cy_len, 0:cx_len],N[0:cy_len, 0:cx_len]),
# method='nearest',fill_value=1.0) # method='nearest',fill_value=1.0)
ys = cutImg.ymin b_img = galsim.Image(cx_len, cy_len)
if ys < 0: b_img.setOrigin(0,0)
ys = 0 bounds = cutImg.bounds & b_img.bounds
ye = cutImg.ymin+cutImg.nrow if bounds.area() == 0:
if ye >= cy_len-1:
ye = cy_len-1
if ye - ys <=0:
continue
xs = cutImg.xmin
if xs < 0:
xs = 0
xe = cutImg.xmin+cutImg.ncol
if xe >= cx_len-1:
xe = cx_len-1
if xe - xs <=0:
continue continue
# ys = cutImg.ymin
# if ys < 0:
# ys = 0
# ye = cutImg.ymin+cutImg.nrow
# if ye >= cy_len-1:
# ye = cy_len-1
# if ye - ys <=0:
# continue
# xs = cutImg.xmin
# if xs < 0:
# xs = 0
# xe = cutImg.xmin+cutImg.ncol
# if xe >= cx_len-1:
# xe = cx_len-1
# if xe - xs <=0:
# continue
ys = bounds.ymin
ye = bounds.ymax+1
xs = bounds.xmin
xe = bounds.xmax+1
U = interpolate.griddata(X_, Z_, (M[ys:ye, xs:xe],N[ys:ye, xs:xe]), U = interpolate.griddata(X_, Z_, (M[ys:ye, xs:xe],N[ys:ye, xs:xe]),
method='nearest',fill_value=1.0) method='nearest',fill_value=1.0)
t2=datetime.datetime.now() # t2=datetime.datetime.now()
print("time interpolate:", t2-t1) # print("time interpolate:", t2-t1)
if U.shape != cutImg.array.shape: # if U.shape != cutImg.array.shape:
print('DEBUG:SHAPE',cutImg.ncol,cutImg.nrow,cutImg.xmin, cutImg.ymin) # print('DEBUG:SHAPE',cutImg.ncol,cutImg.nrow,cutImg.xmin, cutImg.ymin)
continue # continue
img_tmp = cutImg.array*U img_tmp = cutImg
img_tmp[bounds] = img_tmp[bounds]*U
psf = pcs[:, j].reshape(m_size, m_size) psf = pcs[:, j].reshape(m_size, m_size)
tmp_img = tmp_img + signal.fftconvolve(img_tmp, psf, mode='same', axes=None) tmp_img = tmp_img + signal.fftconvolve(img_tmp.array, psf, mode='same', axes=None)
t3=datetime.datetime.now() # t3=datetime.datetime.now()
print("time convole:", t3-t2) # print("time convole:", t3-t2)
del U del U
del img_tmp
if np.sum(tmp_img.array)==0:
tmp_img = cutImg
else:
tmp_img = tmp_img/np.sum(tmp_img.array)*sumImg
return tmp_img return tmp_img
...@@ -644,17 +660,17 @@ class PSFInterpSLS(PSFModel): ...@@ -644,17 +660,17 @@ class PSFInterpSLS(PSFModel):
sub_size = 4 sub_size = 4
cx_len = int(chip.npix_x/sub_size) cx_len = int(chip.npix_x/sub_size)
cy_len = int(chip.npix_y/sub_size) cy_len = int(chip.npix_y/sub_size)
n_x = jnp.arange(0, cx_len, 1, dtype = int) n_x = jnp.arange(0, chip.npix_x, sub_size, dtype = int)
n_y = jnp.arange(0, cy_len, 1, dtype = int) n_y = jnp.arange(0, chip.npix_y, sub_size, dtype = int)
M, N = jnp.meshgrid(n_x, n_y) M, N = jnp.meshgrid(n_x, n_y)
t1=datetime.datetime.now() t1=datetime.datetime.now()
# U = interpolate.griddata(X_, Z_, (M[0:cy_len, 0:cx_len],N[0:cy_len, 0:cx_len]), # U = interpolate.griddata(X_, Z_, (M[0:cy_len, 0:cx_len],N[0:cy_len, 0:cx_len]),
# method='nearest',fill_value=1.0) # method='nearest',fill_value=1.0)
U1 = interpolate.griddata(X_, Z_, (M[0:cy_len, 0:cx_len],N[0:cy_len, 0:cx_len]), U1 = interpolate.griddata(X_, Z_, (M, N),
method='nearest',fill_value=1.0) method='nearest',fill_value=1.0)
U = np.zeros_like(chip.img.array, dtype=np.float32) U = np.zeros_like(chip.img.array, dtype=np.float32)
for mi in np.arange(cx_len): for mi in np.arange(cy_len):
for mj in np.arange(cx_len): for mj in np.arange(cx_len):
U[mi*sub_size:(mi+1)*sub_size, mj*sub_size:(mj+1)*sub_size]=U1[mi,mj] U[mi*sub_size:(mi+1)*sub_size, mj*sub_size:(mj+1)*sub_size]=U1[mi,mj]
t2=datetime.datetime.now() t2=datetime.datetime.now()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment