diff --git a/csst_common/decorator.py b/csst_common/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..05d5b3996c5e7d306f96286966e4dfb1063a9f23 --- /dev/null +++ b/csst_common/decorator.py @@ -0,0 +1,99 @@ +import functools +import logging +import time +import traceback +from typing import Callable, NamedTuple, Optional + +from csst_common.status import CsstResult, CsstStatus +from csst_common.logger import get_logger + +__all__ = ["ModuleResult", "parameterized_module_decorator"] + + +# module should return ModuleResult as result +class ModuleResult(NamedTuple): + module: str + cost: float + status: CsstStatus.ERROR + files: Optional[list] + output: dict + + +def parameterized_module_decorator( + logger: Optional[logging.Logger] = None, +) -> Callable: + # use default logger + if logger is None: + logger = get_logger() + + def module_decorator(func: Callable) -> Callable: + """ + A general wrapper for algorithm module. + + This wrapper can be used for an algorithm module that returns `csst_common.CsstResult` object. + + Parameters + ---------- + func : Callable + The algorithm module interface function. + + Returns + ------- + Callable + The wrapped module. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + logger.info(f"=====================================================") + t_start = time.time() + logger.info(f"Starting Module: **{func.__name__}**") + # logger.info(f"Additional arguments: {args} {kwargs}") + try: + # if the module works well + res: CsstResult = func(*args, **kwargs) + assert isinstance(res, CsstResult) + # define results + status = res.status + files = res.files + output = res.output + except Exception as e: + # if the module raises error + exc_info = traceback.format_exc() # traceback info + logger.error(f"Error occurs! \n{exc_info}") + # define results + status = CsstStatus.ERROR # default status if exceptions occur + files = None + output = {"exc_info": exc_info} # default output if exceptions occur + finally: + t_stop = time.time() + t_cost = t_stop - t_start + if isinstance(status, CsstStatus): + # status is + logger.info( + f"Module finished: status={status} | cost={t_cost:.1f} sec" + ) + else: + # invalid status + logger.error( + f"Invalid status: {status} is not a CsstResult object!" + ) + # record exception traceback info + logger.info( + f"ModuleResult: \n" + f" - name: {func.__name__}\n" + f" - status: {status}\n" + f" - files: {files}\n" + f" - output: {output}\n" + ) + return ModuleResult( + module=func.__name__, + cost=t_cost, + status=status, + files=files, + output=output, + ) + + return wrapper + + return module_decorator diff --git a/tests/test_decorator.py b/tests/test_decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..8c14d3299ffa7fe822434ce706eaf36f86932259 --- /dev/null +++ b/tests/test_decorator.py @@ -0,0 +1,23 @@ +import unittest +from csst_common.status import CsstStatus, CsstResult +from csst_common import parameterized_module_decorator + + +class TestDecorator(unittest.TestCase): + def test_parameterized_module_decorator(self): + @parameterized_module_decorator() + def call_add(a, b): + if isinstance(a, float) and isinstance(b, float): + return CsstResult(CsstStatus.PERFECT, files=None, result=a + b) + else: + return CsstResult(CsstStatus.ERROR, files=None, result=a + b) + + mres_int = call_add(1, 2) + self.assertEqual(mres_int.module, "call_add") + self.assertGreater(mres_int.cost, 0) + self.assertTrue(mres_int.status, CsstStatus(2)) + + mres_float = call_add(1, 2) + self.assertEqual(mres_float.module, "call_add") + self.assertGreater(mres_float.cost, 0) + self.assertTrue(mres_float.status, CsstStatus(0))