Coverage for src/susi/analyse/grid_rot_and_shear.py: 47%

324 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3import os 

4import numpy as np 

5import pandas as pd 

6import matplotlib.pyplot as plt 

7from scipy.ndimage import gaussian_filter 

8from skimage.filters import sobel 

9from scipy.signal import find_peaks 

10from scipy.interpolate import interp1d 

11from scipy.optimize import brute 

12from ... import susi 

13from ..base.header_keys import * 

14from .. import ROOT_DIR 

15 

16log = susi.Logging.get_logger() 

17 

18INPUT_ROI_PATH = os.path.abspath(os.path.join(ROOT_DIR, "..", "data", "susi", "reduc", "shear_analisys_rois.csv")) 

19DEFAULT_ROI = [slice(440, 1980), slice(100, 2000)] # to use if roi='default' and no entry in INPUT_ROI_PATH 

20 

21 

22class RotAndShearAnalysis: 

23 """ 

24 Estimates the shear and rotation angle from the given grid target image 

25 

26 ## Author(s) 

27 iglesias@mps.mpg.de and vukadinovic@mps.mpg.de 

28 """ 

29 

30 def __init__( 

31 self, 

32 path: str, 

33 odir: str, 

34 custom_param=None, 

35 roi="default", 

36 min_grid_distance=140, 

37 grid_lines_to_skip=[], # they are ignored if loaded from file, tickest line is llwas ignored 

38 ): 

39 self.ipath = path 

40 #: The detected rotation angle in degree 

41 self.angle: float = 0.0 

42 #: The detected shear factor 

43 self.shear: float = 0.0 

44 #: The output directory 

45 self.odir = odir 

46 #: Manual parameters [rot, shear]. If set then no fit is done and these values are used for correction 

47 self.custom_param = custom_param 

48 # grid lines to skip, give a list of tuples with a range of row position in px 

49 # relative to the full 2kx2k image. Lines within each range will be skipped 

50 self.grid_lines_to_skip = [] 

51 [self.grid_lines_to_skip.append(g) for g in grid_lines_to_skip] 

52 #: Region of interest, use default for best inter comparison of errors between diff targets 

53 if roi == "default": 

54 self._get_custom_ana_param() 

55 yoff = self.roi[0].start 

56 elif roi == "full": 

57 self.roi = [slice(None), slice(None)] 

58 yoff = 0 

59 else: 

60 self.roi = roi 

61 yoff = self.roi[0].start 

62 # minimum distance between grid lines [px] 

63 self.min_grid_distance = min_grid_distance 

64 # corrects for roi offset 

65 if len(self.grid_lines_to_skip) > 0: 

66 self.grid_lines_to_skip = [[g[0] - yoff, g[1] - yoff] for g in self.grid_lines_to_skip] 

67 

68 def _get_custom_ana_param(self): 

69 df = pd.read_csv(INPUT_ROI_PATH) 

70 fname = os.path.basename(self.ipath).split(".")[0] 

71 row = df[df["lvl1_grid_filename"] == fname] 

72 if row.empty: 

73 self.roi = DEFAULT_ROI 

74 else: 

75 self.roi = [ 

76 slice(row["y0"].values[0], row["y1"].values[0]), 

77 slice(row["x0"].values[0], row["x1"].values[0]), 

78 ] 

79 log.debug(f"Using ROI {self.roi = } for shear analysis from {INPUT_ROI_PATH}") 

80 if row["grid_lines_to_skip"].values[0] is not np.nan: 

81 arr = ( 

82 np.array( 

83 row["grid_lines_to_skip"] 

84 .values[0] 

85 .replace("[", "") 

86 .replace("]", "") 

87 .replace(" ", "") 

88 .split(",") 

89 ) 

90 .astype(int) 

91 .reshape(-1, 2) 

92 ) 

93 self.grid_lines_to_skip = [[arr[i][0], arr[i][1]] for i in range(arr.shape[0])] 

94 

95 def get_grid_lines_loc( 

96 self, 

97 sobelx, 

98 xcut, 

99 xintp, 

100 x, 

101 oversample, 

102 distance=40, 

103 odir=None, 

104 oplot=False, 

105 close_fig=False, 

106 ): 

107 """ 

108 Get the location of the grid lines at columns given by xcut 

109 """ 

110 yslice = gaussian_filter(np.abs(sobelx[:, xcut]), sigma=3) 

111 yslice = interp1d(x, yslice, kind="cubic")(xintp) 

112 pos_peaks, _ = find_peaks(yslice, distance=distance) 

113 self.grid_lines_to_skip.append(self.get_thick_grid_line_loc(yslice, pos_peaks, oversample)) 

114 if len(self.grid_lines_to_skip) > 0 and len(pos_peaks) > 0: 

115 pos_peaks = [ 

116 p for p in pos_peaks if not any([g[0] < p / oversample < g[1] for g in self.grid_lines_to_skip]) 

117 ] 

118 pos_peaks = np.array(pos_peaks) 

119 log.debug(f"Skipped grid lines at positions: {[g for g in self.grid_lines_to_skip]}") 

120 if odir is not None: 

121 if not oplot: 

122 plt.figure(figsize=(15, 7)) 

123 plt.plot(xintp, yslice, label=f"Column {xcut}") 

124 for pos in pos_peaks: 

125 col = plt.gca().lines[-1].get_color() 

126 plt.axvline(x=pos / oversample, c=col, lw=0.7, ls="--") 

127 else: 

128 plt.plot(xintp, yslice, label=f"Column {xcut}") 

129 for pos in pos_peaks: 

130 col = plt.gca().lines[-1].get_color() 

131 plt.axvline(x=pos / oversample, c=col, lw=0.7, ls="--") 

132 if close_fig: 

133 plt.xlabel("vertical position [px]") 

134 plt.ylabel("edge intensity") 

135 plt.legend() 

136 plt.grid() 

137 plt.savefig(os.path.join(odir, "grid_slices.png")) 

138 plt.close() 

139 return pos_peaks 

140 

141 def get_grid_lines_angles(self, horizontal_grids, xcuts, maxx, oversample): 

142 angles = [] 

143 y_pos = [] 

144 displacement = [] 

145 # check that all xcuts have the same number of grid lines 

146 max_nlines = np.max([len(hg) for hg in horizontal_grids]) 

147 wrong_idx = [i for i in range(len(horizontal_grids)) if len(horizontal_grids[i]) != max_nlines] 

148 if len(wrong_idx) > 0: 

149 log.debug(f"Skipping columns {wrong_idx} as they have different number of grid lines") 

150 xcuts = [xcuts[i] for i in range(len(xcuts)) if i not in wrong_idx] 

151 horizontal_grids = [horizontal_grids[i] for i in range(len(horizontal_grids)) if i not in wrong_idx] 

152 

153 for ln in range(len(horizontal_grids[0])): 

154 # fit a line to grind line numbe ln 

155 x = np.array(xcuts) 

156 y = np.array([horizontal_grids[i][ln] for i in range(len(xcuts))]) 

157 fit = np.polyfit(x, y, 1) 

158 angles.append(np.arctan(-fit[0] / oversample) * 180 / np.pi) 

159 y_pos.append(y[0] / oversample) 

160 # displacement between x=0 and max(x) 

161 displacement.append(-fit[0] * maxx / oversample) 

162 angles = np.array(angles) 

163 y_pos = np.array(y_pos) 

164 displacement = np.array(displacement) 

165 return angles, y_pos, displacement 

166 

167 def get_grid(self, image, xcuts, distance=40, odir=None): 

168 """ 

169 Get the rotation of the grid target lines in the image 

170 by locating them only at the xcuts columns 

171 """ 

172 oversample = 50 # oversampling factor 

173 

174 xlen = image.shape[0] 

175 sobelx = sobel(image, axis=0) 

176 sobely = sobel(image, axis=1) 

177 

178 x = np.arange(xlen) 

179 nsamples = xlen * oversample 

180 xintp = np.linspace(0, xlen - 1, num=nsamples) 

181 distance *= len(xintp) / len(x) 

182 oversample = float(len(xintp) / len(x)) 

183 

184 # get grid lines location at xcuts 

185 orig_cuts = len(xcuts) 

186 xcuts = [i for i in xcuts if i < xlen] 

187 if len(xcuts) < orig_cuts: 

188 log.debug(f"Skipping columns {orig_cuts-len(xcuts)} as they are out of the transformed input image") 

189 horizontal_grids = [] 

190 for i in range(len(xcuts)): 

191 if i == 0: 

192 oplot = False 

193 close_fig = False 

194 else: 

195 oplot = True 

196 if i == len(xcuts) - 1: 

197 close_fig = True 

198 grid_loc = self.get_grid_lines_loc( 

199 sobelx, 

200 xcuts[i], 

201 xintp, 

202 x, 

203 oversample, 

204 distance=distance, 

205 odir=odir, 

206 oplot=oplot, 

207 close_fig=close_fig, 

208 ) 

209 horizontal_grids.append(grid_loc) 

210 

211 if len(horizontal_grids) < 2: 

212 log.error("No grid lines found") 

213 return None, None, None, None, None, None 

214 

215 angles, y_pos, displacement = self.get_grid_lines_angles(horizontal_grids, xcuts, xlen - 1, oversample) 

216 

217 # plot rot vs tilt 

218 if odir is not None: 

219 plt.figure(figsize=(10, 7)) 

220 plt.plot(y_pos[0], angles[0], "--or", label=f"displacement") 

221 plt.plot(y_pos, angles, "--ob", label=f"grid tilt") 

222 plt.legend() 

223 plt.xlabel("vertical position [px]") 

224 plt.ylabel("grid line tilt [deg]") 

225 ctilt = np.mean(np.abs(angles)) 

226 cdisp = np.mean(np.abs(displacement)) 

227 plt.title(f"Mean abs tilt: {ctilt:.4f} deg \n Mean abs displacement: {cdisp:.4f} px") 

228 plt.grid() 

229 y2 = plt.twinx() 

230 y2.plot(y_pos, displacement, "--or", label=f"displacement") 

231 y2.set_ylabel("vertical displacement [px]") 

232 plt.savefig(os.path.join(odir, "rotation_vs_slit.png")) 

233 plt.close() 

234 

235 # saves in xpos, angles and displacement to csv 

236 with open(os.path.join(odir, "rotation_vs_slit.csv"), "w") as f: 

237 f.write("vertical position [px], grid tilt [deg], vertical displacement [px]\n") 

238 for i in range(len(y_pos)): 

239 f.write(f"{y_pos[i]}, {angles[i]}, {displacement[i]}\n") 

240 

241 return ( 

242 np.sqrt(sobelx**2 + sobely**2), 

243 horizontal_grids[0] / oversample, 

244 np.array(displacement), 

245 angles, 

246 sobelx, 

247 sobely, 

248 ) 

249 

250 @staticmethod 

251 def transform_image(img, shear, coord_off=(0, 0)): 

252 return susi.ShearDistortion(img).run(shear, coord_off=coord_off) 

253 

254 def error_func(self, shear, img, xcuts, print_msg=False, peak_distance=140, coord_off=(0, 0)): 

255 _, _, error, angles, _, _ = self.get_grid( 

256 RotAndShearAnalysis.transform_image(img, shear, coord_off=coord_off), 

257 xcuts, 

258 distance=peak_distance, 

259 ) 

260 # error to have all angles equal 

261 error = np.sum(np.abs(angles - np.mean(angles)) ** 2) 

262 return error 

263 

264 def plot_detail(self, img, path, peaks_loc, xcuts, detail_xsize): 

265 # input image with horizontal lines at gids 1 peak location detail 

266 fig = plt.figure(figsize=(5, 10)) 

267 detail_img = [ 

268 img[:, 0:detail_xsize] / np.mean(img[:, 0:detail_xsize]), 

269 img[:, -detail_xsize:] / np.mean(img[:, -detail_xsize:]), 

270 ] 

271 detail_img = np.concatenate(detail_img, axis=1) 

272 m = np.mean(detail_img) 

273 sd = np.std(detail_img) 

274 plt.imshow(detail_img, cmap="gray", vmin=m - 2 * sd, vmax=m + 2 * sd) 

275 if peaks_loc is not None: 

276 for y in peaks_loc: 

277 plt.axhline(y=y, c="r", lw=0.7) 

278 if xcuts is not None: 

279 detail_xcuts = [xcuts[0], detail_img.shape[1] - (img.shape[1] - xcuts[-1])] 

280 for x in detail_xcuts: 

281 plt.axvline(x=x, c="b", lw=1) 

282 plt.axvline(x=detail_xsize, c="w", lw=3) 

283 plt.xlabel("horizontal position [px]") 

284 plt.ylabel("vertical position [px]") 

285 plt.colorbar(fraction=0.046, pad=0.04, label="[DN]") 

286 plt.tight_layout() 

287 plt.savefig(path) 

288 plt.close() 

289 

290 def get_columns_to_analyze(self, img, nxcuts): 

291 # Divide image dim 1 in nxcuts and get for each cut the column with the highest mean value 

292 xcuts = [] 

293 for i in range(nxcuts): 

294 xcut = int(i * img.shape[1] / nxcuts) 

295 col = np.mean(img[:, xcut : xcut + img.shape[1] // nxcuts], axis=0) 

296 col = np.argmax(col) 

297 xcuts.append(col + xcut) 

298 log.info(f"Using columns {xcuts} for the analysis") 

299 return xcuts 

300 

301 def get_thick_grid_line_loc(self, yslice, pos_peaks, oversample): 

302 # find the widest grid line around each peak and return its location 

303 fwhm = np.zeros_like(pos_peaks) 

304 for i, p in enumerate(pos_peaks): 

305 # find the left and right half max 

306 left = np.argmax(yslice[:p][::-1] < yslice[p] / 2) 

307 right = np.argmax(yslice[p:] < yslice[p] / 2) 

308 fwhm[i] = right + left 

309 ind = np.argmax(fwhm) 

310 return [(pos_peaks[ind] - left) / oversample, (pos_peaks[ind] + right) / oversample] 

311 

312 def run(self): 

313 # constants 

314 # TODO: Try again with a proper bound least-square optimization 

315 fit_nstep = 1e3 # number of steps for brute force optimization 

316 peak_distance = self.min_grid_distance # distance between peaks 

317 minimum_search_range = (-1, 0.5) # search range for the shear factor 

318 detail_xsize = 200 # size borders used in the detail image 

319 ncol_to_analyze = 20 # number of columns to analyze 

320 if self.roi[0].start is None: 

321 coord_off = (0, 0) 

322 else: 

323 coord_off = (self.roi[0].start - 1, self.roi[1].start - 1) 

324 

325 # output directory 

326 if self.odir is None: 

327 odir = None 

328 else: 

329 odir = self.odir + f'/{os.path.basename(self.ipath).split(".")[0]}_plots' 

330 ofits_path = self.odir + f"/{os.path.basename(self.ipath)}" 

331 os.makedirs(odir, exist_ok=True) 

332 log.info(f"Running grid rotation and shear analysis for {self.ipath}") 

333 

334 # read image 

335 ihdu = susi.Fits(self.ipath).read() 

336 full_img = ihdu.data[0, :, :].astype(np.float32) 

337 if self.roi[0].start is None: 

338 img = full_img 

339 else: 

340 img = full_img[self.roi[0], self.roi[1]] 

341 log.info(f"Original image shape: {full_img.shape}. Using ROI: {img.shape}") 

342 

343 xcuts = self.get_columns_to_analyze(img, ncol_to_analyze) 

344 grid, peaks_loc, _, rotation, grid_x, grid_y = self.get_grid(img, xcuts, distance=peak_distance, odir=odir) 

345 

346 if odir is not None: 

347 # input image with horizontal lines at xcuts[0] peak location 

348 fig = plt.figure(figsize=(15, 10)) 

349 m = np.mean(img) 

350 sd = np.std(img) 

351 plt.imshow(img, cmap="gray", vmin=m - 2 * sd, vmax=m + 2 * sd) 

352 for y in peaks_loc: 

353 plt.axhline(y=y, c="r", lw=0.6) 

354 for x in xcuts: 

355 plt.axvline(x=x, c="b", lw=1) 

356 plt.title("Red lines are horizontal \n Rotation is measured at blue columns") 

357 plt.xlabel("horizontal position [px]") 

358 plt.ylabel("vertical position [px]") 

359 plt.colorbar(fraction=0.046, pad=0.04, label="[DN]") 

360 plt.tight_layout() 

361 plt.savefig(os.path.join(odir, "img.png")) 

362 plt.close() 

363 

364 self.plot_detail( 

365 img, 

366 os.path.join(odir, "img_dteail.png"), 

367 peaks_loc, 

368 xcuts, 

369 detail_xsize, 

370 ) 

371 

372 # fit shear factor 

373 if self.custom_param is None: 

374 log.info("Brute force fitting nl shear factor...") 

375 res = brute( 

376 self.error_func, 

377 [minimum_search_range], 

378 args=(img, xcuts, False, peak_distance, coord_off), 

379 Ns=fit_nstep, 

380 workers=os.cpu_count() // 2, 

381 ) 

382 fit_shear = res[0] 

383 log.info(f"Non-Linear shear factor fit: {fit_shear:.6f}") 

384 else: 

385 fit_shear = self.custom_param[1] 

386 log.info(f"Using custom shear factor: {fit_shear:.6f}") 

387 img_tr = self.transform_image(img, fit_shear, coord_off=coord_off) 

388 

389 # removes the mean tilt 

390 if self.custom_param is None: 

391 log.info("Using mean rotation angle after shear correction...") 

392 grid, peaks_loc, _, rotation, grid_x, grid_y = self.get_grid(img_tr, xcuts, distance=peak_distance) 

393 mean_tilt = np.mean(rotation) 

394 else: 

395 log.info("Using custom rotation angle...") 

396 mean_tilt = self.custom_param[0] 

397 log.info("Removing mean tilt {} deg".format(mean_tilt)) 

398 img_tr = susi.RotationCorrection(img=img_tr, angle=-mean_tilt).bicubic() 

399 

400 if odir is not None: 

401 # save transformed image in fits 

402 ofits = susi.Fits(ofits_path) 

403 ofits.data = self.transform_image(full_img, fit_shear) 

404 ofits.data = susi.RotationCorrection(img=ofits.data, angle=-mean_tilt).bicubic() 

405 ofits.header = ihdu.header 

406 ofits.header[SHEAR_CORR] = (fit_shear, "row-shear factor used") 

407 ofits.header[ROTATION_ANG] = (mean_tilt, "Rot. angle used [deg]") 

408 ofits.write_to_disk(overwrite=True) 

409 

410 # plots transformed derrotated image 

411 odir_tr = os.path.join(odir, "corrected") 

412 os.makedirs(odir_tr, exist_ok=True) 

413 grid, peaks_loc, rotation, _, grid_x, grid_y = self.get_grid( 

414 img_tr, xcuts, distance=peak_distance, odir=odir_tr 

415 ) 

416 # saves fit_shear and mean_tilt in a .csv 

417 with open(os.path.join(odir_tr, "corr_params.csv"), "w") as f: 

418 f.write(f"shear factor, {fit_shear}\n") 

419 f.write(f"global rot angle [deg], {mean_tilt}\n") 

420 f.write(f"row offset [px], {coord_off[0]}\n") 

421 f.write(f"column offset [px], {coord_off[1]}\n") 

422 

423 # transformed image with horizontal lines at grid 1 peak location 

424 fig = plt.figure(figsize=(15, 10)) 

425 m = np.mean(img_tr) 

426 sd = np.std(img_tr) 

427 plt.imshow(img_tr, cmap="gray", vmin=m - 2 * sd, vmax=m + 2 * sd) 

428 for y in peaks_loc: 

429 plt.axhline(y=y, c="r", lw=0.6) 

430 for x in xcuts: 

431 plt.axvline(x=x, c="b", lw=1) 

432 plt.title(f"Red lines are horizontal \n Rotation is measured at blue columns") 

433 plt.xlabel("horizontal position [px]") 

434 plt.ylabel("vertical position [px]") 

435 plt.colorbar(fraction=0.046, pad=0.04, label="[DN]") 

436 plt.tight_layout() 

437 plt.savefig(os.path.join(odir_tr, "img_transformed.png")) 

438 plt.close() 

439 

440 # transformed full image with horizontal lines at grid 1 peak location 

441 off_peaks_loc = [p + +self.roi[0].start for p in peaks_loc] 

442 off_xcuts = [x + self.roi[1].start for x in xcuts] 

443 fig = plt.figure(figsize=(15, 10)) 

444 m = np.mean(img_tr) 

445 sd = np.std(img_tr) 

446 plt.imshow(ofits.data, cmap="gray", vmin=m - 2 * sd, vmax=m + 2 * sd) 

447 for y in off_peaks_loc: 

448 plt.axhline(y=y, c="r", lw=0.6) 

449 for x in off_xcuts: 

450 plt.axvline(x=x, c="b", lw=1) 

451 plt.title(f"Red lines are horizontal \n Rotation is measured at blue columns") 

452 plt.xlabel("horizontal position [px]") 

453 plt.ylabel("vertical position [px]") 

454 plt.colorbar(fraction=0.046, pad=0.04, label="[DN]") 

455 plt.tight_layout() 

456 plt.savefig(os.path.join(odir_tr, "img_transformed_full.png")) 

457 plt.close() 

458 

459 # input image with horizontal lines at lines at grid 1 peak location detail 

460 self.plot_detail( 

461 img_tr, 

462 os.path.join(odir_tr, "img_transformed_detail.png"), 

463 peaks_loc, 

464 xcuts, 

465 detail_xsize, 

466 ) 

467 self.plot_detail( 

468 ofits.data, 

469 os.path.join(odir_tr, "img_transformed_detail_full.png"), 

470 off_peaks_loc, 

471 off_xcuts, 

472 detail_xsize, 

473 ) 

474 

475 # error_func vs shear 

476 log_level = log.getEffectiveLevel() 

477 if log_level == 10: 

478 log.info("Computing error function vs shear factor...") 

479 shear = np.linspace(-1, 1, 100) 

480 error = np.zeros_like(shear) 

481 for i, s in enumerate(shear): 

482 error[i] = self.error_func(s, img, xcuts, peak_distance=peak_distance, coord_off=coord_off) 

483 plt.figure(figsize=(10, 7)) 

484 cerror = self.error_func( 

485 fit_shear, 

486 img, 

487 xcuts, 

488 peak_distance=peak_distance, 

489 coord_off=coord_off, 

490 ) 

491 label = f"error: {cerror:.6f}" 

492 plt.plot(shear, error, label=label) 

493 plt.vlines( 

494 fit_shear, 

495 0, 

496 np.max(error), 

497 color="r", 

498 lw=1, 

499 label=f"fit shear: {fit_shear:.6f}", 

500 ) 

501 plt.xlabel("shear factor") 

502 plt.ylabel("mean displacement between cuts [px]") 

503 plt.grid() 

504 plt.legend() 

505 plt.savefig(os.path.join(odir_tr, "shear_vs_error.png")) 

506 plt.close() 

507 

508 # plot the grid and its components 

509 fig, axs = plt.subplots(ncols=2, nrows=2, sharex=True, sharey=True, figsize=(10, 10)) 

510 axs[0, 0].imshow(img, cmap="gray") 

511 axs[0, 1].imshow(grid, cmap="gray") 

512 axs[1, 0].imshow(grid_x, cmap="gray") 

513 axs[1, 1].imshow(grid_y, cmap="gray") 

514 plt.tight_layout() 

515 plt.savefig(os.path.join(odir, "edge_detection.png")) 

516 plt.close() 

517 

518 self.angle = mean_tilt 

519 self.shear = fit_shear 

520 return mean_tilt, fit_shear