Coverage for src/susi/reduc/pipeline/blocks/block_s.py: 82%

130 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-08-11 10:03 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4Module that holds the block `S` of the susi pipeline. 

5 

6@author: iglesias 

7""" 

8import os 

9 

10import numpy as np 

11from spectroflat.sensor.flat import Flat 

12from scipy.signal import savgol_filter 

13from src.susi import InsufficientDataException 

14 

15from .block import Block, BlockRegistry 

16from ..processing_data import ProcessingData 

17from ....base import Logging, IllegalStateException 

18from ....base.header_keys import * 

19from ....io import FitsBatch, Fits, Card 

20from ....utils import Collections 

21from ....utils.cropping import adapt_shape 

22from ....utils.header_checks import check_same_binning 

23from ....utils.sub_shift import Shifter 

24from ....utils.sub_shift import round_shift_for_roi 

25from ...fields.flat_field_slit_correction import SlitFFCorrector 

26 

27 

28log = Logging.get_logger() 

29 

30 

31class BlockS(Block): 

32 """ 

33 ## BLOCK S: Shift and Slit Flat fielding 

34 

35 This block takes care of the following calibration steps: 

36 - Shift image in slit direction to the desired reference using the input 

37 measured slit mask shift (a smooth SLIT_FLAT_OFFSET from the metadata file) 

38 - Apply Slit Flat Correction 

39 """ 

40 

41 BLOCK_ID = 'S' 

42 

43 @staticmethod 

44 def start(batch: FitsBatch, proc_data: ProcessingData) -> FitsBatch: 

45 """ 

46 Callback-like entry point for the block 

47 """ 

48 return BlockS(batch, proc_data).run().result 

49 

50 def _algorithm(self): 

51 self._prep_flat() 

52 self._apply_binning_factor() 

53 for entry in self.batch: 

54 shift = self._get_shift(entry) 

55 self.result.batch.append(self._process_entry(entry, shift)) 

56 

57 def _process_entry(self, fits: Fits, shift) -> dict: 

58 return { 

59 'file': fits.path, 

60 'data': self._modify_data(fits, shift), 

61 'header': self._modify_header(fits, shift), 

62 } 

63 

64 def _modify_data(self, fits: Fits, shift) -> np.array: 

65 # check data shape to avoid non implemented cases 

66 if fits.data.shape[0] != 1: 

67 raise InsufficientDataException('Block S only works with single frame data for now') 

68 fits = self._shift_image(fits, shift) 

69 # crop to common roi of the whole shifted dataset 

70 fits.data = adapt_shape(fits.data, fits.header, self.roi) 

71 # divide by shifted flat 

72 if self.proc_data.config.base.slit_flat_corr_block_s: 

73 ff_data = np.array(Flat.save_divide(fits.data[0], self._mod_flat(int(fits.header[MOD_STATE])))) 

74 return ff_data[np.newaxis, :, :] 

75 else: 

76 return fits.data 

77 

78 def _mod_flat(self, state: int) -> np.array: 

79 if self.proc_data.slit_flat.data.shape[0] == 1: 

80 return self.proc_data.slit_flat.data[0] 

81 else: 

82 return self.proc_data.slit_flat.data[state] 

83 

84 def _get_shift(self, entry: Fits) -> float: 

85 file = os.path.basename(entry.path) 

86 try: 

87 idx = self.proc_data.slit_shifts['files'].index(file) 

88 except ValueError: 

89 raise ValueError(f'File {file} not found in self.proc_data.slit_shifts') 

90 yshift = self.proc_data.slit_shifts['shifts'][idx] 

91 return np.array([yshift, 0]) 

92 

93 def _prep_flat(self): 

94 # no flat nedeed case 

95 if self.proc_data.config.base.slit_flat_corr_block_s is False: 

96 self.roi = self.proc_data.slit_flat_shift_ref['common_roi'] 

97 return 

98 # shift flat to current data reference. 

99 # The flat was already shifted to its own reference, so we use SLIT_FLAT_REF_OFFSET 

100 self.flat_shift = self.proc_data.slit_flat.header[SLIT_FLAT_REF_OFFSET] 

101 self.flat_shift = self.flat_shift - self.proc_data.slit_flat_shift_ref['offset'] 

102 shifted_flat = [] 

103 for i in range(self.proc_data.slit_flat.data.shape[0]): 

104 sh_img = Shifter.d2shift(self.proc_data.slit_flat.data[i], np.array([self.flat_shift, 0])) 

105 shifted_flat.append(sh_img) 

106 self.proc_data.slit_flat.data = np.array(shifted_flat) 

107 # finds common roi with batch 

108 self._get_common_flat_roi() 

109 # adapt flat shape 

110 self.proc_data.slit_flat.data = adapt_shape( 

111 self.proc_data.slit_flat.data, self.proc_data.slit_flat.header, self.roi 

112 ) 

113 self.proc_data.slit_flat.header[ROI_Y0] = int(self.roi['y0']) 

114 self.proc_data.slit_flat.header[ROI_Y1] = int(self.roi['y1']) 

115 self.proc_data.slit_flat.header[ROI_X0] = int(self.roi['x0']) 

116 self.proc_data.slit_flat.header[ROI_X1] = int(self.roi['x1']) 

117 

118 def _shift_image(self, fits: Fits, shift) -> np.array: 

119 # positive shift moves torwards 0 

120 fits.data = Shifter.d2shift(fits.data[0], shift) 

121 fits.data = fits.data[np.newaxis, :, :] 

122 return fits 

123 

124 def _apply_binning_factor(self): 

125 check_same_binning(self.batch) 

126 bins = self.batch[0]['header'][SPATIAL_BIN] if SPATIAL_BIN in self.batch[0]['header'] else None 

127 if bins is None: 

128 return 

129 # TODO 

130 raise NotImplementedError('Binning in Block S is not implemented yet') 

131 bins = bins.split(',') 

132 # Do not bin mod state dimension 

133 bins.insert(0, 1) 

134 self.proc_data.slit_flat.data = Collections.bin(self.proc_data.slit_flat.data, Collections.as_int_array(bins)) 

135 

136 def _get_common_flat_roi(self) -> None: 

137 # TODO move to orechestrator ? 

138 flat_roi = self._get_shape(self.proc_data.slit_flat) 

139 # get usable roi of shifted flat 

140 shift = round_shift_for_roi(self.flat_shift) 

141 if shift > 0: 

142 flat_roi['y1'] = flat_roi['y1'] - shift 

143 elif shift < 0: 

144 flat_roi['y0'] = flat_roi['y0'] - shift 

145 # common area between shifted flat and shifted frames 

146 data_roi = self.proc_data.slit_flat_shift_ref['common_roi'] 

147 self.roi = { 

148 "y0": max(flat_roi["y0"], data_roi["y0"]), 

149 "y1": min(flat_roi["y1"], data_roi["y1"]), 

150 "x0": max(flat_roi["x0"], data_roi["x0"]), 

151 "x1": min(flat_roi["x1"], data_roi["x1"]), 

152 } 

153 

154 def _get_shape(self, fits: Fits) -> dict: 

155 h = fits.header 

156 x0 = int(h[ROI_X0]) if ROI_X0 in h else 0 

157 y0 = int(h[ROI_Y0]) if ROI_Y0 in h else 0 

158 x1 = int(h[ROI_X1]) if ROI_X1 in h else fits.data.shape[2] 

159 y1 = int(h[ROI_Y1]) if ROI_Y1 in h else fits.data.shape[1] 

160 return {'y0': y0, 'y1': y1, 'x0': x0, 'x1': x1} 

161 

162 def _modify_header(self, fits: Fits, shift): 

163 Fits.override_header(fits.header, Block.BLOCKS_APPLIED, BlockS.BLOCK_ID, append=True) 

164 if self.proc_data.config.base.slit_flat_corr_block_s: 

165 Fits.override_header( 

166 fits.header, FLAT_MAP_MOVING, value=os.path.basename(self.proc_data.slit_flat.path).split('.')[0] 

167 ) 

168 fits.header.append(Card(SHIFT_APPLIED_MOVING_FLAT, value=self.flat_shift, comment='[px]').to_card()) 

169 fits.header.append( 

170 Card(SLIT_FLAT_REF_OFFSET, value=self.proc_data.slit_flat_shift_ref['offset'], comment='[px]').to_card() 

171 ) 

172 fits.header.append(Card(SLIT_FLAT_REF_SLOPE, value=self.proc_data.slit_flat_shift_ref['slope']).to_card()) 

173 fits.header.append( 

174 Card(SLIT_FLAT_REF_FILE, value=os.path.basename(self.proc_data.slit_flat_shift_ref['file'])).to_card() 

175 ) 

176 if SHIFT_APPLIED not in fits.header: 

177 shift_str = f'{shift[0]:.6f}, {shift[1]:.6f}' 

178 fits.header.append(Card(SHIFT_APPLIED, value=shift_str, comment='[px]').to_card()) 

179 else: 

180 shift_str = fits.header[SHIFT_APPLIED].split(',') 

181 shift_str[0] = f'{float(shift_str[0]) + shift[0]:.6f}' 

182 shift_str[1] = f'{float(shift_str[1]) + shift[1]:.6f}' 

183 shift_str = ', '.join(shift_str) 

184 fits.header.append(Card(SHIFT_APPLIED, value=shift_str, comment='[px]').to_card()) 

185 self._update_roi_header(fits.header) 

186 self._update_rms_snr_mean(fits) 

187 return fits.header 

188 

189 @staticmethod 

190 def prepare(proc_data: ProcessingData, files=None) -> None: 

191 flat_nedeed = proc_data.slit_flat is None and proc_data.config.base.slit_flat_corr_block_s 

192 if not proc_data.config.is_cam3() and flat_nedeed: 

193 raise IllegalStateException('No SLIT FLAT given') 

194 

195 log.info('Starting slit flat shift estimation, this may take a while...') 

196 ff = SlitFFCorrector(files, proc_data.slit_flat, proc_data.config) 

197 ff.compute_shifts( 

198 plot=True, 

199 slit_roi=proc_data.config.reduc.slit_ff_slit_roi, 

200 wl_roi=proc_data.config.reduc.slit_ff_wl_roi, 

201 ini_val=proc_data.config.reduc.slit_ff_global_ini, 

202 ) 

203 proc_data.slit_shifts = { 

204 'files': ff.files, 

205 'shifts': ff.shifts, 

206 } 

207 proc_data.slit_flat_shift_ref = { 

208 'file': proc_data.slit_flat.path, 

209 } 

210 log.info(f'Slit flat shift estimation done for {len(proc_data.slit_shifts["files"])} files') 

211 breakpoint() 

212 

213 

214BlockRegistry().register(BlockS.BLOCK_ID, BlockS)