utils.py 4.85 KB
Newer Older
GZhao's avatar
GZhao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import numpy as np
import scipy.ndimage as nd
import logging
import random
import matplotlib.pyplot as plt 
# DO NOT IMPORT CPICIMGSIM MODULES HERE


class Logger(object):
    def __init__(self, filename, level='INFO'):
        self.logger = logging.getLogger('cpism_log')
        self.logger.setLevel(logging.DEBUG)

        shinfo = logging.StreamHandler()
        onlyinfo = logging.Filter()
        onlyinfo.filter = lambda record: (record.levelno < logging.WARNING)
        fmtstr = '%(message)s'
        shinfo.setFormatter(logging.Formatter(fmtstr))  # 设置屏幕上显示的格式
        shinfo.setLevel(logging.INFO)
        shinfo.addFilter(onlyinfo)

        sh = logging.StreamHandler()
        fmtstr = '!%(levelname)s!: %(message)s [%(filename)s - %(funcName)s (line: %(lineno)d)]: '
        sh.setFormatter(logging.Formatter(fmtstr))  # 设置屏幕上显示的格式
        sh.setLevel(logging.WARNING)

        th = logging.FileHandler(filename)  # 往文件里写入#指定间隔时间自动生成文件的处理器

        fmtstr = '%(asctime)s %(filename)s [%(funcName)s] - %(levelname)s: %(message)s'
        th.setFormatter(logging.Formatter(fmtstr))  # 设置文件里写入的格式
        th.setLevel(logging.__dict__.get(level.upper()))
        self.logger.addHandler(shinfo)
        self.logger.addHandler(sh)
        self.logger.addHandler(th)


def random_seed_select(seed=-1):
    """
    Select a random seed for numpy.random and return it.
    """
    if seed == -1:
        seed = random.randint(0, 2**32-1)
    np.random.seed(seed)
    return seed


def region_replace(
    background: np.ndarray,
    front: np.ndarray,
    shift: list,
    bmask: float = 1.0,
    fmask: float = 1.0,
    padded_in: bool = False,
    padded_out: bool = False,
    subpix: bool = False
):
    """
    replace a region of the background with the front image.

    Parameters
    ----------
    background: np.ndarray
        The background image.
    front: np.ndarray
        The front image.
    shift: list
        The [x, y] shift of the front image. Unit: pixel.
        Relative to the lower-left corner of the background image.
        [0, 0] means the lower-left corner of the front image is at the lower-left corner of the background image. 
    bmask: float
        The mask of the background image. Default: 1.0
        0.0 means the background image is masked.
        1.0 means the background image is fully added.
    fmask: float
        The mask of the front image. Default: 1.0
        0.0 means the front image is masked (not added).
        1.0 means the front image is fully added.
    padded_in: bool
        Whether the input background image is padded. Default: False
        In the function, the background image is padded by the size of the front image.
        If True, means the background image is padded.
    padded_out: bool
        Whether the output image is padded. Default: False
        In the function, the background image is padded by the size of the front image.
        If True, means the output image is padded.
        padded_in and padded_out are designed for the case that replace_region fuction is called multiple times.
    subpix: bool
        Whether the shift is subpixel. Default: False
        If True, the shift is subpixel, using scipy.ndimage.shift to shift the front image.
        If False, the shift is integer, using numpy slicing to shift the front image.

    Returns
    -------
    np.ndarray
        The output image.
        shape = background.shape if padded_out = False
        shape = background.shape + 2 * front.shape if padded_out = True
    """

    int_shift = np.array(shift).astype(int)
    b_sz = np.array(background.shape)
    f_sz = np.array(front.shape)

    if padded_in:
        padded = background
        b_sz = b_sz - f_sz * 2
    else:
        padded = np.pad(background, ((f_sz[0], f_sz[0]), (f_sz[1], f_sz[1])))

    if np.any((int_shift < -b_sz) | (int_shift > b_sz)):
        if padded_out:
            return padded
        return background

    if subpix:
        subs = np.array(shift) - int_shift
        front = nd.shift(front, (subs[0], subs[1]))

    int_shift += f_sz
    roi_y = int_shift[1]
    roi_x = int_shift[0]
    padded[roi_y: roi_y+f_sz[0], roi_x:roi_x+f_sz[1]] *= bmask
    padded[roi_y: roi_y+f_sz[0], roi_x:roi_x+f_sz[1]] += fmask * front

    if padded_out:
        return padded

    return padded[f_sz[0]:b_sz[0]+f_sz[0], f_sz[1]:b_sz[1]+f_sz[1]]


def psf_imshow(psf, vmin=1e-8, vmax=0.1, log=True, region=1):
    focal_img = psf.copy()
    focal_img = (focal_img - focal_img.min()) / (focal_img.max() - focal_img.min())
    if log:
        focal_img = np.log10(focal_img * 9 + 1)
    
    plt.imshow(focal_img, origin='lower', cmap='gray', vmin=vmin, vmax=vmax)

    shape = psf.shape
    plt.xlim(shape[1] * (1 - region) / 2, shape[1] * (1 + region) / 2)
    plt.ylim(shape[0] * (1 - region) / 2, shape[0] * (1 + region) / 2)