Coverage for src/susi/reduc/pipeline/orchestrator.py: 91%
121 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4Module that provides the orchestration for SUSI processing pipeline
6@author: hoelken
7"""
8import os.path
9from astropy.wcs.docstrings import b
10import numpy as np
12from spectroflat.smile import OffsetMap
14from . import BlockRegistry
15from .chunker import Chunker
16from .post_processor import PostProcessor
17from .processing_data import ProcessingData
18from .pre_processor import PreProcessor
19from .processor import Processor
20from ...io import Fits
21from ...base import Logging, Config, Globals, IllegalArgumentException, Api
22from ...utils import MP, Collections, progress
23from ..shear_distortion.shear_and_rot_correction import ShearAndRotLoader
24from ..shift.slit_shift import SlitShiftRef
26log = Logging.get_logger()
29class Orchestrator:
30 """
31 ## Module
32 A Module is a chain of `N` processing blocks that shall be executed in a row.
33 Each module is composed of a full processing chain of multiple blocks.
35 <pre>
36 ┣┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅ MODULE ┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┅┫ ┣┅┅┅┅
37 ╭─────────╮ ╭──────────────╮ ╭───────────────╮ ╭───────────────╮
38 │ Chunker │─┬─>│ PreProcessor │──>│ Processor │──>│ PostProcessor │─┬─>...
39 ╰─────────╯ │ ╰──────────────╯ ╰───────────────╯ ╰───────────────╯ │
40 │ ╭──────────────╮ ╭───────────────╮ ╭───────────────╮ │
41 ├─>│ PreProcessor │──>│ Processor │──>│ PostProcessor │─┤
42 │ ╰──────────────╯ ╰───────────────╯ ╰───────────────╯ │
43 │ ╭──────────────╮ ╭───────────────╮ ╭───────────────╮ │
44 ╰─>│ PreProcessor │──>│ Processor │──>│ PostProcessor │─╯
45 ╰──────────────╯ ╰───────────────╯ ╰───────────────╯
46 </pre>
48 The chunker will split the given file list in batches and start one job per batch created.
49 All jobs will be queued and processed by worker threads. All jobs will perform the
50 same set of configured operations (callbacks).
51 All the actual action is to be performed in the callbacks (i.e. the blocks).
52 To add a new processing block implement a new `Block` subclass that follows the live-cycle
53 described in the block class.
55 The `PreProcessor` takes care of the loading the input data
56 The `Processor` executes the chain of processing blocks as configured in the pipeline
57 Finally, the `PostProcessor` provides generic metadata and writes the results back to
58 the disk.
59 The full result is then collected as a file list and given to the next module (if any).
60 """
62 LAST_LOG = -1
64 def __init__(self, config: Config):
65 self.proc_data = ProcessingData(config)
67 def start(self, files: list) -> list:
68 self._read_pipeline_calibration_data()
69 for module in Globals.pipeline(self.proc_data.config.base.pipeline):
70 log.info("Starting module %s with %s files", module, len(files))
71 if not files:
72 log.warning("No files to process. Abort computation")
73 break
75 self._prepare_processing_data(module, files)
76 chunks = self._prepare_chunks(files, module)
77 files = self._submit_jobs(chunks, module)
79 if self.proc_data.config.base.obsid is not None:
80 lvl = Globals.pipeline(self.proc_data.config.base.pipeline).index(module)
81 api = Api(self.proc_data.config)
82 api.add_pipeline_run(self.proc_data.config.base.obsid, lvl)
83 return files
85 def _read_pipeline_calibration_data(self):
86 # run only once at the beggingnig of the whole pipeline
87 if self.proc_data.config.calib_data.dark:
88 self.proc_data.dark_image = Fits(self.proc_data.config.calib_data.dark).read()
89 if self.proc_data.config.calib_data.mod_matrix:
90 self.proc_data.mod_matrix = Fits(self.proc_data.config.calib_data.mod_matrix).read()
91 if self.proc_data.config.calib_data.slit_flat:
92 self.proc_data.slit_flat = Fits(self.proc_data.config.calib_data.slit_flat).read()
93 if self.proc_data.config.calib_data.sensor_flat:
94 self.proc_data.sensor_flat = Fits(self.proc_data.config.calib_data.sensor_flat).read()
95 if self.proc_data.config.calib_data.soft_flat:
96 if self.proc_data.config.base.soft_flat_correction:
97 self.proc_data.soft_flat = Fits(self.proc_data.config.calib_data.soft_flat).read(hdu=0)
98 if self.proc_data.config.base.prefilter_correction:
99 self.proc_data.prefilter_map = Fits(self.proc_data.config.calib_data.soft_flat).read(hdu=1)
100 self.proc_data.wl_cal_axis = Fits(self.proc_data.config.calib_data.soft_flat).read(hdu=2)
101 if self.proc_data.config.calib_data.offset_map:
102 self.proc_data.offset_map = OffsetMap.from_file(self.proc_data.config.calib_data.offset_map)
103 if self.proc_data.config.calib_data.grid_tgt_files:
104 self.proc_data.shear_factor, self.proc_data.rot_angle = ShearAndRotLoader(
105 self.proc_data.config.calib_data.grid_tgt_files
106 ).run()
108 def _prepare_chunks(self, files: list, module: list) -> dict:
109 return Chunker(files, self.proc_data.config, raw="C" in module).run().chunks
111 def _prepare_processing_data(self, module, files) -> None:
112 for blck in module:
113 if blck not in BlockRegistry():
114 raise IllegalArgumentException(f"Block [{blck}] is not known.")
116 block = BlockRegistry()[blck]
117 block.prepare(self.proc_data)
118 if not self.proc_data.has(block.input_needed(self.proc_data.config.is_cam3())):
119 raise IllegalArgumentException(f"Block [{blck}] needs {block.input_needed()}, but not all are given.")
121 # TODO move to block.prepare()
122 if blck == 'S':
123 self.proc_data.slit_flat_shift_ref, self.proc_data.slit_shifts = SlitShiftRef(
124 files, self.proc_data.config.reduc.slit_shift_proc_param
125 ).run()
126 if self.proc_data.slit_flat_shift_ref is None:
127 raise ValueError("No slit shift reference found. Abort processing.")
129 def _submit_jobs(self, chunks: dict, module: list) -> list:
130 total = len(chunks)
131 log.info("Starting %s jobs for %s worker processes.", total, self.proc_data.config.base.workers)
132 args = [[nr, total, Collections.flatten(batch), self.proc_data, module] for nr, batch in chunks.items()]
133 result = MP.simultaneous(Orchestrator._job, args, workers=self.proc_data.config.base.workers)
134 progress.dot(flush=True)
135 files, success = Orchestrator._collect_results(result)
136 if not success:
137 if not Orchestrator._find_error_files(files):
138 raise IOError("Was not able to save all files. See log for details.")
139 return files
141 @staticmethod
142 def _collect_results(result: list) -> tuple:
143 # returned tuple is (file_list: list, success: bool, write_metadata: bool)
144 log.info("All jobs done. Collecting metadata...")
145 total = len(result)
146 for i in range(total):
147 batch, _, write_metadata = result[i]
148 if write_metadata:
149 batch.write_metadata(overwrite=True)
150 progress.bar(i, total)
151 progress.bar(total, total, flush=True)
152 files = Collections.flatten_sort([e[0].file_list() for e in result])
153 return files, all(e[1] for e in result)
155 @staticmethod
156 def _find_error_files(files) -> bool:
157 log.warning("Not all jobs reported success on writing files. checking...")
158 all_good = True
159 for f in files:
160 if not os.path.exists(f):
161 log.error("MISSING: %s", f)
162 all_good = False
163 if all_good:
164 log.info("All expected files found. Continue normally.")
165 return all_good
167 @staticmethod
168 def _job(args) -> tuple:
169 # args:
170 # - 0: int: ID of the chunk
171 # - 1: int: Total number of chunks
172 # - 2: list: List of files to process
173 # - 3: ProcessingData: The Additional processing data and configuration
174 # - 4: list: The list of processing blocks to execute.
175 Orchestrator._log_progress(args[0], args[1])
176 prepro = PreProcessor(args[2], args[3], args[4]).run()
177 if prepro.existing_result is not None:
178 progress.dot(char="+")
179 return prepro.existing_result, True, False
180 result = Processor.run(prepro.batch, prepro.callbacks, prepro.proc_data)
181 popro = PostProcessor(result, args[3].config, args[4]).run()
182 progress.dot(success=popro.success)
183 return popro.result.header_copy(), popro.success, True
185 @staticmethod
186 def _log_progress(current, total):
187 percent = (int(current) * 100) // int(total)
188 if percent % 10 == 0 and percent != Orchestrator.LAST_LOG:
189 progress.dot(flush=True)
190 log.info(" \t- %s %%", percent)
191 Orchestrator.LAST_LOG = percent