Coverage for src/susi/reduc/shift/slit_shift.py: 95%

82 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4provides slit shift class 

5 

6@author: iglesias 

7""" 

8from tkinter import E 

9from weakref import ref 

10import numpy as np 

11import copy 

12import os 

13 

14from scipy.sparse.tests.test_array_api import B 

15from ...base import Logging, Globals 

16from ...io import FitsBatch 

17from ...base.header_keys import * 

18from ...utils.cropping import common_shape 

19from ...utils.sub_shift import round_shift_for_roi 

20 

21 

22logger = Logging.get_logger() 

23 

24 

25class SlitShiftRef: 

26 """ 

27 1-Selects the frame to be used as reference for the slit flat shift correction 

28 2-Read all the shifts in flist and compute the common roi to crop the output data 

29 """ 

30 

31 def __init__(self, files: list, proce_param: int): 

32 self.ref_offset = None 

33 self.ref_slope = None 

34 self.ref_file = None 

35 self.common_roi = None 

36 self.files = files 

37 self.proce_param = proce_param 

38 

39 def run(self): 

40 batch = self._load_data() 

41 self._get_reference(batch) 

42 shifts, common_roi = SlitShiftProcessor(batch, self.ref_offset, self.proce_param).run() 

43 logger.info( 

44 f'shift ref. ={self.ref_file}: offset={self.ref_offset}, slope={self.ref_slope}, Common ROI: {common_roi}' 

45 ) 

46 files = [os.path.basename(batch[f]['file']) for f in range(len(batch))] 

47 shifts_dict = {'files': files, 'shifts': shifts} 

48 ref_dict = {'offset': self.ref_offset, 'slope': self.ref_slope, 'file': self.ref_file, 'common_roi': common_roi} 

49 return ref_dict, shifts_dict 

50 

51 def _load_data(self): 

52 # TODO change to read SLIT_FLAT_OFFSET and SLIT_FLAT_SLOPE from metadata files instead the header 

53 batch = FitsBatch() 

54 logger.info(f'Reading the header of {len(self.files)} files to compute shift reference and out roi') 

55 batch.load(self.files, header_only=True, sort_by=DATE_OBS) 

56 return batch 

57 

58 def _get_reference(self, batch): 

59 fits = batch[len(batch) // 2] 

60 self.ref_offset = fits['header'][SLIT_FLAT_OFFSET] 

61 self.ref_slope = fits['header'][SLIT_FLAT_SLOPE] 

62 self.ref_file = fits['file'] 

63 

64 

65class SlitShiftProcessor: 

66 """ 

67 Class to process slit shifts from a batch (assume is sorted by OBS_TIME) 

68 1- get shifts from frm header 

69 2- smooth the shifts 

70 3- subtract the reference shift 

71 4- apply the same shift for each modulation cycle 

72 5- compute the common roi to crop the output data 

73 """ 

74 

75 def __init__(self, batch, ref_shift, proce_param): 

76 self.batch = copy.deepcopy(batch.header_copy()) 

77 self.data_shape = batch[0]['data'].shape if batch[0]['data'] is not None else None 

78 self.proce_param = proce_param 

79 self.ref_shift = ref_shift 

80 # results 

81 self.shifts = [] 

82 self.common_roi = None 

83 

84 def run(self): 

85 self._get_offset() 

86 self._get_shifts() 

87 self._get_common_shape() 

88 return self.shifts, self.common_roi 

89 

90 def _get_offset(self): 

91 for fits in self.batch: 

92 self.shifts.append(np.float64(fits.header[SLIT_FLAT_OFFSET])) 

93 

94 def _get_shifts(self): 

95 # sanity check 

96 if len(self.shifts) != len(self.batch): 

97 raise ValueError('Number of shifts different from number of frames') 

98 if len(self.shifts) < 3: 

99 raise ValueError('Number of frames is less than 3, cant smooth/fit the slit shifts') 

100 # fit polinomial of order self.proce_param to the shifts 

101 fit = np.polyfit(range(len(self.shifts)), self.shifts, self.proce_param) 

102 self.shifts = np.polyval(fit, range(len(self.shifts))) 

103 # subtract the reference shift 

104 self.shifts = np.array(self.shifts) - self.ref_shift 

105 # Use the same avg shift for each modulation cycle 

106 for i in range(len(self.shifts)): 

107 if self.batch[i]['header'][MOD_STATE] == '0': 

108 cycle_shift = np.mean(self.shifts[i : i + Globals.MOD_CYCLE_FRAMES]) 

109 self.shifts[i : i + Globals.MOD_CYCLE_FRAMES] = cycle_shift 

110 

111 def _get_common_shape(self): 

112 for i in range(len(self.batch)): 

113 # modify ROI keywords to account for the shift that will be implemented 

114 shift = round_shift_for_roi(self.shifts[i]) 

115 if shift > 0: # moves up and thus reduces Y1 to account for the lost roi 

116 self.batch[i]['header'][ROI_Y1] = int(float(self.batch[i]['header'][ROI_Y1]) - shift) 

117 elif shift < 0: # moves down and thus increases Y0 to account for the lost roi 

118 self.batch[i]['header'][ROI_Y0] = int(float(self.batch[i]['header'][ROI_Y0]) - shift) 

119 else: 

120 continue 

121 # get the common roi of the full batch 

122 self.common_roi = common_shape(self.batch) 

123 # reduce common roi by 1 to avoid numeric effects in rounding shifts 

124 self.common_roi['y0'] += 1 

125 self.common_roi['y1'] -= 1 

126 # check that there is no frame with roi smaller than the common roi 

127 for i in range(len(self.batch)): 

128 if (self.batch[i]['header'][ROI_Y1] < self.common_roi['y1']) or ( 

129 self.batch[i]['header'][ROI_Y0] > self.common_roi['y0'] 

130 ): 

131 raise ValueError(f'ROI of frame {i} is smaller than common roi')