CTI_modeling.py 3.93 KB
Newer Older
Wei Chengliang's avatar
Wei Chengliang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from ctypes import CDLL, POINTER, c_int, c_double,c_float,c_long,c_char_p
from numpy.ctypeslib import ndpointer
import numpy.ctypeslib as clb
import numpy as np
from astropy.io import fits
from scipy.stats import randint
from glob import glob
from datetime import datetime
import os

lib_path = os.path.dirname(os.path.realpath(__file__))

#lib_path += "/add_CTI.so"
lib_path += "/libmoduleCTI.so"
lib = CDLL(lib_path)
CTI_simul = lib.__getattr__('CTI_simul')
CTI_simul.argtypes = [POINTER(POINTER(c_int)),c_int,c_int,c_int,c_int,POINTER(c_float),POINTER(c_float),\
                    c_float,c_float,c_float,c_int,POINTER(c_int),c_int,POINTER(POINTER(c_int))]

get_trap_h = lib.__getattr__('save_trap_map')
get_trap_h.argtypes = [POINTER(c_int), c_int, c_int, c_int, c_int, POINTER(c_float), c_float, c_float, c_char_p]

def get_trap_map(seeds,nx,ny,nmax,rho_trap,beta,c,out_dir):
    hsp_result = np.zeros(ny*nx*nmax)
    nsp = len(rho_trap)
    seeds1 = seeds.astype(np.int32)
    seeds_p = np.ctypeslib.as_ctypes(seeds1)
    rho_trap1 = rho_trap.astype(np.float32)
    rho_trap_p = np.ctypeslib.as_ctypes(rho_trap1)
    filename = (out_dir+"/trap.bin").encode('utf-8')
    get_trap_h(seeds_p,c_int(int(nsp)),c_int(int(nx)),c_int(int(ny)),\
            c_int(int(nmax)),rho_trap_p,c_float(beta),\
            c_float(c),filename)
        
def bin2fits(bin_file,fits_dir,nsp,nx,ny,nmax):
    data = np.fromfile(bin_file,dtype=np.float32)
    data = data.reshape(nx,nsp,ny,nmax).transpose(1,3,2,0)
    for i in range(nsp):
        print("transfering trap type "+str(i+1))
        datai = data[i]
        ntrap = datai[0,:,:]
        for j in range(nmax-1):
            h = datai[j+1,:,:]
            h[np.where(ntrap<j+1)] = 0
            datai[j+1,:,:] = h
        fits.writeto(fits_dir+"/trap_"+str(i+1)+".fits",datai,overwrite=True)
        
def numpy_matrix_to_int_pointer(arr):
    int_pointer_array = (POINTER(c_int)*arr.shape[0])()
    for i in range(arr.shape[0]):
        arr1 = np.array(arr[i].copy().tolist(),dtype=np.int32)
        int_pointer_array[i] = np.ctypeslib.as_ctypes(arr1)
    return int_pointer_array
def pointer_to_numpy_matrix(arr_pointer,row,col):
    arr = np.zeros((row,col))
    for i in range(row):
        for j in range(col):
            arr[i,j] = arr_pointer[i][j]
    return arr
def CTI_sim(im,nx,ny,noverscan,nsp,nmax,beta,w,c,t,rho_trap,trap_seeds,release_seed=0):
    image = im.T
    nx_c,ny_c,noverscan_c,nsp_c,nmax_c = c_int(nx),c_int(ny),c_int(noverscan),c_int(nsp),c_int(nmax)
    ntotal = ny+noverscan
    beta_c,w_c,c_c = c_float(beta),c_float(w),c_float(c)
    t_p = np.ctypeslib.as_ctypes(t)
    rho_trap_p = np.ctypeslib.as_ctypes(rho_trap)
    image_p = numpy_matrix_to_int_pointer(image)
    trap_seeds1 = trap_seeds.astype(np.int32)
    trap_seeds_p = np.ctypeslib.as_ctypes(trap_seeds1)
    release_seed_c = c_int(release_seed)
    image_cti = np.zeros((nx,ntotal))
    image_cti = image_cti.astype(np.int32)
    image_cti_p = numpy_matrix_to_int_pointer(image_cti)
    print(datetime.now())
    CTI_simul(image_p,nx,ny,noverscan,nsp,rho_trap_p,t_p,beta,w,c,nmax,trap_seeds_p,release_seed_c,image_cti_p)
    print(datetime.now())
    image_cti_result = np.zeros((nx,ntotal))
    for i in range(nx):
        for j in range(ntotal):
            image_cti_result[i,j] = image_cti_p[i][j]
    return image_cti_result.T

if __name__ =='__main__':
    nx,ny,noverscan,nsp,nmax = 4608,4616,84,3,10
    ntotal = 4700
    beta,w,c = 0.478,84700,0
    t = np.array([0.74,7.7,37],dtype=np.float32)
    rho_trap = np.array([0.6,1.6,1.4],dtype=np.float32)
    trap_seeds = np.array([0,100,1000],dtype=np.int32)
    release_seed = 500
    image = fits.getdata("inputdata/image.fits").astype(np.int32)
    get_trap_map(trap_seeds,nx,ny,nmax,rho_trap,beta,c,".")
    bin2fits("trap.bin",".",nsp,nx,ny,nmax)
    image_cti = CTI_sim(image,nx,ny,noverscan,nsp,nmax,beta,w,c,t,rho_trap,trap_seeds,release_seed)  
    fits.writeto("output/image_CTI.fits",data=image_cti,overwrite=True)