Coverage for pesummary/utils/tqdm.py: 96.4%

55 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3from tqdm import tqdm as _tqdm 

4from tqdm.utils import _unicode 

5import time 

6 

7__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

8 

9 

10class tqdm(_tqdm): 

11 

12 def __init__(self, *args, logger=None, logger_level="INFO", **kwargs): 

13 self.logger = logger 

14 self.logger_level = logger_level 

15 super(tqdm, self).__init__(*args, **kwargs) 

16 logger_prefix = '%(message)s' 

17 if self.logger is not None: 

18 logger_prefix = logger.handlers[0].formatter._fmt 

19 if not self.gui: 

20 self.sp = self.status_printer( 

21 self.fp, logger=self.logger, logger_prefix=logger_prefix, 

22 **self.format_dict 

23 ) 

24 

25 @staticmethod 

26 def status_printer(file, logger=None, logger_prefix='%(message)s', **kwargs): 

27 """Extension of the tqdm.status_printer function to allow for tqdm 

28 to interact with logger 

29 """ 

30 fp = file 

31 fp_flush = getattr(fp, 'flush', lambda: None) # pragma: no cover 

32 

33 def fp_write_log(s): 

34 logger.debug(_unicode(s)) 

35 

36 def fp_write(s): 

37 text = _unicode(s) 

38 fp.write(text) 

39 fp_flush() 

40 

41 last_len = [0] 

42 

43 def print_status(s, time=None): 

44 len_s = len(s) 

45 _message = s + (' ' * max(last_len[0] - len_s, 0)) 

46 kwargs["message"] = _message 

47 if logger is not None: 

48 fp_write_log(_message) 

49 if time is not None: 

50 kwargs["asctime"] = time 

51 fp_write('\r' + logger_prefix % kwargs) 

52 last_len[0] = len_s 

53 

54 return print_status 

55 

56 @property 

57 def format_dict(self): 

58 """Extension of the tqdm.format_dict property to add extra quantities 

59 """ 

60 base = super(tqdm, self).format_dict 

61 if self.logger is not None: 

62 base.update( 

63 {"levelname": self.logger_level, "name": self.logger.name} 

64 ) 

65 base.update({"asctime": time.strftime("%Y-%m-%d %H:%M:%S")}) 

66 return base 

67 

68 def __str__(self): 

69 """Hack of the tqdm.__str__ function to prevent duplicating the entirety 

70 of the tqdm.display function 

71 """ 

72 if hasattr(self, "display_msg") and self.display_msg is not None: 

73 return self.display_msg 

74 return super(tqdm, self).__str__() 

75 

76 def display(self, msg=None, pos=None): 

77 """Extension of the tqdm.display function to allow for the time to be 

78 passed to the status_printer function 

79 """ 

80 self.display_msg = msg 

81 _original_sp = self.sp 

82 self.sp = lambda _msg: _original_sp( 

83 _msg, time.strftime("%Y-%m-%d %H:%M:%S") 

84 ) 

85 _ = super(tqdm, self).display(msg=None, pos=pos) 

86 self.sp = _original_sp 

87 return _ 

88 

89 

90def trange(*args, **kwargs): 

91 """ 

92 A shortcut for tqdm(range(*args), **kwargs). 

93 """ 

94 return tqdm(range(*args), **kwargs)