Coverage for src/susi/reduc/pipeline/orchestrator.py: 92%

118 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-08-11 10:03 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4Module that provides the orchestration for SUSI processing pipeline 

5 

6@author: hoelken 

7""" 

8import os.path 

9from astropy.wcs.docstrings import b 

10import numpy as np 

11 

12from spectroflat.smile import OffsetMap 

13 

14from . import BlockRegistry 

15from .chunker import Chunker 

16from .post_processor import PostProcessor 

17from .processing_data import ProcessingData 

18from .pre_processor import PreProcessor 

19from .processor import Processor 

20from ...io import Fits 

21from ...base import Logging, Config, Globals, IllegalArgumentException, Api 

22from ...utils import MP, Collections, progress 

23from ..shear_distortion.shear_and_rot_correction import ShearAndRotLoader 

24 

25log = Logging.get_logger() 

26 

27 

28class Orchestrator: 

29 """ 

30 ## Module 

31 A Module is a chain of `N` processing blocks that shall be executed in a row. 

32 Each module is composed of a full processing chain of multiple blocks. 

33 

34 <pre> 

35 ┣┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅ MODULE ┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┫ ┣┅┅┅┅ 

36 ╭─────────╮ ╭──────────────╮ ╭───────────────╮ ╭───────────────╮ 

37 │ Chunker │─┬─>│ PreProcessor │──>│ Processor │──>│ PostProcessor │─┬─>... 

38 ╰─────────╯ │ ╰──────────────╯ ╰───────────────╯ ╰───────────────╯ │ 

39 │ ╭──────────────╮ ╭───────────────╮ ╭───────────────╮ │ 

40 ├─>│ PreProcessor │──>│ Processor │──>│ PostProcessor │─┤ 

41 │ ╰──────────────╯ ╰───────────────╯ ╰───────────────╯ │ 

42 │ ╭──────────────╮ ╭───────────────╮ ╭───────────────╮ │ 

43 ╰─>│ PreProcessor │──>│ Processor │──>│ PostProcessor │─╯ 

44 ╰──────────────╯ ╰───────────────╯ ╰───────────────╯ 

45 </pre> 

46 

47 The chunker will split the given file list in batches and start one job per batch created. 

48 All jobs will be queued and processed by worker threads. All jobs will perform the 

49 same set of configured operations (callbacks). 

50 All the actual action is to be performed in the callbacks (i.e. the blocks). 

51 To add a new processing block implement a new `Block` subclass that follows the live-cycle 

52 described in the block class. 

53 

54 The `PreProcessor` takes care of the loading the input data 

55 The `Processor` executes the chain of processing blocks as configured in the pipeline 

56 Finally, the `PostProcessor` provides generic metadata and writes the results back to 

57 the disk. 

58 The full result is then collected as a file list and given to the next module (if any). 

59 """ 

60 

61 LAST_LOG = -1 

62 

63 def __init__(self, config: Config): 

64 self.proc_data = ProcessingData(config) 

65 

66 def start(self, files: list) -> list: 

67 self._read_pipeline_calibration_data() 

68 for module in Globals.pipeline(self.proc_data.config.base.pipeline): 

69 log.info("Starting module %s with %s files", module, len(files)) 

70 if not files: 

71 log.warning("No files to process. Abort computation") 

72 break 

73 

74 self._prepare_processing_data(module, files) 

75 chunks = self._prepare_chunks(files, module) 

76 files = self._submit_jobs(chunks, module) 

77 

78 if self.proc_data.config.base.obsid is not None: 

79 lvl = Globals.pipeline(self.proc_data.config.base.pipeline).index(module) 

80 api = Api(self.proc_data.config) 

81 api.add_pipeline_run(self.proc_data.config.base.obsid, lvl) 

82 return files 

83 

84 def _read_pipeline_calibration_data(self): 

85 # run only once at the beggingnig of the whole pipeline 

86 if self.proc_data.config.calib_data.dark: 

87 self.proc_data.dark_image = Fits(self.proc_data.config.calib_data.dark).read() 

88 if self.proc_data.config.calib_data.mod_matrix: 

89 self.proc_data.mod_matrix = Fits(self.proc_data.config.calib_data.mod_matrix).read() 

90 if self.proc_data.config.calib_data.slit_flat: 

91 self.proc_data.slit_flat = Fits(self.proc_data.config.calib_data.slit_flat).read() 

92 if self.proc_data.config.calib_data.sensor_flat: 

93 self.proc_data.sensor_flat = Fits(self.proc_data.config.calib_data.sensor_flat).read() 

94 if self.proc_data.config.calib_data.soft_flat: 

95 if self.proc_data.config.base.soft_flat_correction: 

96 self.proc_data.soft_flat = Fits(self.proc_data.config.calib_data.soft_flat).read(hdu=0) 

97 if self.proc_data.config.base.prefilter_correction: 

98 if 'amended' in self.proc_data.config.calib_data.soft_flat: 

99 self.proc_data.prefilter_map = Fits(self.proc_data.config.calib_data.soft_flat).read(hdu=1) 

100 self.proc_data.wl_cal_axis = Fits(self.proc_data.config.calib_data.soft_flat).read(hdu=2) 

101 else: 

102 raise IllegalArgumentException("The soft flat file must be amended to extract the prefilter map") 

103 if self.proc_data.config.calib_data.offset_map: 

104 self.proc_data.offset_map = OffsetMap.from_file(self.proc_data.config.calib_data.offset_map) 

105 if self.proc_data.config.calib_data.grid_tgt_files: 

106 self.proc_data.shear_factor, self.proc_data.rot_angle = ShearAndRotLoader( 

107 self.proc_data.config.calib_data.grid_tgt_files 

108 ).run() 

109 

110 def _prepare_chunks(self, files: list, module: list) -> dict: 

111 return Chunker(files, self.proc_data.config, raw="C" in module).run().chunks 

112 

113 def _prepare_processing_data(self, module, files) -> None: 

114 for blck in module: 

115 if blck not in BlockRegistry(): 

116 raise IllegalArgumentException(f"Block [{blck}] is not known.") 

117 

118 block = BlockRegistry()[blck] 

119 block.prepare(self.proc_data, files=files) 

120 if not self.proc_data.has(block.input_needed(self.proc_data.config.is_cam3())): 

121 raise IllegalArgumentException(f"Block [{blck}] needs {block.input_needed()}, but not all are given.") 

122 

123 def _submit_jobs(self, chunks: dict, module: list) -> list: 

124 total = len(chunks) 

125 log.info("Starting %s jobs for %s worker processes.", total, self.proc_data.config.base.workers) 

126 args = [[nr, total, Collections.flatten(batch), self.proc_data, module] for nr, batch in chunks.items()] 

127 result = MP.simultaneous(Orchestrator._job, args, workers=self.proc_data.config.base.workers) 

128 progress.dot(flush=True) 

129 files, success = Orchestrator._collect_results(result) 

130 if not success: 

131 if not Orchestrator._find_error_files(files): 

132 raise IOError("Was not able to save all files. See log for details.") 

133 return files 

134 

135 @staticmethod 

136 def _collect_results(result: list) -> tuple: 

137 # returned tuple is (file_list: list, success: bool, write_metadata: bool) 

138 log.info("All jobs done. Collecting metadata...") 

139 total = len(result) 

140 for i in range(total): 

141 batch, _, write_metadata = result[i] 

142 if write_metadata: 

143 batch.write_metadata(overwrite=True) 

144 progress.bar(i, total) 

145 progress.bar(total, total, flush=True) 

146 files = Collections.flatten_sort([e[0].file_list() for e in result]) 

147 return files, all(e[1] for e in result) 

148 

149 @staticmethod 

150 def _find_error_files(files) -> bool: 

151 log.warning("Not all jobs reported success on writing files. checking...") 

152 all_good = True 

153 for f in files: 

154 if not os.path.exists(f): 

155 log.error("MISSING: %s", f) 

156 all_good = False 

157 if all_good: 

158 log.info("All expected files found. Continue normally.") 

159 return all_good 

160 

161 @staticmethod 

162 def _job(args) -> tuple: 

163 # args: 

164 # - 0: int: ID of the chunk 

165 # - 1: int: Total number of chunks 

166 # - 2: list: List of files to process 

167 # - 3: ProcessingData: The Additional processing data and configuration 

168 # - 4: list: The list of processing blocks to execute. 

169 Orchestrator._log_progress(args[0], args[1]) 

170 prepro = PreProcessor(args[2], args[3], args[4]).run() 

171 if prepro.existing_result is not None: 

172 progress.dot(char="+") 

173 return prepro.existing_result, True, False 

174 result = Processor.run(prepro.batch, prepro.callbacks, prepro.proc_data) 

175 popro = PostProcessor(result, args[3].config, args[4]).run() 

176 progress.dot(success=popro.success) 

177 return popro.result.header_copy(), popro.success, True 

178 

179 @staticmethod 

180 def _log_progress(current, total): 

181 percent = (int(current) * 100) // int(total) 

182 if percent % 10 == 0 and percent != Orchestrator.LAST_LOG: 

183 progress.dot(flush=True) 

184 log.info(" \t- %s %%", percent) 

185 Orchestrator.LAST_LOG = percent