Coverage for src/susi/analyse/grid_rot_and_shear.py: 47%
324 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3import os
4import numpy as np
5import pandas as pd
6import matplotlib.pyplot as plt
7from scipy.ndimage import gaussian_filter
8from skimage.filters import sobel
9from scipy.signal import find_peaks
10from scipy.interpolate import interp1d
11from scipy.optimize import brute
12from ... import susi
13from ..base.header_keys import *
14from .. import ROOT_DIR
16log = susi.Logging.get_logger()
18INPUT_ROI_PATH = os.path.abspath(os.path.join(ROOT_DIR, "..", "data", "susi", "reduc", "shear_analisys_rois.csv"))
19DEFAULT_ROI = [slice(440, 1980), slice(100, 2000)] # to use if roi='default' and no entry in INPUT_ROI_PATH
22class RotAndShearAnalysis:
23 """
24 Estimates the shear and rotation angle from the given grid target image
26 ## Author(s)
27 iglesias@mps.mpg.de and vukadinovic@mps.mpg.de
28 """
30 def __init__(
31 self,
32 path: str,
33 odir: str,
34 custom_param=None,
35 roi="default",
36 min_grid_distance=140,
37 grid_lines_to_skip=[], # they are ignored if loaded from file, tickest line is llwas ignored
38 ):
39 self.ipath = path
40 #: The detected rotation angle in degree
41 self.angle: float = 0.0
42 #: The detected shear factor
43 self.shear: float = 0.0
44 #: The output directory
45 self.odir = odir
46 #: Manual parameters [rot, shear]. If set then no fit is done and these values are used for correction
47 self.custom_param = custom_param
48 # grid lines to skip, give a list of tuples with a range of row position in px
49 # relative to the full 2kx2k image. Lines within each range will be skipped
50 self.grid_lines_to_skip = []
51 [self.grid_lines_to_skip.append(g) for g in grid_lines_to_skip]
52 #: Region of interest, use default for best inter comparison of errors between diff targets
53 if roi == "default":
54 self._get_custom_ana_param()
55 yoff = self.roi[0].start
56 elif roi == "full":
57 self.roi = [slice(None), slice(None)]
58 yoff = 0
59 else:
60 self.roi = roi
61 yoff = self.roi[0].start
62 # minimum distance between grid lines [px]
63 self.min_grid_distance = min_grid_distance
64 # corrects for roi offset
65 if len(self.grid_lines_to_skip) > 0:
66 self.grid_lines_to_skip = [[g[0] - yoff, g[1] - yoff] for g in self.grid_lines_to_skip]
68 def _get_custom_ana_param(self):
69 df = pd.read_csv(INPUT_ROI_PATH)
70 fname = os.path.basename(self.ipath).split(".")[0]
71 row = df[df["lvl1_grid_filename"] == fname]
72 if row.empty:
73 self.roi = DEFAULT_ROI
74 else:
75 self.roi = [
76 slice(row["y0"].values[0], row["y1"].values[0]),
77 slice(row["x0"].values[0], row["x1"].values[0]),
78 ]
79 log.debug(f"Using ROI {self.roi = } for shear analysis from {INPUT_ROI_PATH}")
80 if row["grid_lines_to_skip"].values[0] is not np.nan:
81 arr = (
82 np.array(
83 row["grid_lines_to_skip"]
84 .values[0]
85 .replace("[", "")
86 .replace("]", "")
87 .replace(" ", "")
88 .split(",")
89 )
90 .astype(int)
91 .reshape(-1, 2)
92 )
93 self.grid_lines_to_skip = [[arr[i][0], arr[i][1]] for i in range(arr.shape[0])]
95 def get_grid_lines_loc(
96 self,
97 sobelx,
98 xcut,
99 xintp,
100 x,
101 oversample,
102 distance=40,
103 odir=None,
104 oplot=False,
105 close_fig=False,
106 ):
107 """
108 Get the location of the grid lines at columns given by xcut
109 """
110 yslice = gaussian_filter(np.abs(sobelx[:, xcut]), sigma=3)
111 yslice = interp1d(x, yslice, kind="cubic")(xintp)
112 pos_peaks, _ = find_peaks(yslice, distance=distance)
113 self.grid_lines_to_skip.append(self.get_thick_grid_line_loc(yslice, pos_peaks, oversample))
114 if len(self.grid_lines_to_skip) > 0 and len(pos_peaks) > 0:
115 pos_peaks = [
116 p for p in pos_peaks if not any([g[0] < p / oversample < g[1] for g in self.grid_lines_to_skip])
117 ]
118 pos_peaks = np.array(pos_peaks)
119 log.debug(f"Skipped grid lines at positions: {[g for g in self.grid_lines_to_skip]}")
120 if odir is not None:
121 if not oplot:
122 plt.figure(figsize=(15, 7))
123 plt.plot(xintp, yslice, label=f"Column {xcut}")
124 for pos in pos_peaks:
125 col = plt.gca().lines[-1].get_color()
126 plt.axvline(x=pos / oversample, c=col, lw=0.7, ls="--")
127 else:
128 plt.plot(xintp, yslice, label=f"Column {xcut}")
129 for pos in pos_peaks:
130 col = plt.gca().lines[-1].get_color()
131 plt.axvline(x=pos / oversample, c=col, lw=0.7, ls="--")
132 if close_fig:
133 plt.xlabel("vertical position [px]")
134 plt.ylabel("edge intensity")
135 plt.legend()
136 plt.grid()
137 plt.savefig(os.path.join(odir, "grid_slices.png"))
138 plt.close()
139 return pos_peaks
141 def get_grid_lines_angles(self, horizontal_grids, xcuts, maxx, oversample):
142 angles = []
143 y_pos = []
144 displacement = []
145 # check that all xcuts have the same number of grid lines
146 max_nlines = np.max([len(hg) for hg in horizontal_grids])
147 wrong_idx = [i for i in range(len(horizontal_grids)) if len(horizontal_grids[i]) != max_nlines]
148 if len(wrong_idx) > 0:
149 log.debug(f"Skipping columns {wrong_idx} as they have different number of grid lines")
150 xcuts = [xcuts[i] for i in range(len(xcuts)) if i not in wrong_idx]
151 horizontal_grids = [horizontal_grids[i] for i in range(len(horizontal_grids)) if i not in wrong_idx]
153 for ln in range(len(horizontal_grids[0])):
154 # fit a line to grind line numbe ln
155 x = np.array(xcuts)
156 y = np.array([horizontal_grids[i][ln] for i in range(len(xcuts))])
157 fit = np.polyfit(x, y, 1)
158 angles.append(np.arctan(-fit[0] / oversample) * 180 / np.pi)
159 y_pos.append(y[0] / oversample)
160 # displacement between x=0 and max(x)
161 displacement.append(-fit[0] * maxx / oversample)
162 angles = np.array(angles)
163 y_pos = np.array(y_pos)
164 displacement = np.array(displacement)
165 return angles, y_pos, displacement
167 def get_grid(self, image, xcuts, distance=40, odir=None):
168 """
169 Get the rotation of the grid target lines in the image
170 by locating them only at the xcuts columns
171 """
172 oversample = 50 # oversampling factor
174 xlen = image.shape[0]
175 sobelx = sobel(image, axis=0)
176 sobely = sobel(image, axis=1)
178 x = np.arange(xlen)
179 nsamples = xlen * oversample
180 xintp = np.linspace(0, xlen - 1, num=nsamples)
181 distance *= len(xintp) / len(x)
182 oversample = float(len(xintp) / len(x))
184 # get grid lines location at xcuts
185 orig_cuts = len(xcuts)
186 xcuts = [i for i in xcuts if i < xlen]
187 if len(xcuts) < orig_cuts:
188 log.debug(f"Skipping columns {orig_cuts-len(xcuts)} as they are out of the transformed input image")
189 horizontal_grids = []
190 for i in range(len(xcuts)):
191 if i == 0:
192 oplot = False
193 close_fig = False
194 else:
195 oplot = True
196 if i == len(xcuts) - 1:
197 close_fig = True
198 grid_loc = self.get_grid_lines_loc(
199 sobelx,
200 xcuts[i],
201 xintp,
202 x,
203 oversample,
204 distance=distance,
205 odir=odir,
206 oplot=oplot,
207 close_fig=close_fig,
208 )
209 horizontal_grids.append(grid_loc)
211 if len(horizontal_grids) < 2:
212 log.error("No grid lines found")
213 return None, None, None, None, None, None
215 angles, y_pos, displacement = self.get_grid_lines_angles(horizontal_grids, xcuts, xlen - 1, oversample)
217 # plot rot vs tilt
218 if odir is not None:
219 plt.figure(figsize=(10, 7))
220 plt.plot(y_pos[0], angles[0], "--or", label=f"displacement")
221 plt.plot(y_pos, angles, "--ob", label=f"grid tilt")
222 plt.legend()
223 plt.xlabel("vertical position [px]")
224 plt.ylabel("grid line tilt [deg]")
225 ctilt = np.mean(np.abs(angles))
226 cdisp = np.mean(np.abs(displacement))
227 plt.title(f"Mean abs tilt: {ctilt:.4f} deg \n Mean abs displacement: {cdisp:.4f} px")
228 plt.grid()
229 y2 = plt.twinx()
230 y2.plot(y_pos, displacement, "--or", label=f"displacement")
231 y2.set_ylabel("vertical displacement [px]")
232 plt.savefig(os.path.join(odir, "rotation_vs_slit.png"))
233 plt.close()
235 # saves in xpos, angles and displacement to csv
236 with open(os.path.join(odir, "rotation_vs_slit.csv"), "w") as f:
237 f.write("vertical position [px], grid tilt [deg], vertical displacement [px]\n")
238 for i in range(len(y_pos)):
239 f.write(f"{y_pos[i]}, {angles[i]}, {displacement[i]}\n")
241 return (
242 np.sqrt(sobelx**2 + sobely**2),
243 horizontal_grids[0] / oversample,
244 np.array(displacement),
245 angles,
246 sobelx,
247 sobely,
248 )
250 @staticmethod
251 def transform_image(img, shear, coord_off=(0, 0)):
252 return susi.ShearDistortion(img).run(shear, coord_off=coord_off)
254 def error_func(self, shear, img, xcuts, print_msg=False, peak_distance=140, coord_off=(0, 0)):
255 _, _, error, angles, _, _ = self.get_grid(
256 RotAndShearAnalysis.transform_image(img, shear, coord_off=coord_off),
257 xcuts,
258 distance=peak_distance,
259 )
260 # error to have all angles equal
261 error = np.sum(np.abs(angles - np.mean(angles)) ** 2)
262 return error
264 def plot_detail(self, img, path, peaks_loc, xcuts, detail_xsize):
265 # input image with horizontal lines at gids 1 peak location detail
266 fig = plt.figure(figsize=(5, 10))
267 detail_img = [
268 img[:, 0:detail_xsize] / np.mean(img[:, 0:detail_xsize]),
269 img[:, -detail_xsize:] / np.mean(img[:, -detail_xsize:]),
270 ]
271 detail_img = np.concatenate(detail_img, axis=1)
272 m = np.mean(detail_img)
273 sd = np.std(detail_img)
274 plt.imshow(detail_img, cmap="gray", vmin=m - 2 * sd, vmax=m + 2 * sd)
275 if peaks_loc is not None:
276 for y in peaks_loc:
277 plt.axhline(y=y, c="r", lw=0.7)
278 if xcuts is not None:
279 detail_xcuts = [xcuts[0], detail_img.shape[1] - (img.shape[1] - xcuts[-1])]
280 for x in detail_xcuts:
281 plt.axvline(x=x, c="b", lw=1)
282 plt.axvline(x=detail_xsize, c="w", lw=3)
283 plt.xlabel("horizontal position [px]")
284 plt.ylabel("vertical position [px]")
285 plt.colorbar(fraction=0.046, pad=0.04, label="[DN]")
286 plt.tight_layout()
287 plt.savefig(path)
288 plt.close()
290 def get_columns_to_analyze(self, img, nxcuts):
291 # Divide image dim 1 in nxcuts and get for each cut the column with the highest mean value
292 xcuts = []
293 for i in range(nxcuts):
294 xcut = int(i * img.shape[1] / nxcuts)
295 col = np.mean(img[:, xcut : xcut + img.shape[1] // nxcuts], axis=0)
296 col = np.argmax(col)
297 xcuts.append(col + xcut)
298 log.info(f"Using columns {xcuts} for the analysis")
299 return xcuts
301 def get_thick_grid_line_loc(self, yslice, pos_peaks, oversample):
302 # find the widest grid line around each peak and return its location
303 fwhm = np.zeros_like(pos_peaks)
304 for i, p in enumerate(pos_peaks):
305 # find the left and right half max
306 left = np.argmax(yslice[:p][::-1] < yslice[p] / 2)
307 right = np.argmax(yslice[p:] < yslice[p] / 2)
308 fwhm[i] = right + left
309 ind = np.argmax(fwhm)
310 return [(pos_peaks[ind] - left) / oversample, (pos_peaks[ind] + right) / oversample]
312 def run(self):
313 # constants
314 # TODO: Try again with a proper bound least-square optimization
315 fit_nstep = 1e3 # number of steps for brute force optimization
316 peak_distance = self.min_grid_distance # distance between peaks
317 minimum_search_range = (-1, 0.5) # search range for the shear factor
318 detail_xsize = 200 # size borders used in the detail image
319 ncol_to_analyze = 20 # number of columns to analyze
320 if self.roi[0].start is None:
321 coord_off = (0, 0)
322 else:
323 coord_off = (self.roi[0].start - 1, self.roi[1].start - 1)
325 # output directory
326 if self.odir is None:
327 odir = None
328 else:
329 odir = self.odir + f'/{os.path.basename(self.ipath).split(".")[0]}_plots'
330 ofits_path = self.odir + f"/{os.path.basename(self.ipath)}"
331 os.makedirs(odir, exist_ok=True)
332 log.info(f"Running grid rotation and shear analysis for {self.ipath}")
334 # read image
335 ihdu = susi.Fits(self.ipath).read()
336 full_img = ihdu.data[0, :, :].astype(np.float32)
337 if self.roi[0].start is None:
338 img = full_img
339 else:
340 img = full_img[self.roi[0], self.roi[1]]
341 log.info(f"Original image shape: {full_img.shape}. Using ROI: {img.shape}")
343 xcuts = self.get_columns_to_analyze(img, ncol_to_analyze)
344 grid, peaks_loc, _, rotation, grid_x, grid_y = self.get_grid(img, xcuts, distance=peak_distance, odir=odir)
346 if odir is not None:
347 # input image with horizontal lines at xcuts[0] peak location
348 fig = plt.figure(figsize=(15, 10))
349 m = np.mean(img)
350 sd = np.std(img)
351 plt.imshow(img, cmap="gray", vmin=m - 2 * sd, vmax=m + 2 * sd)
352 for y in peaks_loc:
353 plt.axhline(y=y, c="r", lw=0.6)
354 for x in xcuts:
355 plt.axvline(x=x, c="b", lw=1)
356 plt.title("Red lines are horizontal \n Rotation is measured at blue columns")
357 plt.xlabel("horizontal position [px]")
358 plt.ylabel("vertical position [px]")
359 plt.colorbar(fraction=0.046, pad=0.04, label="[DN]")
360 plt.tight_layout()
361 plt.savefig(os.path.join(odir, "img.png"))
362 plt.close()
364 self.plot_detail(
365 img,
366 os.path.join(odir, "img_dteail.png"),
367 peaks_loc,
368 xcuts,
369 detail_xsize,
370 )
372 # fit shear factor
373 if self.custom_param is None:
374 log.info("Brute force fitting nl shear factor...")
375 res = brute(
376 self.error_func,
377 [minimum_search_range],
378 args=(img, xcuts, False, peak_distance, coord_off),
379 Ns=fit_nstep,
380 workers=os.cpu_count() // 2,
381 )
382 fit_shear = res[0]
383 log.info(f"Non-Linear shear factor fit: {fit_shear:.6f}")
384 else:
385 fit_shear = self.custom_param[1]
386 log.info(f"Using custom shear factor: {fit_shear:.6f}")
387 img_tr = self.transform_image(img, fit_shear, coord_off=coord_off)
389 # removes the mean tilt
390 if self.custom_param is None:
391 log.info("Using mean rotation angle after shear correction...")
392 grid, peaks_loc, _, rotation, grid_x, grid_y = self.get_grid(img_tr, xcuts, distance=peak_distance)
393 mean_tilt = np.mean(rotation)
394 else:
395 log.info("Using custom rotation angle...")
396 mean_tilt = self.custom_param[0]
397 log.info("Removing mean tilt {} deg".format(mean_tilt))
398 img_tr = susi.RotationCorrection(img=img_tr, angle=-mean_tilt).bicubic()
400 if odir is not None:
401 # save transformed image in fits
402 ofits = susi.Fits(ofits_path)
403 ofits.data = self.transform_image(full_img, fit_shear)
404 ofits.data = susi.RotationCorrection(img=ofits.data, angle=-mean_tilt).bicubic()
405 ofits.header = ihdu.header
406 ofits.header[SHEAR_CORR] = (fit_shear, "row-shear factor used")
407 ofits.header[ROTATION_ANG] = (mean_tilt, "Rot. angle used [deg]")
408 ofits.write_to_disk(overwrite=True)
410 # plots transformed derrotated image
411 odir_tr = os.path.join(odir, "corrected")
412 os.makedirs(odir_tr, exist_ok=True)
413 grid, peaks_loc, rotation, _, grid_x, grid_y = self.get_grid(
414 img_tr, xcuts, distance=peak_distance, odir=odir_tr
415 )
416 # saves fit_shear and mean_tilt in a .csv
417 with open(os.path.join(odir_tr, "corr_params.csv"), "w") as f:
418 f.write(f"shear factor, {fit_shear}\n")
419 f.write(f"global rot angle [deg], {mean_tilt}\n")
420 f.write(f"row offset [px], {coord_off[0]}\n")
421 f.write(f"column offset [px], {coord_off[1]}\n")
423 # transformed image with horizontal lines at grid 1 peak location
424 fig = plt.figure(figsize=(15, 10))
425 m = np.mean(img_tr)
426 sd = np.std(img_tr)
427 plt.imshow(img_tr, cmap="gray", vmin=m - 2 * sd, vmax=m + 2 * sd)
428 for y in peaks_loc:
429 plt.axhline(y=y, c="r", lw=0.6)
430 for x in xcuts:
431 plt.axvline(x=x, c="b", lw=1)
432 plt.title(f"Red lines are horizontal \n Rotation is measured at blue columns")
433 plt.xlabel("horizontal position [px]")
434 plt.ylabel("vertical position [px]")
435 plt.colorbar(fraction=0.046, pad=0.04, label="[DN]")
436 plt.tight_layout()
437 plt.savefig(os.path.join(odir_tr, "img_transformed.png"))
438 plt.close()
440 # transformed full image with horizontal lines at grid 1 peak location
441 off_peaks_loc = [p + +self.roi[0].start for p in peaks_loc]
442 off_xcuts = [x + self.roi[1].start for x in xcuts]
443 fig = plt.figure(figsize=(15, 10))
444 m = np.mean(img_tr)
445 sd = np.std(img_tr)
446 plt.imshow(ofits.data, cmap="gray", vmin=m - 2 * sd, vmax=m + 2 * sd)
447 for y in off_peaks_loc:
448 plt.axhline(y=y, c="r", lw=0.6)
449 for x in off_xcuts:
450 plt.axvline(x=x, c="b", lw=1)
451 plt.title(f"Red lines are horizontal \n Rotation is measured at blue columns")
452 plt.xlabel("horizontal position [px]")
453 plt.ylabel("vertical position [px]")
454 plt.colorbar(fraction=0.046, pad=0.04, label="[DN]")
455 plt.tight_layout()
456 plt.savefig(os.path.join(odir_tr, "img_transformed_full.png"))
457 plt.close()
459 # input image with horizontal lines at lines at grid 1 peak location detail
460 self.plot_detail(
461 img_tr,
462 os.path.join(odir_tr, "img_transformed_detail.png"),
463 peaks_loc,
464 xcuts,
465 detail_xsize,
466 )
467 self.plot_detail(
468 ofits.data,
469 os.path.join(odir_tr, "img_transformed_detail_full.png"),
470 off_peaks_loc,
471 off_xcuts,
472 detail_xsize,
473 )
475 # error_func vs shear
476 log_level = log.getEffectiveLevel()
477 if log_level == 10:
478 log.info("Computing error function vs shear factor...")
479 shear = np.linspace(-1, 1, 100)
480 error = np.zeros_like(shear)
481 for i, s in enumerate(shear):
482 error[i] = self.error_func(s, img, xcuts, peak_distance=peak_distance, coord_off=coord_off)
483 plt.figure(figsize=(10, 7))
484 cerror = self.error_func(
485 fit_shear,
486 img,
487 xcuts,
488 peak_distance=peak_distance,
489 coord_off=coord_off,
490 )
491 label = f"error: {cerror:.6f}"
492 plt.plot(shear, error, label=label)
493 plt.vlines(
494 fit_shear,
495 0,
496 np.max(error),
497 color="r",
498 lw=1,
499 label=f"fit shear: {fit_shear:.6f}",
500 )
501 plt.xlabel("shear factor")
502 plt.ylabel("mean displacement between cuts [px]")
503 plt.grid()
504 plt.legend()
505 plt.savefig(os.path.join(odir_tr, "shear_vs_error.png"))
506 plt.close()
508 # plot the grid and its components
509 fig, axs = plt.subplots(ncols=2, nrows=2, sharex=True, sharey=True, figsize=(10, 10))
510 axs[0, 0].imshow(img, cmap="gray")
511 axs[0, 1].imshow(grid, cmap="gray")
512 axs[1, 0].imshow(grid_x, cmap="gray")
513 axs[1, 1].imshow(grid_y, cmap="gray")
514 plt.tight_layout()
515 plt.savefig(os.path.join(odir, "edge_detection.png"))
516 plt.close()
518 self.angle = mean_tilt
519 self.shear = fit_shear
520 return mean_tilt, fit_shear