Coverage for src/susi/utils/processing.py: 92%
71 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4Utilities for (multi) processing
6@author: hoelken
7"""
9import os
10from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, Executor
12from ..base import Logging
14log = Logging.get_logger()
17class Thread:
18 """
19 # Thread
20 The threads holds the information on the function to execute in a thread or process.
21 Provides an interface to the `future` object once submitted to an executer.
22 """
24 def __init__(self, func, args):
25 self.function = func
26 self.arguments = args
27 self.future = None
29 def submit(self, executor: Executor):
30 """Start execution via executor"""
31 if not self.is_submitted():
32 self.future = executor.submit(self.function, self.arguments)
33 return self
35 def is_submitted(self) -> bool:
36 return self.future is not None
38 def is_done(self):
39 return self.is_submitted() and self.future.done()
41 def exception(self):
42 if not self.is_done():
43 return None
44 return self.future.exception()
46 def result(self):
47 if not self.is_submitted():
48 return None
49 return self.future.result()
51 def cancel(self):
52 try:
53 self.future.cancel
54 except RuntimeError as e:
55 log.warning('Unable to cancel thread: %s', e)
58class MP:
59 """
60 ## MP Multi-Processing
61 Class provides housekeeping / setup methods to reduce the programming overhead of
62 spawning threads or processes.
63 """
65 #: Number of CPUs of the current machine
66 NUM_CPUs = round(os.cpu_count() * 0.8)
68 @staticmethod
69 def threaded(func, args, workers=10, raise_exception=True):
70 """
71 Calls the given function in multiple threads for the set of given arguments
72 Note that this does not spawn processes, but threads. Use this for non CPU
73 CPU dependent tasks, i.e. I/O
74 Method returns once all calls are done.
76 ### Params
77 - func: [Function] the function to call
78 - args: [Iterable] the 'list' of arguments for each call
79 - workers: [Integer] the number of concurrent threads to use
80 - raise_exception: [Bool] Flag if an exception in a thread shall be raised or just logged
82 ### Returns
83 Results from all `Threads` as list
84 """
85 if len(args) == 1:
86 return list(func(arg) for arg in args)
88 with ThreadPoolExecutor(workers) as ex:
89 threads = [Thread(func, arg).submit(ex) for arg in args]
90 return MP.collect_results(threads, raise_exception)
92 @staticmethod
93 def simultaneous(func, args, workers=None, raise_exception=True):
94 """
95 Calls the given function in multiple processes for the set of given arguments
96 Note that this does spawn processes, not threads. Use this for task that
97 depend heavily on CPU and can be done in parallel.
98 Method returns once all calls are done.
100 ### Params
101 - func: [Function] the function to call
102 - args: [Iterable] the 'list' of arguments for each call
103 - workers: [Integer] the number of concurrent threads to use (Default: NUM_CPUs)
104 - raise_exception: [Bool] Flag if an exception in a thread shall be raised or just logged
106 ### Returns
107 Results from all `Threads` as list
108 """
109 if len(args) == 1:
110 return list(func(arg) for arg in args)
112 if workers is None:
113 workers = MP.NUM_CPUs
114 with ProcessPoolExecutor(workers) as ex:
115 threads = [Thread(func, arg).submit(ex) for arg in args]
116 return MP.collect_results(threads, raise_exception)
118 @staticmethod
119 def collect_results(threads: list, raise_exception: bool = True) -> list:
120 """
121 Takes a list of threads and waits for them to be executed. Collects results.
123 ### Params
124 - threads: [List<Thread>] a list of submitted threads
125 - raise_exception: [Bool] Flag if an exception in a thread shall be raised or just logged
127 ### Returns
128 Results from all `Threads` as list
129 """
130 result = []
131 while len(threads) > 0:
132 for thread in threads:
133 if not thread.is_submitted():
134 threads.remove(thread)
135 if not thread.is_done():
136 continue
138 if thread.exception() is not None:
139 MP.__exception_handling(threads, thread, raise_exception)
140 else:
141 result.append(thread.result())
142 threads.remove(thread)
143 return result
145 @staticmethod
146 def __exception_handling(threads, thread, raise_exception):
147 ex = thread.exception()
148 log.critical("Execution of '%s' caused\n\t [%s]: %s",
149 thread.function.__name__, ex.__class__.__name__, ex)
150 if raise_exception:
151 # Stop all remaining threads:
152 for t in threads:
153 t.cancel()
154 # Raise exception
155 raise ex