Coverage for src/susi/reduc/fields/flat_field_slit_correction.py: 15%

259 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-08-22 09:20 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4module for moving flat field correction 

5 

6@author: iglesias 

7""" 

8import numpy as np 

9from scipy.optimize import least_squares, minimize, brute 

10from scipy.ndimage import gaussian_filter 

11from scipy.interpolate import interp1d 

12import os 

13from datetime import datetime 

14import matplotlib.pyplot as plt 

15 

16from src import susi 

17from src.susi.base.header_keys import * 

18 

19from ...io import Fits, MultiHDUWriter, MultiHDUReader, Card 

20from ...base import Logging, Config, Globals 

21from ...reduc.average import FrameSum 

22from ...utils import Collections, MP, progress 

23from ...utils.sub_shift import Shifter 

24 

25 

26logger = Logging.get_logger() 

27 

28 

29class SlitFFCorrector: 

30 """ 

31 Class corrects for the moving flat field in the input frames. 

32 1- Read frames and compute parcial averages every every Reduc.slit_ff_avg_frm. 

33 2- For every avg frame, fit the flat shift that minimizes the high-frequency residuals 

34 in the flatfielded frame. NOTE: The stored shifts are the oposite!! so they can be directly 

35 applied to the sci images 

36 3- Fit a polinomial of order Reduc.slit_ff_smooth_par to the shifts, and derive a shift per each 

37 individual input frame. 

38 4- Read individual frames, apply the shift and the flat field, the writes them in the given odir. 

39 """ 

40 

41 def __init__(self, files: list, flat: Fits, config: Config): 

42 """ 

43 Initialize the SlitFFCorrector class. 

44 

45 :param files: List of file paths to process. 

46 :param flat: Flat field data as a Fits object. 

47 :param config: Configuration object containing reduction parameters. 

48 """ 

49 self.files = files 

50 self.flat = flat 

51 self.config = config 

52 self.avg_frames = [] 

53 self.avg_shifts = [] 

54 self.odir = None # new level odir 

55 self.log_dir = None # log directory for plots 

56 self.avg_frames_ofile = None 

57 self.shifts = [] # estimated shifts as they must be applied to the sci frames before ff 

58 self.timestamps = [] 

59 self.db = None 

60 self.slit_roi = [0, -1] # default slit ROI 

61 self.wl_roi = [0, -1] # default wavelength ROI 

62 

63 # get odir and avg_frames_file 

64 self.db = susi.FileDB(self.config) 

65 self.avg_frames_ofile = self.db.avrg_fname(self.files[0], self.files[-1]) 

66 tavg = self.config.reduc.slit_ff_avg_frm * Globals.SAMP_S // 60 

67 self.avg_frames_ofile = self.avg_frames_ofile.replace('-cam', f'-avg{tavg:.1f}min-cam') 

68 odir_day = self.db.dir_path(self.db.file_time(self.files[0]), depth='upto_day', base=self.config.data.out_path) 

69 self.avg_frames_ofile = os.path.join(odir_day, self.avg_frames_ofile) 

70 self.odir = self.db.dir_path( 

71 self.db.file_time(self.files[0]), depth='upto_level', base=self.config.data.out_path 

72 ) 

73 self.log_dir = os.path.join( 

74 self.config.data.log_dir, os.path.basename(self.avg_frames_ofile).replace('.fits.gz', '') 

75 ) 

76 

77 def compute_shifts(self, plot=False, ini_val=[0], slit_roi=None, wl_roi=None): 

78 """ 

79 Obtain the shifts for the indivudual frames (steps 1-3). 

80 NOTE: The returned self.shifts are the ones that must be applied to the sci frames. 

81 :param plot: If True, plot the fit shifts to the log directory. 

82 """ 

83 if slit_roi is not None: 

84 self.slit_roi = slit_roi 

85 

86 if wl_roi is not None: 

87 self.wl_roi = wl_roi 

88 

89 self.__compute_avg() 

90 self.__compute_avg_shifts(ini_val=ini_val, plot=plot) 

91 self.__compute_all_shifts() 

92 if plot: 

93 self.__plot_flat_profile() 

94 self.__plot_shifts() 

95 logger.info(f'Plots saved to {self.log_dir}') 

96 return self.shifts 

97 

98 def correct_frames(self, custom_files=None, odir_sufix='_ff'): 

99 """ 

100 Apply the computed shifts to the individual frames and then divide by the flat field. 

101 :param custom_files: Use these instead of self.files. (they must be contained within self.files!) 

102 :param odir_sufix: Suffix to add to the output directory. Which is the same level directory as the input files 

103 """ 

104 if custom_files is not None: 

105 files = [i for i in self.files if i in custom_files] 

106 else: 

107 files = self.files 

108 

109 if len(files) == 0: 

110 logger.warning('No files to correct. Returning.') 

111 

112 self.config.data.out_path = self.odir + odir_sufix 

113 logger.info(f'Flatfielding {len(files)} files using chunks of {self.config.reduc.slit_ff_avg_frm} frames') 

114 chunks = Collections.chunker(self.files, self.config.reduc.slit_ff_avg_frm) 

115 for chunk in chunks: 

116 logger.info(f'Processing chunk {chunks.index(chunk) + 1} of {len(chunks)}, with {len(chunk)} files') 

117 tstamps = [self.db.file_time(i).timestamp() for i in chunk] 

118 shifts = np.polyval(self.coeffs, tstamps) 

119 SlitFFCorrector.correct_paralell(chunk, shifts, self.flat, self.config) 

120 

121 def __compute_avg(self): 

122 """ 

123 Compute average frames from the input files in parallel. 

124 The average is computed every Reduc.slit_ff_avg_frm frames. 

125 For future use, the average frames are saved in self.avg_frames_ofile. 

126 If the file already exists, it is read instead of recomputing the averages. 

127 """ 

128 if os.path.exists(self.avg_frames_ofile): 

129 mhdur = MultiHDUReader(self.avg_frames_ofile) 

130 self.avg_frames = mhdur.read(fits_list=True) 

131 logger.info(f'Loaded {len(self.avg_frames)} average frames from {self.avg_frames_ofile}') 

132 else: 

133 chunks = Collections.chunker(self.files, self.config.reduc.slit_ff_avg_frm) 

134 for chunk in chunks: 

135 logger.info(f'Averaging chunk {chunks.index(chunk) + 1} of {len(chunks)}, with {len(chunk)} files') 

136 avg = FrameSum.avg_in_chunks(chunk, header_idx=len(chunk) // 2) # use the header of the middle file 

137 fits = Fits(chunk[0]) 

138 fits.header = avg.header 

139 fits.data = avg.result 

140 fits.header.append( 

141 susi.Card(susi.base.TEMPORAL_BIN, value=len(chunk), comment='No. of averaged frames').to_card() 

142 ) 

143 self.avg_frames.append(fits) 

144 

145 logger.info(f'Saving average frames to {self.avg_frames_ofile}') 

146 os.makedirs(os.path.dirname(self.avg_frames_ofile), exist_ok=True) 

147 mhduw = MultiHDUWriter(self.avg_frames_ofile) 

148 mhduw.data = [i.data for i in self.avg_frames] 

149 mhduw.headers = [i.header for i in self.avg_frames] 

150 mhduw.write_to_disk() 

151 

152 if len(self.avg_frames) == 0: 

153 raise ValueError(f'No average frames computed or loaded. Check the input') 

154 

155 def __compute_avg_shifts(self, ini_val=[0], plot=False): 

156 """ 

157 Compute the shifts for the average frames. 

158 The shifts are computed by minimizing the high-frequency residuals 

159 in the flatfielded frames. 

160 """ 

161 slit_roi = self.slit_roi 

162 wl_roi = self.wl_roi 

163 self.flat_profile = np.mean(self.flat.data[0, slit_roi[0] : slit_roi[1], wl_roi[0] : wl_roi[1]], axis=1) 

164 global_profile = np.array( 

165 [i.data[0, slit_roi[0] : slit_roi[1], wl_roi[0] : wl_roi[1]] for i in self.avg_frames] 

166 )[0] 

167 global_profile = np.mean(global_profile, axis=1) 

168 global_shift = self.__fit_shift(global_profile, ini_val=ini_val) 

169 logger.info(f'Computing slit profile shifts with initial condition: {global_shift:.3f} px') 

170 for avg_frame in self.avg_frames: 

171 slit_profile = np.mean(avg_frame.data[0, slit_roi[0] : slit_roi[1], wl_roi[0] : wl_roi[1]], axis=1) 

172 shift = self.__fit_shift(slit_profile, ini_val=[global_shift]) 

173 self.avg_shifts.append( 

174 {'shift': -shift, 'timestamp': datetime.fromisoformat(avg_frame.header[DATE_OBS]).timestamp()} 

175 ) 

176 if plot: 

177 self.__plot_avg_profile(slit_profile, shift, self.avg_frames.index(avg_frame)) 

178 

179 # fit a line to the shifts and use those as initial condition for a new fit 

180 self.__fit_poly(order=1) 

181 self.avg_shifts = [] 

182 logger.info('Recomputing slit profile shifts with improved initial conditions') 

183 for avg_frame in self.avg_frames: 

184 slit_profile = np.mean(avg_frame.data[0, slit_roi[0] : slit_roi[1], wl_roi[0] : wl_roi[1]], axis=1) 

185 new_ini_val = -np.polyval(self.coeffs, datetime.fromisoformat(avg_frame.header[DATE_OBS]).timestamp()) 

186 shift = self.__fit_shift(slit_profile, ini_val=new_ini_val, algo='lsq', verbose=0) 

187 self.avg_shifts.append( 

188 {'shift': -shift, 'timestamp': datetime.fromisoformat(avg_frame.header[DATE_OBS]).timestamp()} 

189 ) 

190 

191 def __fit_shift(self, signal1, ini_val=[0], min_disp=-15, max_disp=15, algo='lsq', verbose=0): 

192 # compute the shift by minimizing the high-passed-filtered ratio signal1/signal2 

193 signal2 = self.flat_profile 

194 if algo == 'lsq': 

195 scalar_error = False 

196 res = least_squares( 

197 SlitFFCorrector.ff_err, 

198 ini_val, 

199 args=(signal1, signal2, scalar_error), 

200 bounds=([min_disp], [max_disp]), 

201 verbose=verbose, 

202 ) 

203 return res.x[0] 

204 elif algo == 'minimize': 

205 scalar_error = False 

206 

207 def obj_func(x): 

208 err = ff_err([x], signal1, signal2, scalar_error) 

209 return np.std(err) 

210 

211 options = {'disp': verbose > 0} 

212 res = minimize(obj_func, ini_val, bounds=[(min_disp, max_disp)], options=options) 

213 return res.x[0] 

214 elif algo == 'brute': 

215 Ns = [20, 50] # number of points to sample in the brute force search 

216 workers = 20 

217 scalar_error = True 

218 ranges = [(min_disp, max_disp)] 

219 res = brute( 

220 ff_err, 

221 ranges, 

222 Ns=Ns[0], 

223 args=(signal1, signal2, scalar_error), 

224 finish=None, 

225 full_output=True, 

226 disp=verbose > 0, 

227 workers=workers, 

228 ) 

229 # run again around +- 1 the minimum found 

230 res = brute( 

231 ff_err, 

232 [(res[0] - 1, res[0] + 1)], 

233 Ns=Ns[1], 

234 args=(signal1, signal2, scalar_error), 

235 finish=None, 

236 full_output=True, 

237 disp=verbose > 0, 

238 workers=workers, 

239 ) 

240 return res[0] 

241 else: 

242 raise ValueError(f'Unknown algorithm: {algo}. Use "lsq", "minimize" or "brute".') 

243 

244 def __compute_all_shifts(self): 

245 """ 

246 Compute the shifts for the individual frames. 

247 The shifts are computed by fiting a polynomial to the average shifts. 

248 """ 

249 logger.info('Computing shifts for individual frames') 

250 self.__fit_poly() 

251 

252 self.timestamps = [self.db.file_time(i).timestamp() for i in self.files] 

253 

254 self.shifts = np.polyval(self.coeffs, self.timestamps) 

255 

256 def __fit_poly(self, order=None): 

257 """ 

258 Fit a polynomial of order Reduc.slit_ff_smooth_par to self.shifts. 

259 """ 

260 x = [i['timestamp'] for i in self.avg_shifts] 

261 y = [i['shift'] for i in self.avg_shifts] 

262 if order is None: 

263 order = self.config.reduc.slit_ff_smooth_par 

264 self.coeffs = np.polyfit(x, y, order) 

265 

266 def __plot_shifts(self): 

267 """ 

268 Plot the shifts and save the figure to the output directory. 

269 """ 

270 

271 x = [i['timestamp'] for i in self.avg_shifts] 

272 y = [i['shift'] for i in self.avg_shifts] 

273 

274 plt.figure(figsize=(10, 5)) 

275 plt.plot(x, y, 'or', label='Fit shift in avg. frames') 

276 plt.plot(self.timestamps, self.shifts, '-k', label=f'Ploy fit coef: {self.coeffs}') 

277 plt.xlabel('Timestamp') 

278 plt.ylabel('Shift (px)') 

279 plt.title('The shifts in blue will be applied to the sci. frm before flatfielding') 

280 plt.grid() 

281 plt.legend() 

282 plt.xticks(rotation=45) 

283 plt.tight_layout() 

284 

285 out_file = os.path.basename(self.avg_frames_ofile.replace('.fits.gz', '_shifts.png')) 

286 out_file = os.path.join(self.log_dir, out_file) 

287 os.makedirs(os.path.dirname(out_file), exist_ok=True) 

288 plt.savefig(out_file) 

289 logger.debug(f'Saved shifts plot to {out_file}') 

290 

291 def __plot_flat_profile(self): 

292 """ 

293 Plot the flat profile and save the figure to the output directory. 

294 """ 

295 plt.figure(figsize=(10, 5)) 

296 plt.plot(np.mean(self.flat.data[0, :, :], axis=1), label='Slit Flat field') 

297 if self.slit_roi[1] == -1: 

298 x = np.arange(self.slit_roi[0], self.flat.data.shape[1] - 1) 

299 else: 

300 x = np.arange(self.slit_roi[0], self.slit_roi[1]) 

301 plt.plot(x, self.flat_profile, 'r', label='Slit profile roi') 

302 plt.xlabel('Pixel') 

303 plt.ylabel('Intensity') 

304 plt.title('Flat field profile used for slit flat correction') 

305 plt.grid() 

306 plt.legend() 

307 plt.tight_layout() 

308 

309 out_file = os.path.basename(self.avg_frames_ofile.replace('.fits.gz', '_flat_profile.png')) 

310 out_file = os.path.join(self.log_dir, out_file) 

311 os.makedirs(os.path.dirname(out_file), exist_ok=True) 

312 plt.savefig(out_file) 

313 logger.debug(f'Saved flat profile plot to {out_file}') 

314 

315 def __plot_avg_profile(self, profile, shift, idx): 

316 """ 

317 Plot the average profile and save the figure to the output directory. 

318 :param profile: The average profile to plot. 

319 :param shift: The shift applied to the profile. 

320 """ 

321 fig, ax1 = plt.subplots(figsize=(10, 5)) 

322 hp_prof = SlitFFCorrector.high_pass_profile(profile) 

323 ax1.plot(hp_prof, 'b', label='High-pass profile') 

324 ax1.set_xlabel('Pixel') 

325 ax1.set_ylabel('Intensity') 

326 ax1.set_title('Average Slit Profile with Shift') 

327 ax1.grid() 

328 ax2 = ax1.twinx() 

329 ax2.plot(profile / np.mean(profile), 'k', label='Slit profile') 

330 ax2.plot(self.flat_profile / np.mean(self.flat_profile), 'r', label='Flat profile') 

331 ax2.set_ylabel('High-pass Intensity') 

332 lines_1, labels_1 = ax1.get_legend_handles_labels() 

333 lines_2, labels_2 = ax2.get_legend_handles_labels() 

334 ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc='best') 

335 fig.tight_layout() 

336 

337 out_file = os.path.basename(self.avg_frames_ofile.replace('.fits.gz', f'_avg_prof_{idx}.png')) 

338 out_file = os.path.join(self.log_dir, out_file) 

339 os.makedirs(os.path.dirname(out_file), exist_ok=True) 

340 plt.savefig(out_file) 

341 logger.debug(f'Saved average profile plot to {out_file}') 

342 

343 @staticmethod 

344 def ff_err(x, signal1, signal2, scalar): 

345 signal2 = SlitFFCorrector.apply_subpixel_shift(signal2, x[0]) 

346 if len(signal1) > len(signal2): 

347 signal2 = np.pad(signal2, (0, len(signal1) - len(signal2)), mode='constant', constant_values=np.nan) 

348 elif len(signal2) > len(signal1): 

349 signal1 = np.pad(signal1, (0, len(signal2) - len(signal1)), mode='constant', constant_values=np.nan) 

350 

351 ratio = signal1 / signal2 

352 ratio = SlitFFCorrector.high_pass_profile(ratio[~np.isnan(ratio)]) 

353 if scalar: 

354 return np.sqrt(np.mean(ratio**2)) 

355 else: 

356 return ratio 

357 

358 @staticmethod 

359 def apply_subpixel_shift(signal, shift): 

360 # Compute the original and shifted indices. 

361 original_indices = np.arange(len(signal)) 

362 shifted_indices = original_indices - shift 

363 

364 # Interpolate signal at the shifted indices. 

365 interpolator = interp1d(original_indices, signal, kind="cubic", fill_value="extrapolate") 

366 shifted_signal = interpolator(shifted_indices) 

367 

368 return shifted_signal 

369 

370 @staticmethod 

371 def high_pass_profile(signal): 

372 signal = signal - gaussian_filter(signal, sigma=0.7) 

373 return np.abs(signal) 

374 

375 @staticmethod 

376 def correct_paralell(files, shifts, flat, config, chunk_size=200): 

377 """ 

378 Correct input files in chunks, by shifting them and dividing by the flat field. 

379 """ 

380 chunks = Collections.chunker(files, chunk_size) 

381 args = [(c, shifts[i * chunk_size : (i + 1) * chunk_size], flat, config) for i, c in enumerate(chunks)] 

382 MP.simultaneous(SlitFFCorrector.flatfield_frames, args) 

383 progress.dot(flush=True) 

384 

385 @staticmethod 

386 def flatfield_frames(args): 

387 files, shifts, flat, config = args 

388 for f in files: 

389 fits = Fits(f).read() 

390 shifted_data = Shifter.d2shift(fits.data[0], np.array([-shifts[files.index(f)], 0])) 

391 szf = flat.data.shape[1] 

392 szd = shifted_data.shape[0] 

393 if szf > szd: 

394 shifted_data = shifted_data[None,] 

395 crop_flat = flat.data[:, :szd, :] 

396 elif szf < szd: 

397 shifted_data = shifted_data[None, :szf, :] 

398 crop_flat = flat.data 

399 else: 

400 shifted_data = shifted_data[None,] 

401 crop_flat = flat.data 

402 corrected_data = shifted_data / crop_flat 

403 fits.data = corrected_data 

404 db = susi.FileDB(config) 

405 try: 

406 fits.header["HIERARCH " + SHIFT_APPLIED] = f'{shifts[files.index(f)]:.6f}, {0:.6f}' 

407 except KeyError: 

408 fits.header.append(Card(SHIFT_APPLIED, value=shifts[files.index(f)], comment='[px]').to_card()) 

409 

410 short_fpath = db.file_short_basename(flat.path, avg=True) 

411 fits.header.append(Card(FLAT_MAP_MOVING, value=short_fpath).to_card()) 

412 dayhourcam = db.dir_path(db.file_time(f), depth='day-hour-cam') 

413 fits.path = os.path.join(config.data.out_path, dayhourcam, os.path.basename(f)) 

414 fits.write_to_disk(overwrite=config.data.force_reprocessing) 

415 progress.dot()