Coverage for pesummary/utils/list.py: 89.1%

64 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import numpy as np 

4 

5__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

6 

7 

8class List(list): 

9 """Base list class to extend the core `list` class 

10 

11 Parameters 

12 ---------- 

13 *args: tuple 

14 all arguments are passed to the list class 

15 **kwargs: dict 

16 all kwargs are passed to the list class 

17 

18 Attributes 

19 ---------- 

20 added: list 

21 list of values appended to the original list 

22 """ 

23 __slots__ = ["original", "cls", "added", "removed"] 

24 

25 def __init__(self, *args, **kwargs): 

26 if len(args) == 1: 

27 self.original = list(*args) 

28 self.cls = kwargs.get("cls", None) 

29 self.added = kwargs.get("added", []) 

30 self.removed = kwargs.get("removed", []) 

31 super(List, self).__init__(*args) 

32 else: 

33 _, self.original, self.cls, self.added, self.removed = args 

34 super(List, self).__init__(_) 

35 

36 @property 

37 def ndim(self): 

38 return np.array(self).ndim 

39 

40 def __reduce__(self): 

41 _slots = [getattr(self, i) for i in self.__slots__] 

42 slots = [list(self)] + _slots 

43 return (self.__class__, tuple(slots)) 

44 

45 def __setstate__(self, state): 

46 _state = state[1] 

47 self.original = _state["original"] 

48 self.cls = _state["original"] 

49 self.added = _state["added"] 

50 self.removed = _state["removed"] 

51 

52 def __getitem__(self, *args, **kwargs): 

53 output = super(List, self).__getitem__(*args, **kwargs) 

54 if self.cls is None: 

55 return output 

56 if isinstance(output, list): 

57 return [self.cls(value) for value in output] 

58 else: 

59 return self.cls(output) 

60 

61 def __add__(self, *args, **kwargs): 

62 self.added.extend(*args) 

63 obj = List(super(List, self).__add__(*args, **kwargs)) 

64 for attr in self.__slots__: 

65 setattr(obj, attr, getattr(self, attr)) 

66 return obj 

67 

68 def __iadd__(self, *args, **kwargs): 

69 self.added.extend(*args) 

70 obj = List(super(List, self).__iadd__(*args, **kwargs)) 

71 for attr in self.__slots__: 

72 setattr(obj, attr, getattr(self, attr)) 

73 return obj 

74 

75 def append(self, *args, **kwargs): 

76 self.added.append(*args) 

77 return super(List, self).append(*args, **kwargs) 

78 

79 def extend(self, *args, **kwargs): 

80 self.added.extend(*args) 

81 return super(List, self).extend(*args, **kwargs) 

82 

83 def insert(self, index, obj, **kwargs): 

84 self.added.append(obj) 

85 return super(List, self).insert(index, obj, **kwargs) 

86 

87 def remove(self, element, **kwargs): 

88 obj = super(List, self).remove(element, **kwargs) 

89 self.removed.append(element) 

90 if element in self.added: 

91 self.added.remove(element) 

92 return obj 

93 

94 def pop(self, index, **kwargs): 

95 self.removed.append(self[index]) 

96 obj = super(List, self).pop(index) 

97 return obj