Coverage for src/susi/reduc/pipeline/blocks/block.py: 84%

99 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4Module that holds the base block. 

5 

6@author: hoelken 

7""" 

8from __future__ import annotations 

9 

10from typing import Type, Union 

11 

12from astropy.io import fits 

13from numpy import shape 

14from scipy.constants import h 

15from src.susi.plot import roi 

16 

17from ..processing_data import ProcessingData 

18from ....base import IllegalStateException 

19from ....base.header_keys import * 

20from ....db import FileDB 

21from ....io import FitsBatch, Fits 

22from ....utils.annotators import singleton 

23from ....utils.cropping import common_shape 

24from ....base import Logging 

25 

26log = Logging.get_logger() 

27 

28 

29class Block: 

30 """ 

31 ## Block 

32 Base class for processing blocks. 

33 All blocks will be executed by the pipeline framework and undergo the same life-cycle. 

34 

35 ### Add a Block 

36 To add a new block the minimum you have to do is to derive from this class and implement 

37 the static method `start` and the protected method `_algorithm`. Afterwards register the 

38 new block. 

39 

40 **Example** 

41 <pre> 

42 class ExampleBlock(Block): 

43 @staticmethod 

44 def start(batch: FitsBatch, proc_data: ProcessingData) -> FitsBatch: 

45 return ExampleBlock(batch, proc_data).run().result 

46 

47 def _algorithm() -> None: 

48 # Implement the processing algorithm 

49 pass 

50 

51 BlockRegistry().register('E', ExampleBlock) 

52 </pre> 

53 

54 Please note that all class identifiers must be unique. The registry will throw an error, otherwise. 

55 

56 The algorithm part of the block class is executed in parallel for all chunks, but in the configured order 

57 within a chunk. 

58 

59 ### Processing Requirements 

60 If your class has special requirements to be given in the `ProcessingData`, you should add the 

61 name of the required field(s) to the `input_needed()` result 

62 

63 **Example** 

64 <pre> 

65 @staticmethod 

66 def input_needed() -> list: 

67 return ['dark_file'] 

68 </pre> 

69 

70 ### Pre-Processing 

71 As the `_algorithm` part runs in parallel it might make sense to perform some pre-processing 

72 steps before forking individual jobs. If needed implement the `prepare` method to alter the 

73 given `ProcessingData`. 

74 

75 **Example** 

76 <pre> 

77 @staticmethod 

78 def prepare(proc_data: ProcessingData) -> None: 

79 # Implement preperation steps 

80 </pre> 

81 

82 ### Post-Processing 

83 After the processing is done there are multiple options on how to influence the processing result 

84 

85 #### Header Modification 

86 If you want to add header information to the result, implement the `_modify_header` method. 

87 

88 **Example** 

89 <pre> 

90 def _modify_header(self) -> None: 

91 pass 

92 </pre> 

93 

94 """ 

95 

96 BLOCKS_APPLIED = "Blocks Applied" 

97 

98 # TODO: If checkups are done in the prepare method, this is not really nedeed. 

99 # in any case this should get the full config as an input to be generic enough 

100 @staticmethod 

101 def input_needed(cam3: bool = False) -> list: 

102 """Return a list of fields from the `ProcessingData` needed during the processing.""" 

103 return [] 

104 

105 @staticmethod 

106 def start(batch: FitsBatch, proc_data: ProcessingData) -> FitsBatch: 

107 """Entrypoint for the framework for parallel processing""" 

108 raise NotImplemented() 

109 

110 @staticmethod 

111 def prepare(proc_data: ProcessingData) -> None: 

112 """Apply modification to the processing data if needed""" 

113 pass 

114 

115 @staticmethod 

116 def predict_output(batch: FitsBatch, proc_data: ProcessingData): 

117 """ 

118 Predict the output basename(s) after processing 

119 

120 This method allows for a smart look-ahead and check if the 

121 result file(s) already exists for this module. 

122 

123 @returns the FitsBatch with new Filenames 

124 """ 

125 return batch 

126 

127 def __init__(self, batch: FitsBatch, proc_data: ProcessingData): 

128 self.batch = batch 

129 self.db = FileDB(proc_data.config) 

130 self.proc_data = proc_data 

131 self.result = FitsBatch() 

132 self.roi = None 

133 

134 def run(self) -> Block: 

135 try: 

136 if len(self.batch) > 0: 

137 if self.batch[0]['data'] is not None: 

138 shape = self.batch[0]['data'].shape 

139 log.debug(f'Starting block {self.BLOCK_ID} with {len(self.batch)} Fits of shape {shape}') 

140 if self.batch[0]['header'] is not None: 

141 if ROI_X0 in self.batch[0]['header']: 

142 x0, x1 = self.batch[0]['header'][ROI_X0], self.batch[0]['header'][ROI_X1] 

143 y0, y1 = self.batch[0]['header'][ROI_Y0], self.batch[0]['header'][ROI_Y1] 

144 log.debug(f'Current ROI is Y0={y0}, Y1={y1}, X0={x0}, X1={x1}') 

145 else: 

146 log.debug(f'Starting block {self.BLOCK_ID} with {len(self.batch)} empty Fits') 

147 else: 

148 log.debug(f'Starting block {self.BLOCK_ID} with empty batch') 

149 self._algorithm() 

150 finally: 

151 self.proc_data.release(self.input_needed()) 

152 return self 

153 

154 def _algorithm(self) -> None: 

155 raise NotImplemented() 

156 

157 def _target_roi(self, other_roi: dict) -> None: 

158 data_roi = common_shape(self.batch) 

159 self.roi = { 

160 "y0": max(other_roi["y0"], data_roi["y0"]), 

161 "y1": min(other_roi["y1"], data_roi["y1"]), 

162 "x0": max(other_roi["x0"], data_roi["x0"]), 

163 "x1": min(other_roi["x1"], data_roi["x1"]), 

164 } 

165 

166 def _update_roi(self, other_roi=None, other_slice=None) -> None: 

167 # Computes a new roi from other_roi defined relative to the current roi 

168 if other_roi is not None: 

169 if self.roi is None: 

170 self.roi = other_roi 

171 else: 

172 self.roi = { 

173 "y0": self.roi["y0"] + other_roi["y0"], 

174 "y1": self.roi["y0"] + other_roi["y1"], 

175 "x0": self.roi["x0"] + other_roi["x0"], 

176 "x1": self.roi["x0"] + other_roi["x1"], 

177 } 

178 elif other_slice is not None: 

179 if self.roi is None: 

180 self.roi = { 

181 "y0": other_slice[0].start, 

182 "y1": other_slice[0].stop, 

183 "x0": other_slice[1].start, 

184 "x1": other_slice[1].stop, 

185 } 

186 else: 

187 self.roi = { 

188 "y0": self.roi["y0"] + other_slice[0].start, 

189 "y1": self.roi["y0"] + other_slice[0].stop, 

190 "x0": self.roi["x0"] + other_slice[1].start, 

191 "x1": self.roi["x0"] + other_slice[1].stop, 

192 } 

193 else: 

194 raise ValueError("Either other_roi or other_slice must be given") 

195 

196 def _update_roi_header(self, header: Union[fits.Header, dict]): 

197 if self.roi is None: 

198 return 

199 

200 # if wl calib is present and roi X is changed, we must update it 

201 if DISPERSION in header and (int(header[ROI_X0]) != self.roi["x0"] or int(header[ROI_X1]) != self.roi["x1"]): 

202 if self.proc_data.config.cam.name == "cam1": # min < max for cam1 

203 header[MIN_WL_NM] += header[DISPERSION] * (self.roi["x0"] - float(header[ROI_X0])) 

204 header[MAX_WL_PX] = int(self.roi["x1"] - self.roi["x0"]) 

205 header[MAX_WL_NM] = header[MIN_WL_NM] + header[DISPERSION] * (header[MAX_WL_PX] - header[MIN_WL_PX]) 

206 elif self.proc_data.config.cam.name == "cam2": # min > max for cam2 

207 header[MIN_WL_NM] += header[DISPERSION] * (float(header[ROI_X1]) - self.roi["x1"]) 

208 header[MIN_WL_PX] = int(self.roi["x1"] - self.roi["x0"]) 

209 header[MAX_WL_NM] = header[MIN_WL_NM] + header[DISPERSION] * header[MIN_WL_PX] 

210 

211 Fits.override_header(header, ROI_X0, self.roi["x0"]) 

212 Fits.override_header(header, ROI_X1, self.roi["x1"]) 

213 Fits.override_header(header, ROI_Y0, self.roi["y0"]) 

214 Fits.override_header(header, ROI_Y1, self.roi["y1"]) 

215 

216 def _update_rms_snr_mean(self, frame: Fits, roi: tuple = None): 

217 Fits.override_header(frame.header, IMG_CONTRAST, f"{frame.contrast(roi):.4e}") 

218 Fits.override_header(frame.header, IMG_RMS, f"{frame.rms(roi):.4e}") 

219 Fits.override_header(frame.header, IMG_SNR, f"{frame.snr(roi):.4e}") 

220 Fits.override_header(frame.header, IMG_MEAN, f"{frame.mean(roi):.4e}") 

221 

222 

223@singleton 

224class BlockRegistry: 

225 """ 

226 Singleton class to register processing blocks 

227 

228 ### Usage 

229 To register a block call the register function with an ID and the class object. 

230 BlockRegistry().register('Foo', BlockFoo) 

231 The class for the block can then be obtained via 

232 Registry()['Foo'] 

233 

234 """ 

235 

236 registry = {} 

237 

238 def register(self, key: str, clazz: Type[Block]) -> None: 

239 """Register a new block. Keys must be unique within the registry""" 

240 if key in self: 

241 raise IllegalStateException(f"Block with key {key} already registered.") 

242 self.registry[key] = clazz 

243 

244 def __contains__(self, key: str) -> bool: 

245 return key in self.registry.keys() 

246 

247 def __getitem__(self, key: str) -> Type[Block]: 

248 return self.registry[key]