Coverage for src/susi/reduc/pipeline/orchestrator.py: 91%

121 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4Module that provides the orchestration for SUSI processing pipeline 

5 

6@author: hoelken 

7""" 

8import os.path 

9from astropy.wcs.docstrings import b 

10import numpy as np 

11 

12from spectroflat.smile import OffsetMap 

13 

14from . import BlockRegistry 

15from .chunker import Chunker 

16from .post_processor import PostProcessor 

17from .processing_data import ProcessingData 

18from .pre_processor import PreProcessor 

19from .processor import Processor 

20from ...io import Fits 

21from ...base import Logging, Config, Globals, IllegalArgumentException, Api 

22from ...utils import MP, Collections, progress 

23from ..shear_distortion.shear_and_rot_correction import ShearAndRotLoader 

24from ..shift.slit_shift import SlitShiftRef 

25 

26log = Logging.get_logger() 

27 

28 

29class Orchestrator: 

30 """ 

31 ## Module 

32 A Module is a chain of `N` processing blocks that shall be executed in a row. 

33 Each module is composed of a full processing chain of multiple blocks. 

34 

35 <pre> 

36 ┣┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅ MODULE ┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┫ ┣┅┅┅┅ 

37 ╭─────────╮ ╭──────────────╮ ╭───────────────╮ ╭───────────────╮ 

38 │ Chunker │─┬─>│ PreProcessor │──>│ Processor │──>│ PostProcessor │─┬─>... 

39 ╰─────────╯ │ ╰──────────────╯ ╰───────────────╯ ╰───────────────╯ │ 

40 │ ╭──────────────╮ ╭───────────────╮ ╭───────────────╮ │ 

41 ├─>│ PreProcessor │──>│ Processor │──>│ PostProcessor │─┤ 

42 │ ╰──────────────╯ ╰───────────────╯ ╰───────────────╯ │ 

43 │ ╭──────────────╮ ╭───────────────╮ ╭───────────────╮ │ 

44 ╰─>│ PreProcessor │──>│ Processor │──>│ PostProcessor │─╯ 

45 ╰──────────────╯ ╰───────────────╯ ╰───────────────╯ 

46 </pre> 

47 

48 The chunker will split the given file list in batches and start one job per batch created. 

49 All jobs will be queued and processed by worker threads. All jobs will perform the 

50 same set of configured operations (callbacks). 

51 All the actual action is to be performed in the callbacks (i.e. the blocks). 

52 To add a new processing block implement a new `Block` subclass that follows the live-cycle 

53 described in the block class. 

54 

55 The `PreProcessor` takes care of the loading the input data 

56 The `Processor` executes the chain of processing blocks as configured in the pipeline 

57 Finally, the `PostProcessor` provides generic metadata and writes the results back to 

58 the disk. 

59 The full result is then collected as a file list and given to the next module (if any). 

60 """ 

61 

62 LAST_LOG = -1 

63 

64 def __init__(self, config: Config): 

65 self.proc_data = ProcessingData(config) 

66 

67 def start(self, files: list) -> list: 

68 self._read_pipeline_calibration_data() 

69 for module in Globals.pipeline(self.proc_data.config.base.pipeline): 

70 log.info("Starting module %s with %s files", module, len(files)) 

71 if not files: 

72 log.warning("No files to process. Abort computation") 

73 break 

74 

75 self._prepare_processing_data(module, files) 

76 chunks = self._prepare_chunks(files, module) 

77 files = self._submit_jobs(chunks, module) 

78 

79 if self.proc_data.config.base.obsid is not None: 

80 lvl = Globals.pipeline(self.proc_data.config.base.pipeline).index(module) 

81 api = Api(self.proc_data.config) 

82 api.add_pipeline_run(self.proc_data.config.base.obsid, lvl) 

83 return files 

84 

85 def _read_pipeline_calibration_data(self): 

86 # run only once at the beggingnig of the whole pipeline 

87 if self.proc_data.config.calib_data.dark: 

88 self.proc_data.dark_image = Fits(self.proc_data.config.calib_data.dark).read() 

89 if self.proc_data.config.calib_data.mod_matrix: 

90 self.proc_data.mod_matrix = Fits(self.proc_data.config.calib_data.mod_matrix).read() 

91 if self.proc_data.config.calib_data.slit_flat: 

92 self.proc_data.slit_flat = Fits(self.proc_data.config.calib_data.slit_flat).read() 

93 if self.proc_data.config.calib_data.sensor_flat: 

94 self.proc_data.sensor_flat = Fits(self.proc_data.config.calib_data.sensor_flat).read() 

95 if self.proc_data.config.calib_data.soft_flat: 

96 if self.proc_data.config.base.soft_flat_correction: 

97 self.proc_data.soft_flat = Fits(self.proc_data.config.calib_data.soft_flat).read(hdu=0) 

98 if self.proc_data.config.base.prefilter_correction: 

99 self.proc_data.prefilter_map = Fits(self.proc_data.config.calib_data.soft_flat).read(hdu=1) 

100 self.proc_data.wl_cal_axis = Fits(self.proc_data.config.calib_data.soft_flat).read(hdu=2) 

101 if self.proc_data.config.calib_data.offset_map: 

102 self.proc_data.offset_map = OffsetMap.from_file(self.proc_data.config.calib_data.offset_map) 

103 if self.proc_data.config.calib_data.grid_tgt_files: 

104 self.proc_data.shear_factor, self.proc_data.rot_angle = ShearAndRotLoader( 

105 self.proc_data.config.calib_data.grid_tgt_files 

106 ).run() 

107 

108 def _prepare_chunks(self, files: list, module: list) -> dict: 

109 return Chunker(files, self.proc_data.config, raw="C" in module).run().chunks 

110 

111 def _prepare_processing_data(self, module, files) -> None: 

112 for blck in module: 

113 if blck not in BlockRegistry(): 

114 raise IllegalArgumentException(f"Block [{blck}] is not known.") 

115 

116 block = BlockRegistry()[blck] 

117 block.prepare(self.proc_data) 

118 if not self.proc_data.has(block.input_needed(self.proc_data.config.is_cam3())): 

119 raise IllegalArgumentException(f"Block [{blck}] needs {block.input_needed()}, but not all are given.") 

120 

121 # TODO move to block.prepare() 

122 if blck == 'S': 

123 self.proc_data.slit_flat_shift_ref, self.proc_data.slit_shifts = SlitShiftRef( 

124 files, self.proc_data.config.reduc.slit_shift_proc_param 

125 ).run() 

126 if self.proc_data.slit_flat_shift_ref is None: 

127 raise ValueError("No slit shift reference found. Abort processing.") 

128 

129 def _submit_jobs(self, chunks: dict, module: list) -> list: 

130 total = len(chunks) 

131 log.info("Starting %s jobs for %s worker processes.", total, self.proc_data.config.base.workers) 

132 args = [[nr, total, Collections.flatten(batch), self.proc_data, module] for nr, batch in chunks.items()] 

133 result = MP.simultaneous(Orchestrator._job, args, workers=self.proc_data.config.base.workers) 

134 progress.dot(flush=True) 

135 files, success = Orchestrator._collect_results(result) 

136 if not success: 

137 if not Orchestrator._find_error_files(files): 

138 raise IOError("Was not able to save all files. See log for details.") 

139 return files 

140 

141 @staticmethod 

142 def _collect_results(result: list) -> tuple: 

143 # returned tuple is (file_list: list, success: bool, write_metadata: bool) 

144 log.info("All jobs done. Collecting metadata...") 

145 total = len(result) 

146 for i in range(total): 

147 batch, _, write_metadata = result[i] 

148 if write_metadata: 

149 batch.write_metadata(overwrite=True) 

150 progress.bar(i, total) 

151 progress.bar(total, total, flush=True) 

152 files = Collections.flatten_sort([e[0].file_list() for e in result]) 

153 return files, all(e[1] for e in result) 

154 

155 @staticmethod 

156 def _find_error_files(files) -> bool: 

157 log.warning("Not all jobs reported success on writing files. checking...") 

158 all_good = True 

159 for f in files: 

160 if not os.path.exists(f): 

161 log.error("MISSING: %s", f) 

162 all_good = False 

163 if all_good: 

164 log.info("All expected files found. Continue normally.") 

165 return all_good 

166 

167 @staticmethod 

168 def _job(args) -> tuple: 

169 # args: 

170 # - 0: int: ID of the chunk 

171 # - 1: int: Total number of chunks 

172 # - 2: list: List of files to process 

173 # - 3: ProcessingData: The Additional processing data and configuration 

174 # - 4: list: The list of processing blocks to execute. 

175 Orchestrator._log_progress(args[0], args[1]) 

176 prepro = PreProcessor(args[2], args[3], args[4]).run() 

177 if prepro.existing_result is not None: 

178 progress.dot(char="+") 

179 return prepro.existing_result, True, False 

180 result = Processor.run(prepro.batch, prepro.callbacks, prepro.proc_data) 

181 popro = PostProcessor(result, args[3].config, args[4]).run() 

182 progress.dot(success=popro.success) 

183 return popro.result.header_copy(), popro.success, True 

184 

185 @staticmethod 

186 def _log_progress(current, total): 

187 percent = (int(current) * 100) // int(total) 

188 if percent % 10 == 0 and percent != Orchestrator.LAST_LOG: 

189 progress.dot(flush=True) 

190 log.info(" \t- %s %%", percent) 

191 Orchestrator.LAST_LOG = percent