Coverage for src/susi/reduc/pipeline/blocks/block_f.py: 88%

91 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4Module that holds the block `F` of the susi pipeline. 

5 

6@author: hoelken 

7""" 

8import os 

9 

10import numpy as np 

11from spectroflat.sensor.flat import Flat 

12 

13from src.susi import FLAT_MAP_SOFT 

14 

15from .block import Block, BlockRegistry 

16from ..processing_data import ProcessingData 

17from ....base import Logging, IllegalStateException, ROI_X0, ROI_X1, ROI_Y0, ROI_Y1 

18from ....base import FLAT_MAP, MOD_STATE, SPATIAL_BIN, FLAT_MAP_PREFILTER, FLAT_MAP_MOVING, FLAT_MAP_PREFILTER_MEAN 

19from ....io import FitsBatch, Fits 

20from ....utils import Collections 

21from ....utils.cropping import adapt_shapes, adapt_shape 

22from ....utils.header_checks import check_same_binning 

23 

24log = Logging.get_logger() 

25 

26 

27class BlockF(Block): 

28 """ 

29 ## BLOCK F: Flat fielding 

30 

31 This block takes care of the following calibration steps: 

32 - Sensor Flat Correction 

33 - TODO Fringes ? 

34 """ 

35 

36 BLOCK_ID = 'F' 

37 

38 @staticmethod 

39 def start(batch: FitsBatch, proc_data: ProcessingData) -> FitsBatch: 

40 """ 

41 Callback-like entry point for the block 

42 """ 

43 return BlockF(batch, proc_data).run().result 

44 

45 def _algorithm(self): 

46 self._merge_flats() 

47 self._get_roi() 

48 self._adapt_shapes() 

49 self._apply_binning_factor() 

50 for entry in self.batch: 

51 self.result.batch.append(self._process_entry(entry)) 

52 

53 def _process_entry(self, fits: Fits) -> dict: 

54 return { 

55 'file': fits.path, 

56 'data': self._modify_data(fits), 

57 'header': self._modify_header(fits), 

58 } 

59 

60 def _modify_data(self, fits: Fits) -> np.array: 

61 return np.array([Flat.save_divide(fits.data[0], self._mod_flat(int(fits.header[MOD_STATE])))]) 

62 

63 def _mod_flat(self, state: int) -> np.array: 

64 if self.proc_data.config.is_cam3() or self._flat.shape[0] == 1: 

65 return self._flat[0] 

66 else: 

67 return self._flat[state] 

68 

69 def _modify_header(self, fits: Fits): 

70 Fits.override_header(fits.header, Block.BLOCKS_APPLIED, BlockF.BLOCK_ID, append=True) 

71 Fits.override_header( 

72 fits.header, 

73 FLAT_MAP, 

74 value=os.path.basename(self.proc_data.sensor_flat.path).split('.')[0], 

75 comment='Fixed Sensor Flat File Used', 

76 ) 

77 if self.proc_data.config.base.slit_flat_corr_block_f: 

78 Fits.override_header( 

79 fits.header, 

80 FLAT_MAP_MOVING, 

81 value=os.path.basename(self.proc_data.slit_flat.path).split('.')[0], 

82 comment='Fixed Slit Flat File Used in block F', 

83 ) 

84 

85 if self.proc_data.config.base.soft_flat_correction: 

86 Fits.override_header( 

87 fits.header, 

88 FLAT_MAP_SOFT, 

89 value=os.path.basename(self.proc_data.soft_flat.path).split('.')[0], 

90 comment='Fixed Soft Flat File Used', 

91 ) 

92 

93 if not self.proc_data.config.is_cam3() and self.proc_data.config.base.prefilter_correction: 

94 Fits.override_header( 

95 fits.header, 

96 FLAT_MAP_PREFILTER, 

97 value=os.path.basename(self.proc_data.prefilter_map.path).split('.')[0], 

98 comment='Fixed Prefilter Map File (Ext. 1) Used', 

99 ) 

100 Fits.override_header( 

101 fits.header, 

102 FLAT_MAP_PREFILTER_MEAN, 

103 value=self.prefilter_map_mean, 

104 comment='Mean of Prefilter Map File (Ext. 1) Used', 

105 ) 

106 

107 self._update_roi_header(fits.header) 

108 self._update_rms_snr_mean(fits) 

109 return fits.header 

110 

111 def _merge_flats(self): 

112 # sensor and slit flats 

113 if self.proc_data.config.is_cam3() or not self.proc_data.config.base.slit_flat_corr_block_f: 

114 self._flat = self.proc_data.sensor_flat.data 

115 else: 

116 self._flat = self.proc_data.sensor_flat.data * self.proc_data.slit_flat.data 

117 # soft flat 

118 if not self.proc_data.config.is_cam3() and self.proc_data.config.base.soft_flat_correction: 

119 self._flat *= self.proc_data.soft_flat.data 

120 # prefilter map 

121 if not self.proc_data.config.is_cam3() and self.proc_data.config.base.prefilter_correction: 

122 self.prefilter_map_mean = np.mean(self.proc_data.prefilter_map.data) 

123 self._flat *= self.proc_data.prefilter_map.data / self.prefilter_map_mean 

124 

125 def _adapt_shapes(self): 

126 self.batch = adapt_shapes(self.batch, self.roi) 

127 self._flat = adapt_shape(self._flat, self.proc_data.sensor_flat.header, self.roi) 

128 

129 def _apply_binning_factor(self): 

130 check_same_binning(self.batch) 

131 bins = self.batch[0]['header'][SPATIAL_BIN] if SPATIAL_BIN in self.batch[0]['header'] else None 

132 if bins is None: 

133 return 

134 bins = bins.split(',') 

135 # Do not bin mod state dimension 

136 bins.insert(0, 1) 

137 self._flat = Collections.bin(self._flat, Collections.as_int_array(bins)) 

138 

139 def _get_roi(self) -> None: 

140 self._target_roi(self._flat_shape()) 

141 

142 def _flat_shape(self) -> dict: 

143 h = self.proc_data.sensor_flat.header 

144 x0 = int(h[ROI_X0]) if ROI_X0 in h else 0 

145 y0 = int(h[ROI_Y0]) if ROI_Y0 in h else 0 

146 x1 = int(h[ROI_X1]) if ROI_X1 in h else self._flat.shape[2] 

147 y1 = int(h[ROI_Y1]) if ROI_Y1 in h else self._flat.shape[1] 

148 return {'y0': y0, 'y1': y1, 'x0': x0, 'x1': x1} 

149 

150 @staticmethod 

151 def prepare(proc_data: ProcessingData) -> None: 

152 if proc_data.sensor_flat is None: 

153 raise IllegalStateException('No SENSOR FLAT given') 

154 slit_needed = proc_data.slit_flat is None and proc_data.config.base.slit_flat_corr_block_f 

155 if not proc_data.config.is_cam3() and slit_needed: 

156 raise IllegalStateException('No SLIT FLAT given') 

157 if proc_data.config.base.soft_flat_correction and proc_data.soft_flat is None: 

158 raise IllegalStateException('No SOFT FLAT file given') 

159 if proc_data.config.base.prefilter_correction and proc_data.prefilter_map is None: 

160 raise IllegalStateException('No SOFT FLAT file given to extract prefilter from FITS extension 1') 

161 

162 @staticmethod 

163 def input_needed(cam3: bool = False) -> list: 

164 return ['sensor_flat'] # see comment in the block parent class 

165 

166 

167BlockRegistry().register(BlockF.BLOCK_ID, BlockF)