diff --git a/csst_proto/flip_image.py b/csst_proto/flip_image.py index d10254ccba585cea9a3feca6ad75d0a8a9c0b57b..818fc097805b2956f7e66eaa97f0944cc8729c7e 100644 --- a/csst_proto/flip_image.py +++ b/csst_proto/flip_image.py @@ -1,4 +1,6 @@ import numpy as np +import joblib +import multiprocessing from . import PACKAGE_PATH @@ -50,3 +52,17 @@ def read_test_image(): fp_img = PACKAGE_PATH + "/data/test_image.txt" print("reading file {} ...".format(fp_img)) return np.loadtxt(fp_img, dtype=int) + + +def flip_multiple_images_mp(imgs: list, n_jobs: int) -> list: + """ parallel with multiprocessing """ + with multiprocessing.Pool(n_jobs) as p: + results = p.map(flip_image, imgs) + return results + + +def flip_multiple_images_jl(imgs: list, n_jobs: int) -> list: + """ parallel with joblib """ + return joblib.Parallel(n_jobs=n_jobs)( + joblib.delayed(flip_image)(img) for img in imgs + ) diff --git a/csst_proto/scratch.py b/csst_proto/scratch.py deleted file mode 100644 index 394e31fcfb38f8258810d6964319ec320891e875..0000000000000000000000000000000000000000 --- a/csst_proto/scratch.py +++ /dev/null @@ -1,33 +0,0 @@ -import numpy as np - - -def print_some_numbers(start: int, stop=10): - """ this function prints some numbers - - Parameters - ---------- - start: int - start number - stop: int - stop number - - Notes - ----- - this algorithm ... - """ - for i in np.arange(start, stop): - print(i) - return - - -class AClass: - def __init__(self): - self.a = 1 - - def print_a(self): - if True: - print(self.a) - - -if __name__ == "__main__": - print_some_numbers(start=1, stop=5) diff --git a/tests/test_flip_image.py b/tests/test_flip_image.py index a657430ac458b10f771fefd41c2c4f60ca7b7727..a76b4e0acafea10bc15bddf5c4d417fdc42c6d23 100644 --- a/tests/test_flip_image.py +++ b/tests/test_flip_image.py @@ -3,6 +3,7 @@ import unittest import numpy as np from csst_proto.top_level_interface import flip_image, read_test_image +from csst_proto.flip_image import flip_multiple_images_jl, flip_multiple_images_mp class FlipImageTestCase(unittest.TestCase): @@ -15,3 +16,23 @@ class FlipImageTestCase(unittest.TestCase): # the code fails for 1D array with self.assertRaises(AssertionError): flip_image(np.array([1, 2, 3, 4])) + + def test_flip_multiple_images_mp(self): + """ test flip multiple images with multiprocessing """ + n_jobs = 10 + imgs = [read_test_image() for i_job in range(n_jobs)] + flipped_imgs = flip_multiple_images_mp(imgs, n_jobs) + for i_job in range(n_jobs): + self.assertTrue( + np.all(flipped_imgs[i_job] == np.array([[4, 3], [2, 1]])) + ) + + def test_flip_multiple_images_jl(self): + """ test flip multiple images with joblib """ + n_jobs = 10 + imgs = [read_test_image() for i_job in range(n_jobs)] + flipped_imgs = flip_multiple_images_jl(imgs, n_jobs) + for i_job in range(n_jobs): + self.assertTrue( + np.all(flipped_imgs[i_job] == np.array([[4, 3], [2, 1]])) + )