Coverage for src/susi/utils/collections.py: 100%
57 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4The `Collections` utility provides methods to deal with Collections (lists, dictionaries, arrays, ...)
6@author: hoelken
7"""
9import numpy as np
10from typing import Callable, Iterable
13class Collections:
14 """
15 Static utility for handling collections.
16 """
18 @staticmethod
19 def chunker(seq: list, size: int) -> list:
20 """
21 Generates chunks (slices) from a given sequence
23 ### Params
24 - seq: the list to chunk
25 - size: the size of the chunks
27 ### Returns
28 A list of lists where each list has the
29 length of the requested chunk size (maybe except the last one)
30 """
31 if size < 1:
32 return [seq]
33 return [seq[pos : pos + size] for pos in range(0, len(seq), size)]
35 @staticmethod
36 def indexed_chunks(seq: list, size: int) -> dict:
37 """
38 Generates indexed chunks (slices) from a given sequence
39 ### Params
40 - seq: List the list to chunk
41 - size: Integer the size of the chunks
43 ### Returns
44 A dictionary with the index as key and the corresponding chunk as value.
45 The length of the value arrays is the requested chunk size (maybe except the last one)
46 """
47 idx = 0
48 indexed_chunks = {}
49 for chunk in Collections.chunker(seq, size):
50 indexed_chunks[idx] = chunk
51 idx += 1
52 return indexed_chunks
54 @staticmethod
55 def as_float_array(orig, dtype=np.float32) -> np.array:
56 """
57 Creates a copy of the orig and converts all values to dtype (default: `np.float32`)
59 ### Params
60 - orig: an object that can be converted to a list
62 ### Params
63 Array with float values converted from the orig
64 """
65 return np.array(list(orig), dtype=dtype)
67 @staticmethod
68 def as_int_array(orig) -> np.array:
69 """
70 Creates a copy of the orig and converts all values to `int`
72 ### Params
73 - orig: an object that can be converted to a list
75 ### Params
76 Array with int values converted from the orig
77 """
78 return np.array(list(orig), dtype=int)
80 @staticmethod
81 def bin(orig: np.array, binning: list, method: Callable = np.mean) -> np.array:
82 """
83 Bins along a given set of axis.
85 ### Params
86 - orig: The original numpy array
87 - binning: A list of binning values.
88 - Length of the list must match the number of axis (i.e. the length of the `orig.shape`).
89 - Per axis set `1` for no binning, `-1` for bin all and any positive number
90 to specify the bin size along the axis.
91 - method: The function to apply to the bin (e.g. np.max for max pooling, np.mean for average)
92 ### Returns
93 The binned array
94 """
95 if np.all(np.array(binning) == 1):
96 # no binning whatsoever, return original
97 return orig
99 if len(orig.shape) != len(binning):
100 raise Exception(f"Shape {orig.shape} and number of binning axis {binning} don't match.")
102 data = orig
103 for ax in range(len(binning)):
104 data = Collections.bin_axis(data, binning[ax], axis=ax, method=method)
105 return data
107 @staticmethod
108 def bin_axis(data: np.array, binsize: int, axis: int = 0, method: Callable = np.mean):
109 """
110 Bins an array along a given axis.
112 ### Params
113 - data: The original numpy array
114 - axis: The axis to bin along
115 - binsize: The size of each bin
116 - method: The function to apply to the bin (e.g. np.max for max pooling, np.mean for average)
118 ### Returns
119 The binned array
120 """
121 if binsize < 0:
122 return np.array([method(data, axis=axis)])
124 dims = np.array(data.shape)
125 argdims = np.arange(data.ndim)
126 argdims[0], argdims[axis] = argdims[axis], argdims[0]
127 data = data.transpose(argdims)
128 data = [
129 method(np.take(data, np.arange(int(i * binsize), int(i * binsize + binsize)), 0), 0)
130 for i in np.arange(dims[axis] // binsize)
131 ]
132 data = np.array(data).transpose(argdims)
133 return data
135 @staticmethod
136 def flatten_sort(lists: Iterable) -> list:
137 """Flattens a list of lists and sorts the result"""
138 result = Collections.flatten(lists)
139 result.sort()
140 return result
142 @staticmethod
143 def flatten(lists: Iterable) -> list:
144 """Flattens a list of lists and sorts the result"""
145 result = []
146 for entry in lists:
147 if isinstance(entry, (list, type(np.array))):
148 result.extend(Collections.flatten(entry))
149 else:
150 result.append(entry)
151 return result