Coverage for src/susi/reduc/pipeline/orchestrator.py: 92%
118 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-08-11 10:03 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-08-11 10:03 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4Module that provides the orchestration for SUSI processing pipeline
6@author: hoelken
7"""
8import os.path
9from astropy.wcs.docstrings import b
10import numpy as np
12from spectroflat.smile import OffsetMap
14from . import BlockRegistry
15from .chunker import Chunker
16from .post_processor import PostProcessor
17from .processing_data import ProcessingData
18from .pre_processor import PreProcessor
19from .processor import Processor
20from ...io import Fits
21from ...base import Logging, Config, Globals, IllegalArgumentException, Api
22from ...utils import MP, Collections, progress
23from ..shear_distortion.shear_and_rot_correction import ShearAndRotLoader
25log = Logging.get_logger()
28class Orchestrator:
29 """
30 ## Module
31 A Module is a chain of `N` processing blocks that shall be executed in a row.
32 Each module is composed of a full processing chain of multiple blocks.
34 <pre>
35 ┣┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅ MODULE ┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┫ ┣┅┅┅┅
36 ╭─────────╮ ╭──────────────╮ ╭───────────────╮ ╭───────────────╮
37 │ Chunker │─┬─>│ PreProcessor │──>│ Processor │──>│ PostProcessor │─┬─>...
38 ╰─────────╯ │ ╰──────────────╯ ╰───────────────╯ ╰───────────────╯ │
39 │ ╭──────────────╮ ╭───────────────╮ ╭───────────────╮ │
40 ├─>│ PreProcessor │──>│ Processor │──>│ PostProcessor │─┤
41 │ ╰──────────────╯ ╰───────────────╯ ╰───────────────╯ │
42 │ ╭──────────────╮ ╭───────────────╮ ╭───────────────╮ │
43 ╰─>│ PreProcessor │──>│ Processor │──>│ PostProcessor │─╯
44 ╰──────────────╯ ╰───────────────╯ ╰───────────────╯
45 </pre>
47 The chunker will split the given file list in batches and start one job per batch created.
48 All jobs will be queued and processed by worker threads. All jobs will perform the
49 same set of configured operations (callbacks).
50 All the actual action is to be performed in the callbacks (i.e. the blocks).
51 To add a new processing block implement a new `Block` subclass that follows the live-cycle
52 described in the block class.
54 The `PreProcessor` takes care of the loading the input data
55 The `Processor` executes the chain of processing blocks as configured in the pipeline
56 Finally, the `PostProcessor` provides generic metadata and writes the results back to
57 the disk.
58 The full result is then collected as a file list and given to the next module (if any).
59 """
61 LAST_LOG = -1
63 def __init__(self, config: Config):
64 self.proc_data = ProcessingData(config)
66 def start(self, files: list) -> list:
67 self._read_pipeline_calibration_data()
68 for module in Globals.pipeline(self.proc_data.config.base.pipeline):
69 log.info("Starting module %s with %s files", module, len(files))
70 if not files:
71 log.warning("No files to process. Abort computation")
72 break
74 self._prepare_processing_data(module, files)
75 chunks = self._prepare_chunks(files, module)
76 files = self._submit_jobs(chunks, module)
78 if self.proc_data.config.base.obsid is not None:
79 lvl = Globals.pipeline(self.proc_data.config.base.pipeline).index(module)
80 api = Api(self.proc_data.config)
81 api.add_pipeline_run(self.proc_data.config.base.obsid, lvl)
82 return files
84 def _read_pipeline_calibration_data(self):
85 # run only once at the beggingnig of the whole pipeline
86 if self.proc_data.config.calib_data.dark:
87 self.proc_data.dark_image = Fits(self.proc_data.config.calib_data.dark).read()
88 if self.proc_data.config.calib_data.mod_matrix:
89 self.proc_data.mod_matrix = Fits(self.proc_data.config.calib_data.mod_matrix).read()
90 if self.proc_data.config.calib_data.slit_flat:
91 self.proc_data.slit_flat = Fits(self.proc_data.config.calib_data.slit_flat).read()
92 if self.proc_data.config.calib_data.sensor_flat:
93 self.proc_data.sensor_flat = Fits(self.proc_data.config.calib_data.sensor_flat).read()
94 if self.proc_data.config.calib_data.soft_flat:
95 if self.proc_data.config.base.soft_flat_correction:
96 self.proc_data.soft_flat = Fits(self.proc_data.config.calib_data.soft_flat).read(hdu=0)
97 if self.proc_data.config.base.prefilter_correction:
98 if 'amended' in self.proc_data.config.calib_data.soft_flat:
99 self.proc_data.prefilter_map = Fits(self.proc_data.config.calib_data.soft_flat).read(hdu=1)
100 self.proc_data.wl_cal_axis = Fits(self.proc_data.config.calib_data.soft_flat).read(hdu=2)
101 else:
102 raise IllegalArgumentException("The soft flat file must be amended to extract the prefilter map")
103 if self.proc_data.config.calib_data.offset_map:
104 self.proc_data.offset_map = OffsetMap.from_file(self.proc_data.config.calib_data.offset_map)
105 if self.proc_data.config.calib_data.grid_tgt_files:
106 self.proc_data.shear_factor, self.proc_data.rot_angle = ShearAndRotLoader(
107 self.proc_data.config.calib_data.grid_tgt_files
108 ).run()
110 def _prepare_chunks(self, files: list, module: list) -> dict:
111 return Chunker(files, self.proc_data.config, raw="C" in module).run().chunks
113 def _prepare_processing_data(self, module, files) -> None:
114 for blck in module:
115 if blck not in BlockRegistry():
116 raise IllegalArgumentException(f"Block [{blck}] is not known.")
118 block = BlockRegistry()[blck]
119 block.prepare(self.proc_data, files=files)
120 if not self.proc_data.has(block.input_needed(self.proc_data.config.is_cam3())):
121 raise IllegalArgumentException(f"Block [{blck}] needs {block.input_needed()}, but not all are given.")
123 def _submit_jobs(self, chunks: dict, module: list) -> list:
124 total = len(chunks)
125 log.info("Starting %s jobs for %s worker processes.", total, self.proc_data.config.base.workers)
126 args = [[nr, total, Collections.flatten(batch), self.proc_data, module] for nr, batch in chunks.items()]
127 result = MP.simultaneous(Orchestrator._job, args, workers=self.proc_data.config.base.workers)
128 progress.dot(flush=True)
129 files, success = Orchestrator._collect_results(result)
130 if not success:
131 if not Orchestrator._find_error_files(files):
132 raise IOError("Was not able to save all files. See log for details.")
133 return files
135 @staticmethod
136 def _collect_results(result: list) -> tuple:
137 # returned tuple is (file_list: list, success: bool, write_metadata: bool)
138 log.info("All jobs done. Collecting metadata...")
139 total = len(result)
140 for i in range(total):
141 batch, _, write_metadata = result[i]
142 if write_metadata:
143 batch.write_metadata(overwrite=True)
144 progress.bar(i, total)
145 progress.bar(total, total, flush=True)
146 files = Collections.flatten_sort([e[0].file_list() for e in result])
147 return files, all(e[1] for e in result)
149 @staticmethod
150 def _find_error_files(files) -> bool:
151 log.warning("Not all jobs reported success on writing files. checking...")
152 all_good = True
153 for f in files:
154 if not os.path.exists(f):
155 log.error("MISSING: %s", f)
156 all_good = False
157 if all_good:
158 log.info("All expected files found. Continue normally.")
159 return all_good
161 @staticmethod
162 def _job(args) -> tuple:
163 # args:
164 # - 0: int: ID of the chunk
165 # - 1: int: Total number of chunks
166 # - 2: list: List of files to process
167 # - 3: ProcessingData: The Additional processing data and configuration
168 # - 4: list: The list of processing blocks to execute.
169 Orchestrator._log_progress(args[0], args[1])
170 prepro = PreProcessor(args[2], args[3], args[4]).run()
171 if prepro.existing_result is not None:
172 progress.dot(char="+")
173 return prepro.existing_result, True, False
174 result = Processor.run(prepro.batch, prepro.callbacks, prepro.proc_data)
175 popro = PostProcessor(result, args[3].config, args[4]).run()
176 progress.dot(success=popro.success)
177 return popro.result.header_copy(), popro.success, True
179 @staticmethod
180 def _log_progress(current, total):
181 percent = (int(current) * 100) // int(total)
182 if percent % 10 == 0 and percent != Orchestrator.LAST_LOG:
183 progress.dot(flush=True)
184 log.info(" \t- %s %%", percent)
185 Orchestrator.LAST_LOG = percent