Coverage for src/susi/analyse/shifts.py: 80%
106 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
1"""
2Find shifts (jitter) in a series of frames wrt to the first frame.
3"""
4from __future__ import annotations
6import glob
7import logging
8import pickle
9from skimage.registration import phase_cross_correlation
10import numpy as np
12from ..io import FitsBatch
13from ..base import Config, IllegalStateException, InsufficientDataException, Globals
14from ..utils import MP
15from ..utils import ExceptionHandler
17log = logging.getLogger('SUSI')
20class ShiftAnalyser:
21 """
22 # ShiftAnalyser
23 Detect and analyse (sub-)pixel shifts in a series of frames.
24 """
26 def __init__(self, config: Config, workers: int = None, upscale: int = 1):
27 self.batch = FitsBatch(config=config, slices=config.cam.data_shape)
28 #: The config to apply (i.e. defines the frame set to analyse)
29 self.config = config
30 #: The number of parallel processes to use
31 self.workers = config.base.workers
32 #: The upscale factor to analyse sub-pixel shifts(default 1)
33 self.upscale = upscale
34 #: Dark image
35 self.dark_file = None
36 #: The result
37 self.result = {}
39 def with_dark_correction(self, dark_file):
40 """
41 Chainable configuration method to activate dark img correction
43 :param file: [String] the path of the darks file to use
45 :return: self
46 """
47 self.dark_file = dark_file
48 return self
50 def run(self) -> ShiftAnalyser:
51 self.__load_data()
52 self.__apply_dark()
53 self.__analyse()
54 return self
56 def __load_data(self):
57 log.info('Looking for file(s) matching %s', self.config.input_pattern())
58 file_list = glob.glob(self.config.input_pattern())
59 if not file_list:
60 raise InsufficientDataException('No files found!')
61 log.info('Found: %s file(s)', len(file_list))
62 log.debug('Loading area %s of all files ...', self.config.cam.data_shape)
63 self.batch.load(file_list, workers=min(len(file_list), 100))
64 self.batch.sort_by('Timestamp ms')
66 def __apply_dark(self):
67 if self.dark_file is not None:
68 self.batch.apply_dark(self.dark_file)
70 def __analyse(self):
71 gs = GlobalShift(self.batch, self.workers, self.upscale).find()
72 self.result['y'] = np.array(gs.xpos) - np.mean(gs.xpos)
73 self.result['x'] = np.array(gs.ypos) - np.mean(gs.ypos)
75 @staticmethod
76 def power_spectrum(shifts, samp=Globals.SAMP, bin_size=None) -> dict:
77 """
78 Computes the power spectrum (in arbitrary units) of x, y pixel shifts.
80 ### Params
81 - shifts: The detected shifts.
82 - samp: Sample spacing (inverse of the sampling rate). Defaults to 0.0213 [s] (default SUSI frame time).
83 - bin_size: Bin size in the frequency domain (default: no binning)
84 """
85 if 'x' not in shifts or 'y' not in shifts:
86 raise IllegalStateException('Shift analysis must be executed before a power spectrum can be computed!')
88 n = len(shifts['x'])
89 norm_factor = samp / n # normalize power spectra to integral(power spectrum) = variance
90 freqs = np.fft.fftfreq(n, d=samp)
91 x = np.abs(np.fft.fft(shifts['x'])) ** 2 * norm_factor
92 y = np.abs(np.fft.fft(shifts['y'])) ** 2 * norm_factor
93 sorted_ids = np.argsort(freqs)
94 ps = {
95 'freq': freqs[sorted_ids],
96 'x': x[sorted_ids],
97 'y': y[sorted_ids]
98 }
99 if bin_size is not None:
100 ps = ShiftAnalyser.__ps_bin(ps, bin_size)
101 return ps
103 @staticmethod
104 def __ps_bin(ps, bin_size):
105 bin_values = np.arange(0, max(ps['freq']), bin_size)
106 bin_idx = np.digitize(ps['freq'], bin_values)
107 idx = range(len(bin_values))
108 w = [np.where(bin_idx == i)[0] for i in idx]
109 for key in ps.keys():
110 ps[key] = [np.mean(ps[key][w[i]]) for i in idx][1:]
111 for key in ['x', 'y']:
112 ps[key] = [2 * val for val in ps[key]] # account for the contribution of the negative frequencies
113 return ps
115 @staticmethod
116 def save_result(result, out_file) -> None:
117 """
118 Dump the result dictionary to the specified out file using `pickler`
119 :param result: The data to save
120 :param out_file: The file to dump the data to
121 """
122 log.info('Writing result dictionary as binary pickle file to: %s', out_file)
123 try:
124 ShiftAnalyser.__write_pickle(out_file, result)
125 except Exception as e:
126 ExceptionHandler.recover_from_write_error(e, out_file, result, ShiftAnalyser.__write_pickle)
128 @staticmethod
129 def __write_pickle(file, data, *args):
130 with open(file, 'wb') as f:
131 pickle.dump(data, f)
134class GlobalShift:
135 """
136 # GlobalShift
137 Analyse the given batch and find the offsets in `x` and `y` direction at the
138 img center wrt to first frame.
140 The analysis is done simultaneously.
141 However, a cross-correlation is quite computing intense. Therefore, it's
142 probably a good idea to not correlate the whole FOV (except if you have a lot of time).
143 """
145 def __init__(self, batch: FitsBatch, workers: int = None, upscale: int = 1):
146 self.workers = MP.NUM_CPUs // 2 if workers is None else workers
147 self.upscale = upscale
148 self.batch = batch
149 self.xpos = []
150 self.ypos = []
152 def find(self) -> GlobalShift:
153 log.debug('Analysing with upscale factor %d (using %s workers)...', self.upscale, self.workers)
154 args = self.__gen_arguments()
155 self.__global_offset(args)
156 return self
158 def __global_offset(self, args: list) -> None:
159 result = dict(MP.simultaneous(GlobalShift.offset, args, workers=self.workers))
161 log.debug('Analysis done, collecting results...')
162 for i in range(1, len(result.keys()) + 1):
163 self.xpos.append(result[i]['x'])
164 self.ypos.append(result[i]['y'])
166 def __gen_arguments(self) -> list:
167 frames = self.batch.data_array()
168 return [{
169 'idx': i,
170 'template': frames[0],
171 'frame': frames[i],
172 'upsample': self.upscale}
173 for i in range(1, len(frames))]
175 @staticmethod
176 def offset(data: dict) -> tuple:
177 """
178 Compute the global offset at img center
180 ### Param
181 - data: a dictionary with
182 - 'base' the base frame to correlate to
183 - 'frame' the frame to correlate
184 - 'idx' the index of the frame
185 - 'upsample' upsampling factor for sub-pixel registration
187 ### Returns
188 A tuple where 0 is the idx and 1 is a dict with offset in x and y
189 """
190 if data['idx'] % 200 == 0:
191 log.debug('\tProcessing Frame %s', data['idx'])
192 shift, *_ = phase_cross_correlation(data['template'], data['frame'],
193 upsample_factor=data['upsample'], return_error='always')
194 return data['idx'], {'x': shift[1], 'y': shift[0]}