Commit 16cecad4 authored by BO ZHANG's avatar BO ZHANG 🏀
Browse files

update decorator and its unit test

parent 1e398d53
import functools
import logging
import time
import traceback
from typing import Callable, NamedTuple, Optional
from csst_common.status import CsstResult, CsstStatus
from csst_common.logger import get_logger
__all__ = ["ModuleResult", "parameterized_module_decorator"]
# module should return ModuleResult as result
class ModuleResult(NamedTuple):
module: str
cost: float
status: CsstStatus.ERROR
files: Optional[list]
output: dict
def parameterized_module_decorator(
logger: Optional[logging.Logger] = None,
) -> Callable:
# use default logger
if logger is None:
logger = get_logger()
def module_decorator(func: Callable) -> Callable:
"""
A general wrapper for algorithm module.
This wrapper can be used for an algorithm module that returns `csst_common.CsstResult` object.
Parameters
----------
func : Callable
The algorithm module interface function.
Returns
-------
Callable
The wrapped module.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
logger.info(f"=====================================================")
t_start = time.time()
logger.info(f"Starting Module: **{func.__name__}**")
# logger.info(f"Additional arguments: {args} {kwargs}")
try:
# if the module works well
res: CsstResult = func(*args, **kwargs)
assert isinstance(res, CsstResult)
# define results
status = res.status
files = res.files
output = res.output
except Exception as e:
# if the module raises error
exc_info = traceback.format_exc() # traceback info
logger.error(f"Error occurs! \n{exc_info}")
# define results
status = CsstStatus.ERROR # default status if exceptions occur
files = None
output = {"exc_info": exc_info} # default output if exceptions occur
finally:
t_stop = time.time()
t_cost = t_stop - t_start
if isinstance(status, CsstStatus):
# status is
logger.info(
f"Module finished: status={status} | cost={t_cost:.1f} sec"
)
else:
# invalid status
logger.error(
f"Invalid status: {status} is not a CsstResult object!"
)
# record exception traceback info
logger.info(
f"ModuleResult: \n"
f" - name: {func.__name__}\n"
f" - status: {status}\n"
f" - files: {files}\n"
f" - output: {output}\n"
)
return ModuleResult(
module=func.__name__,
cost=t_cost,
status=status,
files=files,
output=output,
)
return wrapper
return module_decorator
import unittest
from csst_common.status import CsstStatus, CsstResult
from csst_common import parameterized_module_decorator
class TestDecorator(unittest.TestCase):
def test_parameterized_module_decorator(self):
@parameterized_module_decorator()
def call_add(a, b):
if isinstance(a, float) and isinstance(b, float):
return CsstResult(CsstStatus.PERFECT, files=None, result=a + b)
else:
return CsstResult(CsstStatus.ERROR, files=None, result=a + b)
mres_int = call_add(1, 2)
self.assertEqual(mres_int.module, "call_add")
self.assertGreater(mres_int.cost, 0)
self.assertTrue(mres_int.status, CsstStatus(2))
mres_float = call_add(1, 2)
self.assertEqual(mres_float.module, "call_add")
self.assertGreater(mres_float.cost, 0)
self.assertTrue(mres_float.status, CsstStatus(0))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment