Coverage for src/susi/reduc/pipeline/blocks/block_d.py: 93%

104 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4Module that holds the block `D` of the susi pipeline. 

5 

6@author: hoelken 

7""" 

8import os 

9 

10import numpy as np 

11from astropy.io import fits 

12 

13from .block import Block, BlockRegistry 

14from ..processing_data import ProcessingData 

15from ...demodulation import Demodulator, DemodAssembler 

16from ....base import Logging, IllegalStateException 

17from ....base.header_keys import * 

18from ....db import FileDB 

19from ....io import FitsBatch, Fits, Card 

20from ....utils import Collections 

21from ....utils.cropping import adapt_shapes, adapt_shape, common_shape 

22from ....utils.header_checks import check_same_binning 

23 

24log = Logging.get_logger() 

25 

26 

27class BlockD(Block): 

28 """ 

29 ## BLOCK D: Demodulation 

30 

31 This blocks will apply the demodulation matrix to the FitsBatch. 

32 """ 

33 

34 BLOCK_ID = 'D' 

35 

36 @staticmethod 

37 def start(batch: FitsBatch, proc_data: ProcessingData) -> FitsBatch: 

38 """ 

39 Callback-like entry point for the block 

40 """ 

41 return BlockD(batch, proc_data).run().result 

42 

43 def _algorithm(self): 

44 if self.proc_data.config.is_cam3(): 

45 self.result = self.batch 

46 return 

47 

48 self._get_roi() 

49 self._adapt_shapes() 

50 self._apply_binning_factor() 

51 for block in DemodAssembler(self.db, self.batch).run(): 

52 stokes = Demodulator(self.proc_data.demod_matrix, block['data'], block['states'], block['name']).run() 

53 self._process_entry(stokes, block['start']) 

54 

55 def _get_roi(self) -> None: 

56 self._target_roi(self._demod_mat_shape()) 

57 

58 def _demod_mat_shape(self) -> dict: 

59 # when using a ROW_WISE demodulation, we extend it to fill the whole X dim 

60 if self.proc_data.config.spol.demod_mode == 'ROW_WISE': 

61 data_roi = common_shape(self.batch) 

62 self.proc_data.mod_matrix.header[ROI_X0] = data_roi['x0'] 

63 self.proc_data.mod_matrix.header[ROI_X1] = data_roi['x1'] 

64 h = self.proc_data.mod_matrix.header 

65 return {'y0': int(h[ROI_Y0]), 'y1': int(h[ROI_Y1]), 'x0': int(h[ROI_X0]), 'x1': int(h[ROI_X1])} 

66 

67 def _adapt_shapes(self): 

68 self.batch = adapt_shapes(self.batch, self.roi) 

69 if self.proc_data.config.spol.demod_mode == 'ROW_WISE': 

70 meanm = np.median(self.proc_data.demod_matrix, axis=1, keepdims=True) 

71 self.proc_data.demod_matrix = np.repeat(meanm, self.roi['x1'] - self.roi['x0'], axis=1) 

72 self.proc_data.demod_matrix = adapt_shape( 

73 self.proc_data.demod_matrix, self.proc_data.mod_matrix.header, self.roi 

74 ) 

75 

76 def _apply_binning_factor(self): 

77 check_same_binning(self.batch) 

78 bins = self.batch[0]['header'][SPATIAL_BIN] if SPATIAL_BIN in self.batch[0]['header'] else None 

79 if bins is None: 

80 return 

81 bins = bins.split(',') 

82 # Do not bin output dimension 

83 bins.append(1) 

84 # Do not bin input dimension 

85 bins.append(1) 

86 self.proc_data.demod_matrix = Collections.bin(self.proc_data.demod_matrix, Collections.as_int_array(bins)) 

87 

88 def _process_entry(self, obj: Fits, start_file: str): 

89 self.result.batch.append( 

90 { 

91 'file': obj.path, 

92 'data': obj.data, 

93 'header': self._merged_header(start_file), 

94 } 

95 ) 

96 

97 def _merged_header(self, template_file: str): 

98 header = self.batch[template_file]['header'] 

99 new_header = fits.PrimaryHDU().header 

100 for key in self.proc_data.config.base.header_copy_fields: 

101 if key not in header or key == MOD_STATE: 

102 continue 

103 new_header.append(Card.from_orig(header, key).to_card()) 

104 Fits.override_header(new_header, BLOCKS_APPLIED, BlockD.BLOCK_ID, append=True) 

105 modm_name = os.path.basename(os.path.dirname(os.path.dirname(self.proc_data.mod_matrix.path))) 

106 new_header.append(Card(DEMOD_MAT, value=modm_name).to_card()) 

107 new_header.append(Card(DEMOD_MODE, value=self.proc_data.config.spol.demod_mode).to_card()) 

108 self._update_roi_header(new_header) 

109 return new_header 

110 

111 @staticmethod 

112 def prepare(proc_data: ProcessingData) -> None: 

113 if proc_data.config.is_cam3(): 

114 return 

115 if proc_data.demod_matrix is not None: 

116 return 

117 if proc_data.mod_matrix is None: 

118 raise IllegalStateException('Neither modulation nor demodulation matrix given') 

119 

120 log.info('Preparing demodulation matrix...') 

121 # transpose ROI name to match data convetion ROIY is python dim 0 (or vertical in standard matplotlib plot) 

122 old_value = proc_data.mod_matrix.header[ROI_X0], proc_data.mod_matrix.header[ROI_Y0] 

123 proc_data.mod_matrix.header[ROI_Y0], proc_data.mod_matrix.header[ROI_X0] = old_value 

124 old_value = proc_data.mod_matrix.header[ROI_Y1], proc_data.mod_matrix.header[ROI_X1] 

125 proc_data.mod_matrix.header[ROI_X1], proc_data.mod_matrix.header[ROI_Y1] = old_value 

126 # demodm 

127 demo = proc_data.mod_matrix.data.astype("float64") 

128 demo /= demo[:, :, 0:1, 0:1] # norm wrt element 0,0 

129 proc_data.demod_matrix = np.linalg.pinv(demo) 

130 # invert modm dmiension 0, beacuse it has row 0 at the bottom (old DAQ convention) 

131 proc_data.demod_matrix = proc_data.demod_matrix[::-1, :, :, :] 

132 old_y0 = proc_data.mod_matrix.header[ROI_Y0] 

133 old_y1 = proc_data.mod_matrix.header[ROI_Y1] 

134 proc_data.mod_matrix.header[ROI_Y1] = proc_data.config.cam.data_shape[0].stop - old_y0 

135 proc_data.mod_matrix.header[ROI_Y0] = proc_data.config.cam.data_shape[0].stop - old_y1 

136 proc_data.mod_matrix.data = None 

137 

138 @staticmethod 

139 def input_needed(cam3: bool = False) -> list: 

140 if cam3: 

141 return [] 

142 else: 

143 return ['demod_matrix'] 

144 

145 @staticmethod 

146 def predict_output(batch: FitsBatch, proc_data: ProcessingData): 

147 new_batch = FitsBatch() 

148 db = FileDB(proc_data.config) 

149 for block in DemodAssembler(db, batch, with_data=False).run(): 

150 new_batch.batch.append({'file': block['name'], 'header': batch[block['start']]['header'], 'data': None}) 

151 

152 return new_batch 

153 

154 

155BlockRegistry().register(BlockD.BLOCK_ID, BlockD)