Coverage for src/susi/reduc/pipeline/blocks/block_f.py: 86%

109 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-08-22 09:20 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4Module that holds the block `F` of the susi pipeline. 

5 

6@author: hoelken 

7""" 

8import os 

9 

10import numpy as np 

11from spectroflat.sensor.flat import Flat 

12 

13from .block import Block, BlockRegistry 

14from ..processing_data import ProcessingData 

15from ....base import Logging, IllegalStateException 

16from ....base.header_keys import * 

17from ....io import FitsBatch, Fits 

18from ....utils import Collections 

19from ....utils.cropping import adapt_shapes, adapt_shape 

20from ....utils.header_checks import check_same_binning 

21from ....utils.sub_shift import Shifter 

22 

23log = Logging.get_logger() 

24 

25 

26class BlockF(Block): 

27 """ 

28 ## BLOCK F: Flat fielding 

29 

30 This block takes care of the following calibration steps: 

31 - Sensor Flat Correction 

32 - TODO Fringes ? 

33 """ 

34 

35 BLOCK_ID = 'F' 

36 

37 @staticmethod 

38 def start(batch: FitsBatch, proc_data: ProcessingData) -> FitsBatch: 

39 """ 

40 Callback-like entry point for the block 

41 """ 

42 return BlockF(batch, proc_data).run().result 

43 

44 def _algorithm(self): 

45 self._merge_flats() 

46 self._get_roi() 

47 self._adapt_shapes() 

48 self._apply_binning_factor() 

49 for entry in self.batch: 

50 self.result.batch.append(self._process_entry(entry)) 

51 

52 def _process_entry(self, fits: Fits) -> dict: 

53 return { 

54 'file': fits.path, 

55 'data': self._modify_data(fits), 

56 'header': self._modify_header(fits), 

57 } 

58 

59 def _modify_data(self, fits: Fits) -> np.array: 

60 return np.array([Flat.save_divide(fits.data[0], self._mod_flat(int(fits.header[MOD_STATE])))]) 

61 

62 def _mod_flat(self, state: int) -> np.array: 

63 if self.proc_data.config.is_cam3() or self._flat.shape[0] == 1: 

64 return self._flat[0] 

65 else: 

66 return self._flat[state] 

67 

68 def _modify_header(self, fits: Fits): 

69 Fits.override_header(fits.header, Block.BLOCKS_APPLIED, BlockF.BLOCK_ID, append=True) 

70 Fits.override_header( 

71 fits.header, 

72 FLAT_MAP, 

73 value=os.path.basename(self.proc_data.sensor_flat.path).split('.')[0], 

74 comment='Fixed Sensor Flat File Used', 

75 ) 

76 if self.proc_data.config.base.slit_flat_corr_block_f: 

77 Fits.override_header( 

78 fits.header, 

79 FLAT_MAP_MOVING, 

80 value=os.path.basename(self.proc_data.slit_flat.path).split('.')[0], 

81 comment='Fixed Slit Flat File Used in block F', 

82 ) 

83 

84 if self.proc_data.config.base.soft_flat_correction: 

85 Fits.override_header( 

86 fits.header, 

87 FLAT_MAP_SOFT, 

88 value=os.path.basename(self.proc_data.soft_flat.path).split('.')[0], 

89 comment='Fixed Soft Flat File Used', 

90 ) 

91 

92 if not self.proc_data.config.is_cam3() and self.proc_data.config.base.prefilter_correction: 

93 Fits.override_header( 

94 fits.header, 

95 FLAT_MAP_PREFILTER, 

96 value=os.path.basename(self.proc_data.prefilter_map.path).split('.')[0], 

97 comment='Fixed Prefilter Map File (Ext. 1) Used', 

98 ) 

99 Fits.override_header( 

100 fits.header, 

101 FLAT_MAP_PREFILTER_MEAN, 

102 value=self.prefilter_map_mean, 

103 comment='Mean of Prefilter Map File (Ext. 1) Used', 

104 ) 

105 

106 if self.proc_data.config.reduc.soft_flat_wl_offset is not None: 

107 Fits.override_header( 

108 fits.header, 

109 SOFT_FLAT_WL_OFFSET, 

110 value=self.proc_data.config.reduc.soft_flat_wl_offset, 

111 comment='WL shift for Soft Flat and Prefilter maps', 

112 ) 

113 self._update_roi_header(fits.header) 

114 self._update_rms_snr_mean(fits) 

115 return fits.header 

116 

117 def _merge_flats(self): 

118 # sensor and slit flats 

119 if self.proc_data.config.is_cam3() or not self.proc_data.config.base.slit_flat_corr_block_f: 

120 self._flat = self.proc_data.sensor_flat.data 

121 else: 

122 self._flat = self.proc_data.sensor_flat.data * self.proc_data.slit_flat.data 

123 

124 # soft flat 

125 if not self.proc_data.config.is_cam3() and self.proc_data.config.base.soft_flat_correction: 

126 if self.proc_data.config.reduc.soft_flat_wl_offset is not None: 

127 self._flat *= self._shift_map_wl( 

128 self.proc_data.soft_flat.data, self.proc_data.config.reduc.soft_flat_wl_offset 

129 ) 

130 else: 

131 self._flat *= self.proc_data.soft_flat.data 

132 

133 # prefilter map 

134 if not self.proc_data.config.is_cam3() and self.proc_data.config.base.prefilter_correction: 

135 self.prefilter_map_mean = np.mean(self.proc_data.prefilter_map.data) 

136 if self.proc_data.config.reduc.soft_flat_wl_offset is not None: 

137 self._flat *= self._shift_map_wl( 

138 self.proc_data.prefilter_map.data / self.prefilter_map_mean, 

139 float(self.proc_data.config.reduc.soft_flat_wl_offset), 

140 ) 

141 else: 

142 self._flat *= self.proc_data.prefilter_map.data / self.prefilter_map_mean 

143 

144 def _shift_map_wl(self, map: np.array, wl_offset: float) -> np.array: 

145 """ 

146 Shift the map by the given wavelength offset 

147 """ 

148 if wl_offset == 0: 

149 return map 

150 px_offset = wl_offset / float(self.proc_data.wl_cal_axis.header[DISPERSION]) 

151 log.warning(f'Shifting flat map by {wl_offset} nm ({px_offset} px)') 

152 if np.ndim(map) == 2: 

153 return Shifter.d2shift(map, [0, -px_offset]) 

154 elif np.ndim(map) == 3: 

155 return Shifter.d2shift(map, [0, 0, -px_offset]) 

156 else: 

157 raise IllegalStateException(f'Flat map has unexpected shape {map.shape}, cannot shift it') 

158 

159 def _adapt_shapes(self): 

160 self.batch = adapt_shapes(self.batch, self.roi) 

161 self._flat = adapt_shape(self._flat, self.proc_data.sensor_flat.header, self.roi) 

162 

163 def _apply_binning_factor(self): 

164 check_same_binning(self.batch) 

165 bins = self.batch[0]['header'][SPATIAL_BIN] if SPATIAL_BIN in self.batch[0]['header'] else None 

166 if bins is None: 

167 return 

168 bins = bins.split(',') 

169 # Do not bin mod state dimension 

170 bins.insert(0, 1) 

171 self._flat = Collections.bin(self._flat, Collections.as_int_array(bins)) 

172 

173 def _get_roi(self) -> None: 

174 self._target_roi(self._flat_shape()) 

175 

176 def _flat_shape(self) -> dict: 

177 h = self.proc_data.sensor_flat.header 

178 x0 = int(h[ROI_X0]) if ROI_X0 in h else 0 

179 y0 = int(h[ROI_Y0]) if ROI_Y0 in h else 0 

180 x1 = int(h[ROI_X1]) if ROI_X1 in h else self._flat.shape[2] 

181 y1 = int(h[ROI_Y1]) if ROI_Y1 in h else self._flat.shape[1] 

182 return {'y0': y0, 'y1': y1, 'x0': x0, 'x1': x1} 

183 

184 @staticmethod 

185 def prepare(proc_data: ProcessingData, files=None) -> None: 

186 if proc_data.sensor_flat is None: 

187 raise IllegalStateException('No SENSOR FLAT given') 

188 slit_needed = proc_data.slit_flat is None and proc_data.config.base.slit_flat_corr_block_f 

189 if not proc_data.config.is_cam3() and slit_needed: 

190 raise IllegalStateException('No SLIT FLAT given') 

191 if proc_data.config.base.soft_flat_correction and proc_data.soft_flat is None: 

192 raise IllegalStateException('No SOFT FLAT file given') 

193 if proc_data.config.base.prefilter_correction and proc_data.prefilter_map is None: 

194 raise IllegalStateException('No amended SOFT FLAT file given to extract prefilter from FITS extension 1') 

195 if proc_data.config.reduc.soft_flat_wl_offset is not None and proc_data.wl_cal_axis is None: 

196 raise IllegalStateException('No amended soft_flat given to shift (need wl axis), see soft_flat_wl_offset') 

197 

198 @staticmethod 

199 def input_needed(cam3: bool = False) -> list: 

200 return ['sensor_flat'] # see comment in the block parent class 

201 

202 

203BlockRegistry().register(BlockF.BLOCK_ID, BlockF)