Coverage for src/susi/reduc/pipeline/blocks/block_m.py: 85%

95 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-08-22 09:20 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4Module that holds the block `M` of the susi pipeline. 

5 

6@author: hoelken 

7""" 

8import os 

9 

10import numpy as np 

11from spectroflat.smile import SmileInterpolator 

12 

13from .block import Block, BlockRegistry 

14from ..processing_data import ProcessingData 

15from ....base import Logging, IllegalStateException, Globals 

16from ....base.header_keys import * 

17from ....io import FitsBatch, Fits, Card 

18from ....utils import Collections 

19from ....utils.cropping import adapt_shapes 

20from ....utils.header_checks import check_same_binning 

21from ....utils.sub_shift import Shifter 

22 

23log = Logging.get_logger() 

24 

25 

26class BlockM(Block): 

27 """ 

28 ## BLOCK M: Morphological Operations 

29 

30 This block takes care of morphological calibration steps: 

31 - Smile Correction 

32 - TODO Support data binned in spatial and/or spectral dimension 

33 - TODO Image Registration 

34 - TODO Wavelength calibration (using amended offset map) 

35 """ 

36 

37 BLOCK_ID = 'M' 

38 

39 @staticmethod 

40 def start(batch: FitsBatch, proc_data: ProcessingData) -> FitsBatch: 

41 """ 

42 Callback-like entry point for the block 

43 """ 

44 return BlockM(batch, proc_data).run().result 

45 

46 def _algorithm(self): 

47 if self.proc_data.config.is_cam3(): 

48 log.debug('CAM3 run: Morphological operations skipped...') 

49 self.result = self.batch 

50 return 

51 

52 self._shift_offset_map_wl() 

53 self._get_roi() 

54 self._adapt_shapes() 

55 for entry in self.batch: 

56 self.result.batch.append(self._process_entry(entry)) 

57 

58 def _process_entry(self, fits: Fits) -> dict: 

59 return { 

60 'file': fits.path, 

61 'data': self._modify_data(fits), 

62 'header': self._modify_header(fits), 

63 } 

64 

65 def _modify_data(self, fits: Fits): 

66 if self.proc_data.offset_map.map.shape[0] == 1: 

67 self.proc_data.offset_map.map = np.repeat(self.proc_data.offset_map.map, Globals.MOD_CYCLE_FRAMES, axis=0) 

68 

69 sc = SmileInterpolator(self.proc_data.offset_map, fits.data[0], mod_state=int(fits.header[MOD_STATE])).run() 

70 return np.array([sc.result]) 

71 

72 def _modify_header(self, fits: Fits): 

73 Fits.override_header(fits.header, Block.BLOCKS_APPLIED, BlockM.BLOCK_ID, append=True) 

74 Fits.override_header( 

75 fits.header, OFFSET_MAP, value=os.path.basename(self.proc_data.offset_map.path).split('.')[0] 

76 ) 

77 self._update_roi_header(fits.header) 

78 

79 # check if wl-caelibrated offset map, ie. if DISPERSION is present 

80 if DISPERSION in self.proc_data.offset_map.header: 

81 Fits.override_header(fits.header, WL_CALIBRATED, value='TRUE') 

82 fits.header.append(Card(WL_CALIBRATED, value='TRUE').to_card()) 

83 fits.header.append(Card(DISPERSION, self.proc_data.offset_map.header[DISPERSION]).to_card()) 

84 fits.header.append(Card(MIN_WL_NM, self.proc_data.offset_map.header[MIN_WL_NM]).to_card()) 

85 fits.header.append(Card(MIN_WL_PX, self.proc_data.offset_map.header[MIN_WL_PX]).to_card()) 

86 fits.header.append(Card(MAX_WL_NM, self.proc_data.offset_map.header[MAX_WL_NM]).to_card()) 

87 fits.header.append(Card(MAX_WL_PX, self.proc_data.offset_map.header[MAX_WL_PX]).to_card()) 

88 else: 

89 Fits.override_header(fits.header, WL_CALIBRATED, value='FALSE') 

90 

91 if self.proc_data.config.reduc.offset_map_wl_offset is not None: 

92 log.warning(f'Shifting wl-axis by {self.wl_offset} nm ({self.px_offset} px)') 

93 fits.header[MIN_WL_NM] += self.wl_offset 

94 fits.header[MAX_WL_NM] += self.wl_offset 

95 Fits.override_header( 

96 fits.header, 

97 OFFSET_MAP_WL_OFFSET, 

98 value=self.proc_data.config.reduc.offset_map_wl_offset, 

99 comment='WL shift for Offset map', 

100 ) 

101 return fits.header 

102 

103 def _get_roi(self) -> None: 

104 h = self.proc_data.offset_map.header 

105 self._target_roi({'y0': int(h[ROI_Y0]), 'y1': int(h[ROI_Y1]), 'x0': int(h[ROI_X0]), 'x1': int(h[ROI_X1])}) 

106 

107 def _adapt_shapes(self): 

108 self.batch = adapt_shapes(self.batch, self.roi) 

109 h = self.proc_data.offset_map.header 

110 dx = int(h[ROI_X0]) if ROI_X0 in h else 0 

111 dy = int(h[ROI_Y0]) if ROI_Y0 in h else 0 

112 roi = ( 

113 slice(None, None), 

114 slice(self.roi['y0'] - dy, self.roi['y1'] - dy), 

115 slice(self.roi['x0'] - dx, self.roi['x1'] - dx), 

116 ) 

117 self.proc_data.offset_map.map = self.proc_data.offset_map.map[roi] 

118 

119 def _shift_offset_map_wl(self) -> np.array: 

120 """ 

121 Shift the map by the configured wavelength offset 

122 """ 

123 self.wl_offset = self.proc_data.config.reduc.offset_map_wl_offset 

124 if self.wl_offset is None or self.wl_offset == 0: 

125 return 

126 self.px_offset = float(self.wl_offset) / float(self.proc_data.offset_map.header[DISPERSION]) 

127 log.warning(f'Shifting offset map by {self.wl_offset} nm ({self.px_offset} px)') 

128 offmap = self.proc_data.offset_map.map 

129 if np.ndim(offmap) == 2: 

130 self.proc_data.offset_map.data = Shifter.d2shift(offmap, [0, -self.px_offset]) 

131 elif np.ndim(offmap) == 3: 

132 self.proc_data.offset_map.data = Shifter.d2shift(offmap, [0, 0, -self.px_offset]) 

133 else: 

134 raise IllegalStateException(f'Offset map has unexpected shape {offmap.shape}') 

135 

136 def _apply_binning_factor(self): 

137 check_same_binning(self.batch) 

138 bins = self.batch[0]['header'][SPATIAL_BIN] if SPATIAL_BIN in self.batch[0]['header'] else None 

139 if bins is None: 

140 return 

141 bins = bins.split(',') 

142 # Do not bin mod state dimension 

143 bins.insert(0, 1) 

144 self.proc_data.offset_map.map = Collections.bin(self.proc_data.offset_map.map, Collections.as_int_array(bins)) 

145 

146 @staticmethod 

147 def prepare(proc_data: ProcessingData, files=None) -> None: 

148 if not proc_data.config.is_cam3() and proc_data.offset_map is None: 

149 raise IllegalStateException('No OFFSET MAP given') 

150 

151 @staticmethod 

152 def input_needed(cam3: bool = False) -> list: 

153 if cam3: 

154 return [] 

155 else: 

156 return ['offset_map'] 

157 

158 

159BlockRegistry().register(BlockM.BLOCK_ID, BlockM)