Coverage for src/susi/reduc/pipeline/blocks/block.py: 84%
99 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4Module that holds the base block.
6@author: hoelken
7"""
8from __future__ import annotations
10from typing import Type, Union
12from astropy.io import fits
13from numpy import shape
14from scipy.constants import h
15from src.susi.plot import roi
17from ..processing_data import ProcessingData
18from ....base import IllegalStateException
19from ....base.header_keys import *
20from ....db import FileDB
21from ....io import FitsBatch, Fits
22from ....utils.annotators import singleton
23from ....utils.cropping import common_shape
24from ....base import Logging
26log = Logging.get_logger()
29class Block:
30 """
31 ## Block
32 Base class for processing blocks.
33 All blocks will be executed by the pipeline framework and undergo the same life-cycle.
35 ### Add a Block
36 To add a new block the minimum you have to do is to derive from this class and implement
37 the static method `start` and the protected method `_algorithm`. Afterwards register the
38 new block.
40 **Example**
41 <pre>
42 class ExampleBlock(Block):
43 @staticmethod
44 def start(batch: FitsBatch, proc_data: ProcessingData) -> FitsBatch:
45 return ExampleBlock(batch, proc_data).run().result
47 def _algorithm() -> None:
48 # Implement the processing algorithm
49 pass
51 BlockRegistry().register('E', ExampleBlock)
52 </pre>
54 Please note that all class identifiers must be unique. The registry will throw an error, otherwise.
56 The algorithm part of the block class is executed in parallel for all chunks, but in the configured order
57 within a chunk.
59 ### Processing Requirements
60 If your class has special requirements to be given in the `ProcessingData`, you should add the
61 name of the required field(s) to the `input_needed()` result
63 **Example**
64 <pre>
65 @staticmethod
66 def input_needed() -> list:
67 return ['dark_file']
68 </pre>
70 ### Pre-Processing
71 As the `_algorithm` part runs in parallel it might make sense to perform some pre-processing
72 steps before forking individual jobs. If needed implement the `prepare` method to alter the
73 given `ProcessingData`.
75 **Example**
76 <pre>
77 @staticmethod
78 def prepare(proc_data: ProcessingData) -> None:
79 # Implement preperation steps
80 </pre>
82 ### Post-Processing
83 After the processing is done there are multiple options on how to influence the processing result
85 #### Header Modification
86 If you want to add header information to the result, implement the `_modify_header` method.
88 **Example**
89 <pre>
90 def _modify_header(self) -> None:
91 pass
92 </pre>
94 """
96 BLOCKS_APPLIED = "Blocks Applied"
98 # TODO: If checkups are done in the prepare method, this is not really nedeed.
99 # in any case this should get the full config as an input to be generic enough
100 @staticmethod
101 def input_needed(cam3: bool = False) -> list:
102 """Return a list of fields from the `ProcessingData` needed during the processing."""
103 return []
105 @staticmethod
106 def start(batch: FitsBatch, proc_data: ProcessingData) -> FitsBatch:
107 """Entrypoint for the framework for parallel processing"""
108 raise NotImplemented()
110 @staticmethod
111 def prepare(proc_data: ProcessingData) -> None:
112 """Apply modification to the processing data if needed"""
113 pass
115 @staticmethod
116 def predict_output(batch: FitsBatch, proc_data: ProcessingData):
117 """
118 Predict the output basename(s) after processing
120 This method allows for a smart look-ahead and check if the
121 result file(s) already exists for this module.
123 @returns the FitsBatch with new Filenames
124 """
125 return batch
127 def __init__(self, batch: FitsBatch, proc_data: ProcessingData):
128 self.batch = batch
129 self.db = FileDB(proc_data.config)
130 self.proc_data = proc_data
131 self.result = FitsBatch()
132 self.roi = None
134 def run(self) -> Block:
135 try:
136 if len(self.batch) > 0:
137 if self.batch[0]['data'] is not None:
138 shape = self.batch[0]['data'].shape
139 log.debug(f'Starting block {self.BLOCK_ID} with {len(self.batch)} Fits of shape {shape}')
140 if self.batch[0]['header'] is not None:
141 if ROI_X0 in self.batch[0]['header']:
142 x0, x1 = self.batch[0]['header'][ROI_X0], self.batch[0]['header'][ROI_X1]
143 y0, y1 = self.batch[0]['header'][ROI_Y0], self.batch[0]['header'][ROI_Y1]
144 log.debug(f'Current ROI is Y0={y0}, Y1={y1}, X0={x0}, X1={x1}')
145 else:
146 log.debug(f'Starting block {self.BLOCK_ID} with {len(self.batch)} empty Fits')
147 else:
148 log.debug(f'Starting block {self.BLOCK_ID} with empty batch')
149 self._algorithm()
150 finally:
151 self.proc_data.release(self.input_needed())
152 return self
154 def _algorithm(self) -> None:
155 raise NotImplemented()
157 def _target_roi(self, other_roi: dict) -> None:
158 data_roi = common_shape(self.batch)
159 self.roi = {
160 "y0": max(other_roi["y0"], data_roi["y0"]),
161 "y1": min(other_roi["y1"], data_roi["y1"]),
162 "x0": max(other_roi["x0"], data_roi["x0"]),
163 "x1": min(other_roi["x1"], data_roi["x1"]),
164 }
166 def _update_roi(self, other_roi=None, other_slice=None) -> None:
167 # Computes a new roi from other_roi defined relative to the current roi
168 if other_roi is not None:
169 if self.roi is None:
170 self.roi = other_roi
171 else:
172 self.roi = {
173 "y0": self.roi["y0"] + other_roi["y0"],
174 "y1": self.roi["y0"] + other_roi["y1"],
175 "x0": self.roi["x0"] + other_roi["x0"],
176 "x1": self.roi["x0"] + other_roi["x1"],
177 }
178 elif other_slice is not None:
179 if self.roi is None:
180 self.roi = {
181 "y0": other_slice[0].start,
182 "y1": other_slice[0].stop,
183 "x0": other_slice[1].start,
184 "x1": other_slice[1].stop,
185 }
186 else:
187 self.roi = {
188 "y0": self.roi["y0"] + other_slice[0].start,
189 "y1": self.roi["y0"] + other_slice[0].stop,
190 "x0": self.roi["x0"] + other_slice[1].start,
191 "x1": self.roi["x0"] + other_slice[1].stop,
192 }
193 else:
194 raise ValueError("Either other_roi or other_slice must be given")
196 def _update_roi_header(self, header: Union[fits.Header, dict]):
197 if self.roi is None:
198 return
200 # if wl calib is present and roi X is changed, we must update it
201 if DISPERSION in header and (int(header[ROI_X0]) != self.roi["x0"] or int(header[ROI_X1]) != self.roi["x1"]):
202 if self.proc_data.config.cam.name == "cam1": # min < max for cam1
203 header[MIN_WL_NM] += header[DISPERSION] * (self.roi["x0"] - float(header[ROI_X0]))
204 header[MAX_WL_PX] = int(self.roi["x1"] - self.roi["x0"])
205 header[MAX_WL_NM] = header[MIN_WL_NM] + header[DISPERSION] * (header[MAX_WL_PX] - header[MIN_WL_PX])
206 elif self.proc_data.config.cam.name == "cam2": # min > max for cam2
207 header[MIN_WL_NM] += header[DISPERSION] * (float(header[ROI_X1]) - self.roi["x1"])
208 header[MIN_WL_PX] = int(self.roi["x1"] - self.roi["x0"])
209 header[MAX_WL_NM] = header[MIN_WL_NM] + header[DISPERSION] * header[MIN_WL_PX]
211 Fits.override_header(header, ROI_X0, self.roi["x0"])
212 Fits.override_header(header, ROI_X1, self.roi["x1"])
213 Fits.override_header(header, ROI_Y0, self.roi["y0"])
214 Fits.override_header(header, ROI_Y1, self.roi["y1"])
216 def _update_rms_snr_mean(self, frame: Fits, roi: tuple = None):
217 Fits.override_header(frame.header, IMG_CONTRAST, f"{frame.contrast(roi):.4e}")
218 Fits.override_header(frame.header, IMG_RMS, f"{frame.rms(roi):.4e}")
219 Fits.override_header(frame.header, IMG_SNR, f"{frame.snr(roi):.4e}")
220 Fits.override_header(frame.header, IMG_MEAN, f"{frame.mean(roi):.4e}")
223@singleton
224class BlockRegistry:
225 """
226 Singleton class to register processing blocks
228 ### Usage
229 To register a block call the register function with an ID and the class object.
230 BlockRegistry().register('Foo', BlockFoo)
231 The class for the block can then be obtained via
232 Registry()['Foo']
234 """
236 registry = {}
238 def register(self, key: str, clazz: Type[Block]) -> None:
239 """Register a new block. Keys must be unique within the registry"""
240 if key in self:
241 raise IllegalStateException(f"Block with key {key} already registered.")
242 self.registry[key] = clazz
244 def __contains__(self, key: str) -> bool:
245 return key in self.registry.keys()
247 def __getitem__(self, key: str) -> Type[Block]:
248 return self.registry[key]