diff --git a/csst_common/file.py b/csst_common/file.py index c81b966999d4ac451f4e6b130d13698048ed2c68..6727cfcdf06d81d57e1d01c01524c3fd519f66d9 100644 --- a/csst_common/file.py +++ b/csst_common/file.py @@ -1,13 +1,28 @@ import os +from typing import Optional class File: - def __init__(self, file_path: str = "/path/to/file.fits"): + def __init__(self, file_path: str = "/path/to/file.fits", new_dir=None): self.file_path = file_path - self.file_path_prefix, self.file_path_ext = os.path.splitext(file_path) + self.dirname = os.path.dirname(self.file_path) + self.file_name = os.path.basename(self.file_path) + self.prefix, self.ext = os.path.splitext(self.file_name) + self.new_dir = new_dir if new_dir is not None else self.dirname - def replace_ext(self, ext="wht.fits") -> str: - if ext.startswith("."): - return self.file_path_prefix + ext + def replace_ext( + self, new_ext: str = "wht.fits", new_dir: Optional[str] = None + ) -> str: + if new_dir is None: + new_dir = self.new_dir + + if new_ext.startswith("."): + return os.path.join( + new_dir if new_dir is not None else self.dirname, + self.prefix + new_ext, + ) else: - return self.file_path_prefix + "_" + ext + return os.path.join( + new_dir if new_dir is not None else self.dirname, + self.prefix + "_" + new_ext, + )