Coverage for src/susi/utils/processing.py: 92%

71 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4Utilities for (multi) processing 

5 

6@author: hoelken 

7""" 

8 

9import os 

10from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, Executor 

11 

12from ..base import Logging 

13 

14log = Logging.get_logger() 

15 

16 

17class Thread: 

18 """ 

19 # Thread 

20 The threads holds the information on the function to execute in a thread or process. 

21 Provides an interface to the `future` object once submitted to an executer. 

22 """ 

23 

24 def __init__(self, func, args): 

25 self.function = func 

26 self.arguments = args 

27 self.future = None 

28 

29 def submit(self, executor: Executor): 

30 """Start execution via executor""" 

31 if not self.is_submitted(): 

32 self.future = executor.submit(self.function, self.arguments) 

33 return self 

34 

35 def is_submitted(self) -> bool: 

36 return self.future is not None 

37 

38 def is_done(self): 

39 return self.is_submitted() and self.future.done() 

40 

41 def exception(self): 

42 if not self.is_done(): 

43 return None 

44 return self.future.exception() 

45 

46 def result(self): 

47 if not self.is_submitted(): 

48 return None 

49 return self.future.result() 

50 

51 def cancel(self): 

52 try: 

53 self.future.cancel 

54 except RuntimeError as e: 

55 log.warning('Unable to cancel thread: %s', e) 

56 

57 

58class MP: 

59 """ 

60 ## MP Multi-Processing 

61 Class provides housekeeping / setup methods to reduce the programming overhead of 

62 spawning threads or processes. 

63 """ 

64 

65 #: Number of CPUs of the current machine 

66 NUM_CPUs = round(os.cpu_count() * 0.8) 

67 

68 @staticmethod 

69 def threaded(func, args, workers=10, raise_exception=True): 

70 """ 

71 Calls the given function in multiple threads for the set of given arguments 

72 Note that this does not spawn processes, but threads. Use this for non CPU 

73 CPU dependent tasks, i.e. I/O 

74 Method returns once all calls are done. 

75 

76 ### Params 

77 - func: [Function] the function to call 

78 - args: [Iterable] the 'list' of arguments for each call 

79 - workers: [Integer] the number of concurrent threads to use 

80 - raise_exception: [Bool] Flag if an exception in a thread shall be raised or just logged 

81 

82 ### Returns 

83 Results from all `Threads` as list 

84 """ 

85 if len(args) == 1: 

86 return list(func(arg) for arg in args) 

87 

88 with ThreadPoolExecutor(workers) as ex: 

89 threads = [Thread(func, arg).submit(ex) for arg in args] 

90 return MP.collect_results(threads, raise_exception) 

91 

92 @staticmethod 

93 def simultaneous(func, args, workers=None, raise_exception=True): 

94 """ 

95 Calls the given function in multiple processes for the set of given arguments 

96 Note that this does spawn processes, not threads. Use this for task that 

97 depend heavily on CPU and can be done in parallel. 

98 Method returns once all calls are done. 

99 

100 ### Params 

101 - func: [Function] the function to call 

102 - args: [Iterable] the 'list' of arguments for each call 

103 - workers: [Integer] the number of concurrent threads to use (Default: NUM_CPUs) 

104 - raise_exception: [Bool] Flag if an exception in a thread shall be raised or just logged 

105 

106 ### Returns 

107 Results from all `Threads` as list 

108 """ 

109 if len(args) == 1: 

110 return list(func(arg) for arg in args) 

111 

112 if workers is None: 

113 workers = MP.NUM_CPUs 

114 with ProcessPoolExecutor(workers) as ex: 

115 threads = [Thread(func, arg).submit(ex) for arg in args] 

116 return MP.collect_results(threads, raise_exception) 

117 

118 @staticmethod 

119 def collect_results(threads: list, raise_exception: bool = True) -> list: 

120 """ 

121 Takes a list of threads and waits for them to be executed. Collects results. 

122 

123 ### Params 

124 - threads: [List<Thread>] a list of submitted threads 

125 - raise_exception: [Bool] Flag if an exception in a thread shall be raised or just logged 

126 

127 ### Returns 

128 Results from all `Threads` as list 

129 """ 

130 result = [] 

131 while len(threads) > 0: 

132 for thread in threads: 

133 if not thread.is_submitted(): 

134 threads.remove(thread) 

135 if not thread.is_done(): 

136 continue 

137 

138 if thread.exception() is not None: 

139 MP.__exception_handling(threads, thread, raise_exception) 

140 else: 

141 result.append(thread.result()) 

142 threads.remove(thread) 

143 return result 

144 

145 @staticmethod 

146 def __exception_handling(threads, thread, raise_exception): 

147 ex = thread.exception() 

148 log.critical("Execution of '%s' caused\n\t [%s]: %s", 

149 thread.function.__name__, ex.__class__.__name__, ex) 

150 if raise_exception: 

151 # Stop all remaining threads: 

152 for t in threads: 

153 t.cancel() 

154 # Raise exception 

155 raise ex