Coverage for src/susi/reduc/pipeline/blocks/block_s.py: 86%

122 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4Module that holds the block `S` of the susi pipeline. 

5 

6@author: iglesias 

7""" 

8import os 

9 

10import numpy as np 

11from spectroflat.sensor.flat import Flat 

12from scipy.signal import savgol_filter 

13from src.susi import InsufficientDataException 

14 

15from .block import Block, BlockRegistry 

16from ..processing_data import ProcessingData 

17from ....base import Logging, IllegalStateException 

18from ....base.header_keys import * 

19from ....io import FitsBatch, Fits, Card 

20from ....utils import Collections 

21from ....utils.cropping import adapt_shape 

22from ....utils.header_checks import check_same_binning 

23from ....utils.sub_shift import Shifter 

24from ....utils.sub_shift import round_shift_for_roi 

25 

26 

27log = Logging.get_logger() 

28 

29 

30class BlockS(Block): 

31 """ 

32 ## BLOCK S: Shift and Slit Flat fielding 

33 

34 This block takes care of the following calibration steps: 

35 - Shift image in slit direction to the desired reference using the input 

36 measured slit mask shift (a smooth SLIT_FLAT_OFFSET from the metadata file) 

37 - Apply Slit Flat Correction 

38 """ 

39 

40 BLOCK_ID = 'S' 

41 

42 @staticmethod 

43 def start(batch: FitsBatch, proc_data: ProcessingData) -> FitsBatch: 

44 """ 

45 Callback-like entry point for the block 

46 """ 

47 return BlockS(batch, proc_data).run().result 

48 

49 def _algorithm(self): 

50 self._prep_flat() 

51 self._apply_binning_factor() 

52 for entry in self.batch: 

53 shift = self._get_shift(entry) 

54 self.result.batch.append(self._process_entry(entry, shift)) 

55 

56 def _process_entry(self, fits: Fits, shift) -> dict: 

57 return { 

58 'file': fits.path, 

59 'data': self._modify_data(fits, shift), 

60 'header': self._modify_header(fits, shift), 

61 } 

62 

63 def _modify_data(self, fits: Fits, shift) -> np.array: 

64 # check data shape to avoid non implemented cases 

65 if fits.data.shape[0] != 1: 

66 raise InsufficientDataException('Block S only works with single frame data for now') 

67 fits = self._shift_image(fits, shift) 

68 # crop to common roi of the whole shifted dataset 

69 fits.data = adapt_shape(fits.data, fits.header, self.roi) 

70 # divide by shifted flat 

71 if self.proc_data.config.base.slit_flat_corr_block_s: 

72 ff_data = np.array(Flat.save_divide(fits.data[0], self._mod_flat(int(fits.header[MOD_STATE])))) 

73 return ff_data[np.newaxis, :, :] 

74 else: 

75 return fits.data 

76 

77 def _mod_flat(self, state: int) -> np.array: 

78 if self.proc_data.slit_flat.data.shape[0] == 1: 

79 return self.proc_data.slit_flat.data[0] 

80 else: 

81 return self.proc_data.slit_flat.data[state] 

82 

83 def _get_shift(self, entry: Fits) -> float: 

84 file = os.path.basename(entry.path) 

85 try: 

86 idx = self.proc_data.slit_shifts['files'].index(file) 

87 except ValueError: 

88 raise ValueError(f'File {file} not found in self.proc_data.slit_shifts') 

89 yshift = self.proc_data.slit_shifts['shifts'][idx] 

90 return np.array([yshift, 0]) 

91 

92 def _prep_flat(self): 

93 # no flat nedeed case 

94 if self.proc_data.config.base.slit_flat_corr_block_s is False: 

95 self.roi = self.proc_data.slit_flat_shift_ref['common_roi'] 

96 return 

97 # shift flat to current data reference. 

98 # The flat was already shifted to its own reference, so we use SLIT_FLAT_REF_OFFSET 

99 self.flat_shift = self.proc_data.slit_flat.header[SLIT_FLAT_REF_OFFSET] 

100 self.flat_shift = self.flat_shift - self.proc_data.slit_flat_shift_ref['offset'] 

101 shifted_flat = [] 

102 for i in range(self.proc_data.slit_flat.data.shape[0]): 

103 sh_img = Shifter.d2shift(self.proc_data.slit_flat.data[i], np.array([self.flat_shift, 0])) 

104 shifted_flat.append(sh_img) 

105 self.proc_data.slit_flat.data = np.array(shifted_flat) 

106 # finds common roi with batch 

107 self._get_common_flat_roi() 

108 # adapt flat shape 

109 self.proc_data.slit_flat.data = adapt_shape( 

110 self.proc_data.slit_flat.data, self.proc_data.slit_flat.header, self.roi 

111 ) 

112 self.proc_data.slit_flat.header[ROI_Y0] = int(self.roi['y0']) 

113 self.proc_data.slit_flat.header[ROI_Y1] = int(self.roi['y1']) 

114 self.proc_data.slit_flat.header[ROI_X0] = int(self.roi['x0']) 

115 self.proc_data.slit_flat.header[ROI_X1] = int(self.roi['x1']) 

116 

117 def _shift_image(self, fits: Fits, shift) -> np.array: 

118 # positive shift moves torwards 0 

119 fits.data = Shifter.d2shift(fits.data[0], shift) 

120 fits.data = fits.data[np.newaxis, :, :] 

121 return fits 

122 

123 def _apply_binning_factor(self): 

124 check_same_binning(self.batch) 

125 bins = self.batch[0]['header'][SPATIAL_BIN] if SPATIAL_BIN in self.batch[0]['header'] else None 

126 if bins is None: 

127 return 

128 # TODO 

129 raise NotImplementedError('Binning in Block S is not implemented yet') 

130 bins = bins.split(',') 

131 # Do not bin mod state dimension 

132 bins.insert(0, 1) 

133 self.proc_data.slit_flat.data = Collections.bin(self.proc_data.slit_flat.data, Collections.as_int_array(bins)) 

134 

135 def _get_common_flat_roi(self) -> None: 

136 # TODO move to orechestrator ? 

137 flat_roi = self._get_shape(self.proc_data.slit_flat) 

138 # get usable roi of shifted flat 

139 shift = round_shift_for_roi(self.flat_shift) 

140 if shift > 0: 

141 flat_roi['y1'] = flat_roi['y1'] - shift 

142 elif shift < 0: 

143 flat_roi['y0'] = flat_roi['y0'] - shift 

144 # common area between shifted flat and shifted frames 

145 data_roi = self.proc_data.slit_flat_shift_ref['common_roi'] 

146 self.roi = { 

147 "y0": max(flat_roi["y0"], data_roi["y0"]), 

148 "y1": min(flat_roi["y1"], data_roi["y1"]), 

149 "x0": max(flat_roi["x0"], data_roi["x0"]), 

150 "x1": min(flat_roi["x1"], data_roi["x1"]), 

151 } 

152 

153 def _get_shape(self, fits: Fits) -> dict: 

154 h = fits.header 

155 x0 = int(h[ROI_X0]) if ROI_X0 in h else 0 

156 y0 = int(h[ROI_Y0]) if ROI_Y0 in h else 0 

157 x1 = int(h[ROI_X1]) if ROI_X1 in h else fits.data.shape[2] 

158 y1 = int(h[ROI_Y1]) if ROI_Y1 in h else fits.data.shape[1] 

159 return {'y0': y0, 'y1': y1, 'x0': x0, 'x1': x1} 

160 

161 def _modify_header(self, fits: Fits, shift): 

162 Fits.override_header(fits.header, Block.BLOCKS_APPLIED, BlockS.BLOCK_ID, append=True) 

163 if self.proc_data.config.base.slit_flat_corr_block_s: 

164 Fits.override_header( 

165 fits.header, FLAT_MAP_MOVING, value=os.path.basename(self.proc_data.slit_flat.path).split('.')[0] 

166 ) 

167 fits.header.append(Card(SHIFT_APPLIED_MOVING_FLAT, value=self.flat_shift, comment='[px]').to_card()) 

168 fits.header.append( 

169 Card(SLIT_FLAT_REF_OFFSET, value=self.proc_data.slit_flat_shift_ref['offset'], comment='[px]').to_card() 

170 ) 

171 fits.header.append(Card(SLIT_FLAT_REF_SLOPE, value=self.proc_data.slit_flat_shift_ref['slope']).to_card()) 

172 fits.header.append( 

173 Card(SLIT_FLAT_REF_FILE, value=os.path.basename(self.proc_data.slit_flat_shift_ref['file'])).to_card() 

174 ) 

175 if SHIFT_APPLIED not in fits.header: 

176 shift_str = f'{shift[0]:.6f}, {shift[1]:.6f}' 

177 fits.header.append(Card(SHIFT_APPLIED, value=shift_str, comment='[px]').to_card()) 

178 else: 

179 shift_str = fits.header[SHIFT_APPLIED].split(',') 

180 shift_str[0] = f'{float(shift_str[0]) + shift[0]:.6f}' 

181 shift_str[1] = f'{float(shift_str[1]) + shift[1]:.6f}' 

182 shift_str = ', '.join(shift_str) 

183 fits.header.append(Card(SHIFT_APPLIED, value=shift_str, comment='[px]').to_card()) 

184 self._update_roi_header(fits.header) 

185 self._update_rms_snr_mean(fits) 

186 return fits.header 

187 

188 @staticmethod 

189 def prepare(proc_data: ProcessingData) -> None: 

190 flat_nedeed = proc_data.slit_flat is None and proc_data.config.base.slit_flat_corr_block_s 

191 if not proc_data.config.is_cam3() and flat_nedeed: 

192 raise IllegalStateException('No SLIT FLAT given') 

193 

194 

195BlockRegistry().register(BlockS.BLOCK_ID, BlockS)