Coverage for src/susi/analyse/shifts.py: 80%

106 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1""" 

2Find shifts (jitter) in a series of frames wrt to the first frame. 

3""" 

4from __future__ import annotations 

5 

6import glob 

7import logging 

8import pickle 

9from skimage.registration import phase_cross_correlation 

10import numpy as np 

11 

12from ..io import FitsBatch 

13from ..base import Config, IllegalStateException, InsufficientDataException, Globals 

14from ..utils import MP 

15from ..utils import ExceptionHandler 

16 

17log = logging.getLogger('SUSI') 

18 

19 

20class ShiftAnalyser: 

21 """ 

22 # ShiftAnalyser 

23 Detect and analyse (sub-)pixel shifts in a series of frames. 

24 """ 

25 

26 def __init__(self, config: Config, workers: int = None, upscale: int = 1): 

27 self.batch = FitsBatch(config=config, slices=config.cam.data_shape) 

28 #: The config to apply (i.e. defines the frame set to analyse) 

29 self.config = config 

30 #: The number of parallel processes to use 

31 self.workers = config.base.workers 

32 #: The upscale factor to analyse sub-pixel shifts(default 1) 

33 self.upscale = upscale 

34 #: Dark image 

35 self.dark_file = None 

36 #: The result 

37 self.result = {} 

38 

39 def with_dark_correction(self, dark_file): 

40 """ 

41 Chainable configuration method to activate dark img correction 

42 

43 :param file: [String] the path of the darks file to use 

44 

45 :return: self 

46 """ 

47 self.dark_file = dark_file 

48 return self 

49 

50 def run(self) -> ShiftAnalyser: 

51 self.__load_data() 

52 self.__apply_dark() 

53 self.__analyse() 

54 return self 

55 

56 def __load_data(self): 

57 log.info('Looking for file(s) matching %s', self.config.input_pattern()) 

58 file_list = glob.glob(self.config.input_pattern()) 

59 if not file_list: 

60 raise InsufficientDataException('No files found!') 

61 log.info('Found: %s file(s)', len(file_list)) 

62 log.debug('Loading area %s of all files ...', self.config.cam.data_shape) 

63 self.batch.load(file_list, workers=min(len(file_list), 100)) 

64 self.batch.sort_by('Timestamp ms') 

65 

66 def __apply_dark(self): 

67 if self.dark_file is not None: 

68 self.batch.apply_dark(self.dark_file) 

69 

70 def __analyse(self): 

71 gs = GlobalShift(self.batch, self.workers, self.upscale).find() 

72 self.result['y'] = np.array(gs.xpos) - np.mean(gs.xpos) 

73 self.result['x'] = np.array(gs.ypos) - np.mean(gs.ypos) 

74 

75 @staticmethod 

76 def power_spectrum(shifts, samp=Globals.SAMP, bin_size=None) -> dict: 

77 """ 

78 Computes the power spectrum (in arbitrary units) of x, y pixel shifts. 

79 

80 ### Params 

81 - shifts: The detected shifts. 

82 - samp: Sample spacing (inverse of the sampling rate). Defaults to 0.0213 [s] (default SUSI frame time). 

83 - bin_size: Bin size in the frequency domain (default: no binning) 

84 """ 

85 if 'x' not in shifts or 'y' not in shifts: 

86 raise IllegalStateException('Shift analysis must be executed before a power spectrum can be computed!') 

87 

88 n = len(shifts['x']) 

89 norm_factor = samp / n # normalize power spectra to integral(power spectrum) = variance 

90 freqs = np.fft.fftfreq(n, d=samp) 

91 x = np.abs(np.fft.fft(shifts['x'])) ** 2 * norm_factor 

92 y = np.abs(np.fft.fft(shifts['y'])) ** 2 * norm_factor 

93 sorted_ids = np.argsort(freqs) 

94 ps = { 

95 'freq': freqs[sorted_ids], 

96 'x': x[sorted_ids], 

97 'y': y[sorted_ids] 

98 } 

99 if bin_size is not None: 

100 ps = ShiftAnalyser.__ps_bin(ps, bin_size) 

101 return ps 

102 

103 @staticmethod 

104 def __ps_bin(ps, bin_size): 

105 bin_values = np.arange(0, max(ps['freq']), bin_size) 

106 bin_idx = np.digitize(ps['freq'], bin_values) 

107 idx = range(len(bin_values)) 

108 w = [np.where(bin_idx == i)[0] for i in idx] 

109 for key in ps.keys(): 

110 ps[key] = [np.mean(ps[key][w[i]]) for i in idx][1:] 

111 for key in ['x', 'y']: 

112 ps[key] = [2 * val for val in ps[key]] # account for the contribution of the negative frequencies 

113 return ps 

114 

115 @staticmethod 

116 def save_result(result, out_file) -> None: 

117 """ 

118 Dump the result dictionary to the specified out file using `pickler` 

119 :param result: The data to save 

120 :param out_file: The file to dump the data to 

121 """ 

122 log.info('Writing result dictionary as binary pickle file to: %s', out_file) 

123 try: 

124 ShiftAnalyser.__write_pickle(out_file, result) 

125 except Exception as e: 

126 ExceptionHandler.recover_from_write_error(e, out_file, result, ShiftAnalyser.__write_pickle) 

127 

128 @staticmethod 

129 def __write_pickle(file, data, *args): 

130 with open(file, 'wb') as f: 

131 pickle.dump(data, f) 

132 

133 

134class GlobalShift: 

135 """ 

136 # GlobalShift 

137 Analyse the given batch and find the offsets in `x` and `y` direction at the 

138 img center wrt to first frame. 

139 

140 The analysis is done simultaneously. 

141 However, a cross-correlation is quite computing intense. Therefore, it's 

142 probably a good idea to not correlate the whole FOV (except if you have a lot of time). 

143 """ 

144 

145 def __init__(self, batch: FitsBatch, workers: int = None, upscale: int = 1): 

146 self.workers = MP.NUM_CPUs // 2 if workers is None else workers 

147 self.upscale = upscale 

148 self.batch = batch 

149 self.xpos = [] 

150 self.ypos = [] 

151 

152 def find(self) -> GlobalShift: 

153 log.debug('Analysing with upscale factor %d (using %s workers)...', self.upscale, self.workers) 

154 args = self.__gen_arguments() 

155 self.__global_offset(args) 

156 return self 

157 

158 def __global_offset(self, args: list) -> None: 

159 result = dict(MP.simultaneous(GlobalShift.offset, args, workers=self.workers)) 

160 

161 log.debug('Analysis done, collecting results...') 

162 for i in range(1, len(result.keys()) + 1): 

163 self.xpos.append(result[i]['x']) 

164 self.ypos.append(result[i]['y']) 

165 

166 def __gen_arguments(self) -> list: 

167 frames = self.batch.data_array() 

168 return [{ 

169 'idx': i, 

170 'template': frames[0], 

171 'frame': frames[i], 

172 'upsample': self.upscale} 

173 for i in range(1, len(frames))] 

174 

175 @staticmethod 

176 def offset(data: dict) -> tuple: 

177 """ 

178 Compute the global offset at img center 

179 

180 ### Param 

181 - data: a dictionary with 

182 - 'base' the base frame to correlate to 

183 - 'frame' the frame to correlate 

184 - 'idx' the index of the frame 

185 - 'upsample' upsampling factor for sub-pixel registration 

186 

187 ### Returns 

188 A tuple where 0 is the idx and 1 is a dict with offset in x and y 

189 """ 

190 if data['idx'] % 200 == 0: 

191 log.debug('\tProcessing Frame %s', data['idx']) 

192 shift, *_ = phase_cross_correlation(data['template'], data['frame'], 

193 upsample_factor=data['upsample'], return_error='always') 

194 return data['idx'], {'x': shift[1], 'y': shift[0]}