Commit edffea7b authored by Zhang Xin's avatar Zhang Xin
Browse files

delete jax lib

parent 5634e275
Pipeline #6492 passed with stage
in 0 seconds
......@@ -23,7 +23,7 @@ from astropy.modeling.models import Gaussian2D
from scipy import signal, interpolate
import datetime
import gc
from jax import numpy as jnp
# from jax import numpy as jnp
LOG_DEBUG = False # ***#
NPSF = 900 # ***# 30*30
......@@ -537,14 +537,14 @@ class PSFInterpSLS(PSFModel):
sumImg = np.sum(cutImg.array)
tmp_img = cutImg*0
for j in np.arange(npc):
X_ = jnp.hstack((pos_p[:,1].flatten()[:, None], pos_p[:,0].flatten()[:, None]),dtype=np.float32)
X_ = np.hstack((pos_p[:,1].flatten()[:, None], pos_p[:,0].flatten()[:, None]),dtype=np.float32)
Z_ = (pc_coeff[j].astype(np.float32)).flatten()
# print(pc_coeff[j].shape[0], pos_p[:,1].shape[0], pos_p[:,0].shape[0])
cx_len = int(chip.npix_x)
cy_len = int(chip.npix_y)
n_x = jnp.arange(0, cx_len, 1, dtype = int)
n_y = jnp.arange(0, cy_len, 1, dtype = int)
M, N = jnp.meshgrid(n_x, n_y)
n_x = np.arange(0, cx_len, 1, dtype = int)
n_y = np.arange(0, cy_len, 1, dtype = int)
M, N = np.meshgrid(n_x, n_y)
# t1=datetime.datetime.now()
# U = interpolate.griddata(X_, Z_, (M[0:cy_len, 0:cx_len],N[0:cy_len, 0:cx_len]),
# method='nearest',fill_value=1.0)
......@@ -663,16 +663,16 @@ class PSFInterpSLS(PSFModel):
tmp_img = np.zeros_like(img.array,dtype=np.float32)
for j in np.arange(npca):
print(gt, od, w, j)
X_ = jnp.hstack((pos_p[:,1].flatten()[:, None], pos_p[:,0].flatten()[:, None]),dtype=np.float32)
X_ = np.hstack((pos_p[:,1].flatten()[:, None], pos_p[:,0].flatten()[:, None]),dtype=np.float32)
Z_ = (pc_coeff[j].astype(np.float32)).flatten()
# print(pc_coeff[j].shape[0], pos_p[:,1].shape[0], pos_p[:,0].shape[0])
sub_size = 4
cx_len = int(chip.npix_x/sub_size)
cy_len = int(chip.npix_y/sub_size)
n_x = jnp.arange(0, chip.npix_x, sub_size, dtype = int)
n_y = jnp.arange(0, chip.npix_y, sub_size, dtype = int)
n_x = np.arange(0, chip.npix_x, sub_size, dtype = int)
n_y = np.arange(0, chip.npix_y, sub_size, dtype = int)
M, N = jnp.meshgrid(n_x, n_y)
M, N = np.meshgrid(n_x, n_y)
t1=datetime.datetime.now()
# U = interpolate.griddata(X_, Z_, (M[0:cy_len, 0:cx_len],N[0:cy_len, 0:cx_len]),
# method='nearest',fill_value=1.0)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment