Coverage for src/susi/utils/collections.py: 100%

57 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4The `Collections` utility provides methods to deal with Collections (lists, dictionaries, arrays, ...) 

5 

6@author: hoelken 

7""" 

8 

9import numpy as np 

10from typing import Callable, Iterable 

11 

12 

13class Collections: 

14 """ 

15 Static utility for handling collections. 

16 """ 

17 

18 @staticmethod 

19 def chunker(seq: list, size: int) -> list: 

20 """ 

21 Generates chunks (slices) from a given sequence 

22 

23 ### Params 

24 - seq: the list to chunk 

25 - size: the size of the chunks 

26 

27 ### Returns 

28 A list of lists where each list has the 

29 length of the requested chunk size (maybe except the last one) 

30 """ 

31 if size < 1: 

32 return [seq] 

33 return [seq[pos : pos + size] for pos in range(0, len(seq), size)] 

34 

35 @staticmethod 

36 def indexed_chunks(seq: list, size: int) -> dict: 

37 """ 

38 Generates indexed chunks (slices) from a given sequence 

39 ### Params 

40 - seq: List the list to chunk 

41 - size: Integer the size of the chunks 

42 

43 ### Returns 

44 A dictionary with the index as key and the corresponding chunk as value. 

45 The length of the value arrays is the requested chunk size (maybe except the last one) 

46 """ 

47 idx = 0 

48 indexed_chunks = {} 

49 for chunk in Collections.chunker(seq, size): 

50 indexed_chunks[idx] = chunk 

51 idx += 1 

52 return indexed_chunks 

53 

54 @staticmethod 

55 def as_float_array(orig, dtype=np.float32) -> np.array: 

56 """ 

57 Creates a copy of the orig and converts all values to dtype (default: `np.float32`) 

58 

59 ### Params 

60 - orig: an object that can be converted to a list 

61 

62 ### Params 

63 Array with float values converted from the orig 

64 """ 

65 return np.array(list(orig), dtype=dtype) 

66 

67 @staticmethod 

68 def as_int_array(orig) -> np.array: 

69 """ 

70 Creates a copy of the orig and converts all values to `int` 

71 

72 ### Params 

73 - orig: an object that can be converted to a list 

74 

75 ### Params 

76 Array with int values converted from the orig 

77 """ 

78 return np.array(list(orig), dtype=int) 

79 

80 @staticmethod 

81 def bin(orig: np.array, binning: list, method: Callable = np.mean) -> np.array: 

82 """ 

83 Bins along a given set of axis. 

84 

85 ### Params 

86 - orig: The original numpy array 

87 - binning: A list of binning values. 

88 - Length of the list must match the number of axis (i.e. the length of the `orig.shape`). 

89 - Per axis set `1` for no binning, `-1` for bin all and any positive number 

90 to specify the bin size along the axis. 

91 - method: The function to apply to the bin (e.g. np.max for max pooling, np.mean for average) 

92 ### Returns 

93 The binned array 

94 """ 

95 if np.all(np.array(binning) == 1): 

96 # no binning whatsoever, return original 

97 return orig 

98 

99 if len(orig.shape) != len(binning): 

100 raise Exception(f"Shape {orig.shape} and number of binning axis {binning} don't match.") 

101 

102 data = orig 

103 for ax in range(len(binning)): 

104 data = Collections.bin_axis(data, binning[ax], axis=ax, method=method) 

105 return data 

106 

107 @staticmethod 

108 def bin_axis(data: np.array, binsize: int, axis: int = 0, method: Callable = np.mean): 

109 """ 

110 Bins an array along a given axis. 

111 

112 ### Params 

113 - data: The original numpy array 

114 - axis: The axis to bin along 

115 - binsize: The size of each bin 

116 - method: The function to apply to the bin (e.g. np.max for max pooling, np.mean for average) 

117 

118 ### Returns 

119 The binned array 

120 """ 

121 if binsize < 0: 

122 return np.array([method(data, axis=axis)]) 

123 

124 dims = np.array(data.shape) 

125 argdims = np.arange(data.ndim) 

126 argdims[0], argdims[axis] = argdims[axis], argdims[0] 

127 data = data.transpose(argdims) 

128 data = [ 

129 method(np.take(data, np.arange(int(i * binsize), int(i * binsize + binsize)), 0), 0) 

130 for i in np.arange(dims[axis] // binsize) 

131 ] 

132 data = np.array(data).transpose(argdims) 

133 return data 

134 

135 @staticmethod 

136 def flatten_sort(lists: Iterable) -> list: 

137 """Flattens a list of lists and sorts the result""" 

138 result = Collections.flatten(lists) 

139 result.sort() 

140 return result 

141 

142 @staticmethod 

143 def flatten(lists: Iterable) -> list: 

144 """Flattens a list of lists and sorts the result""" 

145 result = [] 

146 for entry in lists: 

147 if isinstance(entry, (list, type(np.array))): 

148 result.extend(Collections.flatten(entry)) 

149 else: 

150 result.append(entry) 

151 return result