disperse.pyx 4 KB
Newer Older
xin's avatar
init  
xin committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

from __future__ import division

import numpy as np
cimport numpy as np

DTYPE = np.double
ITYPE = np.int64

ctypedef np.double_t DTYPE_t

ctypedef np.uint_t UINT_t
ctypedef np.int_t INT_t
ctypedef np.int64_t LINT_t
ctypedef np.int32_t FINT_t
ctypedef np.float32_t FTYPE_t

import cython

cdef extern from "math.h":
    double sqrt(double x)
    double exp(double x)
    
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.embedsignature(True)
def disperse_grism_object(np.ndarray[FTYPE_t, ndim=2] flam, 
                          np.ndarray[LINT_t, ndim=1] idxl, 
                          np.ndarray[DTYPE_t, ndim=1] yfrac, 
                          np.ndarray[DTYPE_t, ndim=1] ysens, 
                          np.ndarray[DTYPE_t, ndim=1] full, 
                          np.ndarray[LINT_t, ndim=1] x0, 
                          np.ndarray[LINT_t, ndim=1] shd, 
                          np.ndarray[LINT_t, ndim=1] shg):
    """Compute a dispersed 2D spectrum
    
    Parameters
    ----------
    flam : direct image matrix, 2 dim [y,x]
    idxl: grating disperse light to pixel, pixel index, 1 dim, length = ysens, yfrac
    yfrac: 
    ysense: sensitivity  use pixel describe
    full: output result ,1 dim, y_beam * x_beam
    x0: the center of gal in image thumbnail
    shd:  shape of direct image
    shg:  shape of grating image
    """
    cdef int i,j,k1,k2
    cdef unsigned int nk,nl,k,shx,shy
    cdef double fl_ij
    
    nk = len(idxl)
    nl = len(full)
    
    for i in range(0-x0[1], x0[1]):
        if (x0[1]+i < 0) | (x0[1]+i >= shd[1]):
            continue
            
        for j in range(0-x0[0], x0[0]):
            if (x0[0]+j < 0) | (x0[0]+j >= shd[0]):
                continue

            fl_ij = flam[x0[0]+j, x0[1]+i] #/1.e-17
            if (fl_ij == 0):
                continue
                
            for k in range(nk):
                k1 = idxl[k]+j*shg[1]+i
                if (k1 >= 0) & (k1 < nl):
                    full[k1] += ysens[k]*fl_ij*yfrac[k]
                    
                k2 = idxl[k]+(j-1)*shg[1]+i
                if (k2 >= 0) & (k2 < nl):
                    full[k2] += ysens[k]*fl_ij*(1-yfrac[k])
    
    return True

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.embedsignature(True)
def compute_segmentation_limits(np.ndarray[FTYPE_t, ndim=2] segm, int seg_id, np.ndarray[FTYPE_t, ndim=2] flam, np.ndarray[LINT_t, ndim=1] shd):
    """Find pixel limits of a segmentation region
    
    Parameters
    ----------
    segm: ndarray (np.float32)
        segmentation array
    
    seg_id: int
        ID to test
    
    flam: ndarray (float)
        Flux array to compute weighted centroid within segmentation region
        
    shd: [int, int]
        Shape of segm
    """
    cdef int i, j, imin, imax, jmin, jmax, area
    cdef double inumer, jnumer, denom, wht_ij
    
    area = 0
    
    imin = shd[0]
    imax = 0
    jmin = shd[1]
    jmax = 0
    
    inumer = 0.
    jnumer = 0.
    denom = 0.
    
    for i in range(shd[0]):
        for j in range(shd[1]):
            if segm[i,j] != seg_id:
                continue
            
            area += 1
            wht_ij = flam[i,j]
            inumer += i*wht_ij
            jnumer += j*wht_ij
            denom += wht_ij
            
            if i < imin:
                imin = i
            if i > imax:
                imax = i
            
            if j < jmin: 
                jmin = j
            if j > jmax:
                jmax = j
    
    ### No matched pixels
    if denom == 0:
        denom = -99
        
    return imin, imax, inumer/denom, jmin, jmax, jnumer/denom, area, denom
            
            
            
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.embedsignature(True)
def seg_flux(np.ndarray[FTYPE_t, ndim=2] flam, np.ndarray[LINT_t, ndim=1] idxl, np.ndarray[DTYPE_t, ndim=1] yfrac, np.ndarray[DTYPE_t, ndim=1] ysens, np.ndarray[DTYPE_t, ndim=1] full, np.ndarray[LINT_t, ndim=1] x0, np.ndarray[LINT_t, ndim=1] shd, np.ndarray[LINT_t, ndim=1] shg):
    pass