Coverage for src/susi/reduc/fields/flat_field_slit_correction.py: 15%
259 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-08-22 09:20 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-08-22 09:20 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4module for moving flat field correction
6@author: iglesias
7"""
8import numpy as np
9from scipy.optimize import least_squares, minimize, brute
10from scipy.ndimage import gaussian_filter
11from scipy.interpolate import interp1d
12import os
13from datetime import datetime
14import matplotlib.pyplot as plt
16from src import susi
17from src.susi.base.header_keys import *
19from ...io import Fits, MultiHDUWriter, MultiHDUReader, Card
20from ...base import Logging, Config, Globals
21from ...reduc.average import FrameSum
22from ...utils import Collections, MP, progress
23from ...utils.sub_shift import Shifter
26logger = Logging.get_logger()
29class SlitFFCorrector:
30 """
31 Class corrects for the moving flat field in the input frames.
32 1- Read frames and compute parcial averages every every Reduc.slit_ff_avg_frm.
33 2- For every avg frame, fit the flat shift that minimizes the high-frequency residuals
34 in the flatfielded frame. NOTE: The stored shifts are the oposite!! so they can be directly
35 applied to the sci images
36 3- Fit a polinomial of order Reduc.slit_ff_smooth_par to the shifts, and derive a shift per each
37 individual input frame.
38 4- Read individual frames, apply the shift and the flat field, the writes them in the given odir.
39 """
41 def __init__(self, files: list, flat: Fits, config: Config):
42 """
43 Initialize the SlitFFCorrector class.
45 :param files: List of file paths to process.
46 :param flat: Flat field data as a Fits object.
47 :param config: Configuration object containing reduction parameters.
48 """
49 self.files = files
50 self.flat = flat
51 self.config = config
52 self.avg_frames = []
53 self.avg_shifts = []
54 self.odir = None # new level odir
55 self.log_dir = None # log directory for plots
56 self.avg_frames_ofile = None
57 self.shifts = [] # estimated shifts as they must be applied to the sci frames before ff
58 self.timestamps = []
59 self.db = None
60 self.slit_roi = [0, -1] # default slit ROI
61 self.wl_roi = [0, -1] # default wavelength ROI
63 # get odir and avg_frames_file
64 self.db = susi.FileDB(self.config)
65 self.avg_frames_ofile = self.db.avrg_fname(self.files[0], self.files[-1])
66 tavg = self.config.reduc.slit_ff_avg_frm * Globals.SAMP_S // 60
67 self.avg_frames_ofile = self.avg_frames_ofile.replace('-cam', f'-avg{tavg:.1f}min-cam')
68 odir_day = self.db.dir_path(self.db.file_time(self.files[0]), depth='upto_day', base=self.config.data.out_path)
69 self.avg_frames_ofile = os.path.join(odir_day, self.avg_frames_ofile)
70 self.odir = self.db.dir_path(
71 self.db.file_time(self.files[0]), depth='upto_level', base=self.config.data.out_path
72 )
73 self.log_dir = os.path.join(
74 self.config.data.log_dir, os.path.basename(self.avg_frames_ofile).replace('.fits.gz', '')
75 )
77 def compute_shifts(self, plot=False, ini_val=[0], slit_roi=None, wl_roi=None):
78 """
79 Obtain the shifts for the indivudual frames (steps 1-3).
80 NOTE: The returned self.shifts are the ones that must be applied to the sci frames.
81 :param plot: If True, plot the fit shifts to the log directory.
82 """
83 if slit_roi is not None:
84 self.slit_roi = slit_roi
86 if wl_roi is not None:
87 self.wl_roi = wl_roi
89 self.__compute_avg()
90 self.__compute_avg_shifts(ini_val=ini_val, plot=plot)
91 self.__compute_all_shifts()
92 if plot:
93 self.__plot_flat_profile()
94 self.__plot_shifts()
95 logger.info(f'Plots saved to {self.log_dir}')
96 return self.shifts
98 def correct_frames(self, custom_files=None, odir_sufix='_ff'):
99 """
100 Apply the computed shifts to the individual frames and then divide by the flat field.
101 :param custom_files: Use these instead of self.files. (they must be contained within self.files!)
102 :param odir_sufix: Suffix to add to the output directory. Which is the same level directory as the input files
103 """
104 if custom_files is not None:
105 files = [i for i in self.files if i in custom_files]
106 else:
107 files = self.files
109 if len(files) == 0:
110 logger.warning('No files to correct. Returning.')
112 self.config.data.out_path = self.odir + odir_sufix
113 logger.info(f'Flatfielding {len(files)} files using chunks of {self.config.reduc.slit_ff_avg_frm} frames')
114 chunks = Collections.chunker(self.files, self.config.reduc.slit_ff_avg_frm)
115 for chunk in chunks:
116 logger.info(f'Processing chunk {chunks.index(chunk) + 1} of {len(chunks)}, with {len(chunk)} files')
117 tstamps = [self.db.file_time(i).timestamp() for i in chunk]
118 shifts = np.polyval(self.coeffs, tstamps)
119 SlitFFCorrector.correct_paralell(chunk, shifts, self.flat, self.config)
121 def __compute_avg(self):
122 """
123 Compute average frames from the input files in parallel.
124 The average is computed every Reduc.slit_ff_avg_frm frames.
125 For future use, the average frames are saved in self.avg_frames_ofile.
126 If the file already exists, it is read instead of recomputing the averages.
127 """
128 if os.path.exists(self.avg_frames_ofile):
129 mhdur = MultiHDUReader(self.avg_frames_ofile)
130 self.avg_frames = mhdur.read(fits_list=True)
131 logger.info(f'Loaded {len(self.avg_frames)} average frames from {self.avg_frames_ofile}')
132 else:
133 chunks = Collections.chunker(self.files, self.config.reduc.slit_ff_avg_frm)
134 for chunk in chunks:
135 logger.info(f'Averaging chunk {chunks.index(chunk) + 1} of {len(chunks)}, with {len(chunk)} files')
136 avg = FrameSum.avg_in_chunks(chunk, header_idx=len(chunk) // 2) # use the header of the middle file
137 fits = Fits(chunk[0])
138 fits.header = avg.header
139 fits.data = avg.result
140 fits.header.append(
141 susi.Card(susi.base.TEMPORAL_BIN, value=len(chunk), comment='No. of averaged frames').to_card()
142 )
143 self.avg_frames.append(fits)
145 logger.info(f'Saving average frames to {self.avg_frames_ofile}')
146 os.makedirs(os.path.dirname(self.avg_frames_ofile), exist_ok=True)
147 mhduw = MultiHDUWriter(self.avg_frames_ofile)
148 mhduw.data = [i.data for i in self.avg_frames]
149 mhduw.headers = [i.header for i in self.avg_frames]
150 mhduw.write_to_disk()
152 if len(self.avg_frames) == 0:
153 raise ValueError(f'No average frames computed or loaded. Check the input')
155 def __compute_avg_shifts(self, ini_val=[0], plot=False):
156 """
157 Compute the shifts for the average frames.
158 The shifts are computed by minimizing the high-frequency residuals
159 in the flatfielded frames.
160 """
161 slit_roi = self.slit_roi
162 wl_roi = self.wl_roi
163 self.flat_profile = np.mean(self.flat.data[0, slit_roi[0] : slit_roi[1], wl_roi[0] : wl_roi[1]], axis=1)
164 global_profile = np.array(
165 [i.data[0, slit_roi[0] : slit_roi[1], wl_roi[0] : wl_roi[1]] for i in self.avg_frames]
166 )[0]
167 global_profile = np.mean(global_profile, axis=1)
168 global_shift = self.__fit_shift(global_profile, ini_val=ini_val)
169 logger.info(f'Computing slit profile shifts with initial condition: {global_shift:.3f} px')
170 for avg_frame in self.avg_frames:
171 slit_profile = np.mean(avg_frame.data[0, slit_roi[0] : slit_roi[1], wl_roi[0] : wl_roi[1]], axis=1)
172 shift = self.__fit_shift(slit_profile, ini_val=[global_shift])
173 self.avg_shifts.append(
174 {'shift': -shift, 'timestamp': datetime.fromisoformat(avg_frame.header[DATE_OBS]).timestamp()}
175 )
176 if plot:
177 self.__plot_avg_profile(slit_profile, shift, self.avg_frames.index(avg_frame))
179 # fit a line to the shifts and use those as initial condition for a new fit
180 self.__fit_poly(order=1)
181 self.avg_shifts = []
182 logger.info('Recomputing slit profile shifts with improved initial conditions')
183 for avg_frame in self.avg_frames:
184 slit_profile = np.mean(avg_frame.data[0, slit_roi[0] : slit_roi[1], wl_roi[0] : wl_roi[1]], axis=1)
185 new_ini_val = -np.polyval(self.coeffs, datetime.fromisoformat(avg_frame.header[DATE_OBS]).timestamp())
186 shift = self.__fit_shift(slit_profile, ini_val=new_ini_val, algo='lsq', verbose=0)
187 self.avg_shifts.append(
188 {'shift': -shift, 'timestamp': datetime.fromisoformat(avg_frame.header[DATE_OBS]).timestamp()}
189 )
191 def __fit_shift(self, signal1, ini_val=[0], min_disp=-15, max_disp=15, algo='lsq', verbose=0):
192 # compute the shift by minimizing the high-passed-filtered ratio signal1/signal2
193 signal2 = self.flat_profile
194 if algo == 'lsq':
195 scalar_error = False
196 res = least_squares(
197 SlitFFCorrector.ff_err,
198 ini_val,
199 args=(signal1, signal2, scalar_error),
200 bounds=([min_disp], [max_disp]),
201 verbose=verbose,
202 )
203 return res.x[0]
204 elif algo == 'minimize':
205 scalar_error = False
207 def obj_func(x):
208 err = ff_err([x], signal1, signal2, scalar_error)
209 return np.std(err)
211 options = {'disp': verbose > 0}
212 res = minimize(obj_func, ini_val, bounds=[(min_disp, max_disp)], options=options)
213 return res.x[0]
214 elif algo == 'brute':
215 Ns = [20, 50] # number of points to sample in the brute force search
216 workers = 20
217 scalar_error = True
218 ranges = [(min_disp, max_disp)]
219 res = brute(
220 ff_err,
221 ranges,
222 Ns=Ns[0],
223 args=(signal1, signal2, scalar_error),
224 finish=None,
225 full_output=True,
226 disp=verbose > 0,
227 workers=workers,
228 )
229 # run again around +- 1 the minimum found
230 res = brute(
231 ff_err,
232 [(res[0] - 1, res[0] + 1)],
233 Ns=Ns[1],
234 args=(signal1, signal2, scalar_error),
235 finish=None,
236 full_output=True,
237 disp=verbose > 0,
238 workers=workers,
239 )
240 return res[0]
241 else:
242 raise ValueError(f'Unknown algorithm: {algo}. Use "lsq", "minimize" or "brute".')
244 def __compute_all_shifts(self):
245 """
246 Compute the shifts for the individual frames.
247 The shifts are computed by fiting a polynomial to the average shifts.
248 """
249 logger.info('Computing shifts for individual frames')
250 self.__fit_poly()
252 self.timestamps = [self.db.file_time(i).timestamp() for i in self.files]
254 self.shifts = np.polyval(self.coeffs, self.timestamps)
256 def __fit_poly(self, order=None):
257 """
258 Fit a polynomial of order Reduc.slit_ff_smooth_par to self.shifts.
259 """
260 x = [i['timestamp'] for i in self.avg_shifts]
261 y = [i['shift'] for i in self.avg_shifts]
262 if order is None:
263 order = self.config.reduc.slit_ff_smooth_par
264 self.coeffs = np.polyfit(x, y, order)
266 def __plot_shifts(self):
267 """
268 Plot the shifts and save the figure to the output directory.
269 """
271 x = [i['timestamp'] for i in self.avg_shifts]
272 y = [i['shift'] for i in self.avg_shifts]
274 plt.figure(figsize=(10, 5))
275 plt.plot(x, y, 'or', label='Fit shift in avg. frames')
276 plt.plot(self.timestamps, self.shifts, '-k', label=f'Ploy fit coef: {self.coeffs}')
277 plt.xlabel('Timestamp')
278 plt.ylabel('Shift (px)')
279 plt.title('The shifts in blue will be applied to the sci. frm before flatfielding')
280 plt.grid()
281 plt.legend()
282 plt.xticks(rotation=45)
283 plt.tight_layout()
285 out_file = os.path.basename(self.avg_frames_ofile.replace('.fits.gz', '_shifts.png'))
286 out_file = os.path.join(self.log_dir, out_file)
287 os.makedirs(os.path.dirname(out_file), exist_ok=True)
288 plt.savefig(out_file)
289 logger.debug(f'Saved shifts plot to {out_file}')
291 def __plot_flat_profile(self):
292 """
293 Plot the flat profile and save the figure to the output directory.
294 """
295 plt.figure(figsize=(10, 5))
296 plt.plot(np.mean(self.flat.data[0, :, :], axis=1), label='Slit Flat field')
297 if self.slit_roi[1] == -1:
298 x = np.arange(self.slit_roi[0], self.flat.data.shape[1] - 1)
299 else:
300 x = np.arange(self.slit_roi[0], self.slit_roi[1])
301 plt.plot(x, self.flat_profile, 'r', label='Slit profile roi')
302 plt.xlabel('Pixel')
303 plt.ylabel('Intensity')
304 plt.title('Flat field profile used for slit flat correction')
305 plt.grid()
306 plt.legend()
307 plt.tight_layout()
309 out_file = os.path.basename(self.avg_frames_ofile.replace('.fits.gz', '_flat_profile.png'))
310 out_file = os.path.join(self.log_dir, out_file)
311 os.makedirs(os.path.dirname(out_file), exist_ok=True)
312 plt.savefig(out_file)
313 logger.debug(f'Saved flat profile plot to {out_file}')
315 def __plot_avg_profile(self, profile, shift, idx):
316 """
317 Plot the average profile and save the figure to the output directory.
318 :param profile: The average profile to plot.
319 :param shift: The shift applied to the profile.
320 """
321 fig, ax1 = plt.subplots(figsize=(10, 5))
322 hp_prof = SlitFFCorrector.high_pass_profile(profile)
323 ax1.plot(hp_prof, 'b', label='High-pass profile')
324 ax1.set_xlabel('Pixel')
325 ax1.set_ylabel('Intensity')
326 ax1.set_title('Average Slit Profile with Shift')
327 ax1.grid()
328 ax2 = ax1.twinx()
329 ax2.plot(profile / np.mean(profile), 'k', label='Slit profile')
330 ax2.plot(self.flat_profile / np.mean(self.flat_profile), 'r', label='Flat profile')
331 ax2.set_ylabel('High-pass Intensity')
332 lines_1, labels_1 = ax1.get_legend_handles_labels()
333 lines_2, labels_2 = ax2.get_legend_handles_labels()
334 ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc='best')
335 fig.tight_layout()
337 out_file = os.path.basename(self.avg_frames_ofile.replace('.fits.gz', f'_avg_prof_{idx}.png'))
338 out_file = os.path.join(self.log_dir, out_file)
339 os.makedirs(os.path.dirname(out_file), exist_ok=True)
340 plt.savefig(out_file)
341 logger.debug(f'Saved average profile plot to {out_file}')
343 @staticmethod
344 def ff_err(x, signal1, signal2, scalar):
345 signal2 = SlitFFCorrector.apply_subpixel_shift(signal2, x[0])
346 if len(signal1) > len(signal2):
347 signal2 = np.pad(signal2, (0, len(signal1) - len(signal2)), mode='constant', constant_values=np.nan)
348 elif len(signal2) > len(signal1):
349 signal1 = np.pad(signal1, (0, len(signal2) - len(signal1)), mode='constant', constant_values=np.nan)
351 ratio = signal1 / signal2
352 ratio = SlitFFCorrector.high_pass_profile(ratio[~np.isnan(ratio)])
353 if scalar:
354 return np.sqrt(np.mean(ratio**2))
355 else:
356 return ratio
358 @staticmethod
359 def apply_subpixel_shift(signal, shift):
360 # Compute the original and shifted indices.
361 original_indices = np.arange(len(signal))
362 shifted_indices = original_indices - shift
364 # Interpolate signal at the shifted indices.
365 interpolator = interp1d(original_indices, signal, kind="cubic", fill_value="extrapolate")
366 shifted_signal = interpolator(shifted_indices)
368 return shifted_signal
370 @staticmethod
371 def high_pass_profile(signal):
372 signal = signal - gaussian_filter(signal, sigma=0.7)
373 return np.abs(signal)
375 @staticmethod
376 def correct_paralell(files, shifts, flat, config, chunk_size=200):
377 """
378 Correct input files in chunks, by shifting them and dividing by the flat field.
379 """
380 chunks = Collections.chunker(files, chunk_size)
381 args = [(c, shifts[i * chunk_size : (i + 1) * chunk_size], flat, config) for i, c in enumerate(chunks)]
382 MP.simultaneous(SlitFFCorrector.flatfield_frames, args)
383 progress.dot(flush=True)
385 @staticmethod
386 def flatfield_frames(args):
387 files, shifts, flat, config = args
388 for f in files:
389 fits = Fits(f).read()
390 shifted_data = Shifter.d2shift(fits.data[0], np.array([-shifts[files.index(f)], 0]))
391 szf = flat.data.shape[1]
392 szd = shifted_data.shape[0]
393 if szf > szd:
394 shifted_data = shifted_data[None,]
395 crop_flat = flat.data[:, :szd, :]
396 elif szf < szd:
397 shifted_data = shifted_data[None, :szf, :]
398 crop_flat = flat.data
399 else:
400 shifted_data = shifted_data[None,]
401 crop_flat = flat.data
402 corrected_data = shifted_data / crop_flat
403 fits.data = corrected_data
404 db = susi.FileDB(config)
405 try:
406 fits.header["HIERARCH " + SHIFT_APPLIED] = f'{shifts[files.index(f)]:.6f}, {0:.6f}'
407 except KeyError:
408 fits.header.append(Card(SHIFT_APPLIED, value=shifts[files.index(f)], comment='[px]').to_card())
410 short_fpath = db.file_short_basename(flat.path, avg=True)
411 fits.header.append(Card(FLAT_MAP_MOVING, value=short_fpath).to_card())
412 dayhourcam = db.dir_path(db.file_time(f), depth='day-hour-cam')
413 fits.path = os.path.join(config.data.out_path, dayhourcam, os.path.basename(f))
414 fits.write_to_disk(overwrite=config.data.force_reprocessing)
415 progress.dot()