Coverage for src/susi/analyse/img_rotation.py: 52%

97 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1import numpy as np 

2from skimage.feature import canny 

3from skimage.transform import probabilistic_hough_line, rotate 

4from ..base import InsufficientDataException 

5from ... import susi 

6from scipy.interpolate import interp1d 

7from scipy.optimize import least_squares 

8 

9log = susi.Logging.get_logger() 

10 

11 

12class RotationAnalysis: 

13 """ 

14 This class analyzes the global rotation of a given image. 

15 

16 First it detects edges via the canny algorithm and then it uses scikit 

17 implementation of [Hough Transform](https://en.wikipedia.org/wiki/Hough_transform) to detect 

18 straight lines. Finally, the angle of all lines is computed and the median of all angles is returned. 

19 """ 

20 

21 def __init__(self, img: np.array): 

22 self.orig = img 

23 #: The detected rotation angle in degree 

24 self.angle: float = 0.0 

25 

26 def run(self): 

27 """Start the detection algorithm""" 

28 lines = self.__detect_lines() 

29 if len(lines) == 0: 

30 raise InsufficientDataException("No lines detected. Cannot determine image rotation.") 

31 

32 self.__detect_rotation(lines) 

33 return self 

34 

35 def __detect_lines(self): 

36 edges = canny(self.orig, 2, 3, 50) 

37 return probabilistic_hough_line(edges, threshold=10, line_length=150, line_gap=5) 

38 

39 def __detect_rotation(self, lines): 

40 angles = [np.degrees(np.arctan2(y2 - y1, x2 - x1)) for (x1, y1), (x2, y2) in lines] 

41 self.angle = np.median(angles) + 90 

42 

43 

44class RotationCorrection: 

45 

46 def __init__(self, img: np.array, angle: float): 

47 self.img = img 

48 self.angle = angle 

49 #: The output image roi 

50 self.out_roi: tuple = slice(None, None), slice(None, None) 

51 

52 def __cut_shape(self) -> tuple: 

53 rad_angle = np.abs(self.angle) * np.pi / 180 

54 cut0 = int(np.ceil(np.tan(rad_angle) * self.img.shape[1] / 2)) 

55 cut1 = int(np.ceil(np.tan(rad_angle) * self.img.shape[0] / 2)) 

56 return slice(cut0, self.img.shape[0] - cut0), slice(cut1, self.img.shape[1] - cut1) 

57 

58 def bicubic(self) -> np.array: 

59 """ 

60 Rotate the given image by the given angle with a bicubic algorithm 

61 ### Params 

62 - img: the image to rotate 

63 - angle: the angle to rotate the image by 

64 

65 ### Returns 

66 The rotated image 

67 """ 

68 if self.angle == 0: 

69 self.out_roi = slice(0, self.img.shape[0]), slice(0, self.img.shape[1]) 

70 return self.img 

71 self.out_roi = self.__cut_shape() 

72 return rotate(self.img, self.angle, order=3)[self.out_roi] 

73 

74 

75class GetRotationFromSlitMask: 

76 """ 

77 This gets the rotation using the border of the slit mask seen in SP images 

78 0 deg is considered horizontal (row) direction to the right 

79 90 deg is vertical (column) direction to the top 

80 """ 

81 

82 # numebr of x positions to sample the edge rotation 

83 NXLOC = 100 

84 # Oversample factor for the edge detection 

85 OVERSAMPLE = 100 

86 # Max shit [px] allowed (which defines the max angle you can find) 

87 MAX_YSHIFT = 30 

88 

89 def __init__(self, img: np.array): 

90 self.orig = img 

91 #: The detected rotation angle in degree 

92 self.angle: float = 0.0 

93 #: offset of the slit edge 

94 self.intercept: float = 0 

95 

96 def run(self): 

97 xcuts = self._get_columns_to_analyze() 

98 yloc, xcuts = self._get_edge_yloc(xcuts) 

99 self.angle, self.intercept = self._get_line(xcuts, yloc) 

100 return self 

101 

102 def _get_line(self, xcuts, yloc): 

103 # fit a line to the edge points and get the angle of the line 

104 fit = np.polyfit(xcuts, yloc, 1) 

105 return np.degrees(np.arctan(fit[0])), fit[1] 

106 

107 def _get_edge_yloc(self, xcuts): 

108 # get the yloc of the half max value of the slit edge for each xcut 

109 yloc = [] 

110 cut_len = self.orig.shape[0] 

111 xcol = np.arange(cut_len) 

112 xinterp = np.linspace(0, cut_len - 1, cut_len * self.OVERSAMPLE) 

113 avg_profile = np.mean(self.orig[:, xcuts], axis=1) 

114 avg_profile = interp1d(xcol, avg_profile, kind="linear")(xinterp) 

115 for xcut in xcuts: 

116 col = self.orig[:, xcut] 

117 col = interp1d(xcol, col, kind="linear")(xinterp) 

118 loc = self._compute_shift(avg_profile, col) 

119 yloc.append(loc) 

120 # filter outliers 

121 yloc = np.array(yloc) 

122 ind = np.where(np.abs(yloc - np.median(yloc)) < 3 * np.std(yloc)) 

123 # add the offset to be the half max of the mean profile 

124 half_max_level = (np.percentile(avg_profile, 90) - np.percentile(avg_profile, 10)) / 2 

125 half_max_loc = np.min(np.where(avg_profile > half_max_level)) 

126 return (yloc[ind] + half_max_loc) / self.OVERSAMPLE, xcuts[ind] 

127 

128 def _get_columns_to_analyze(self): 

129 # Divide image dim 1 in NXLOC and get for each cut the column with the highest mean value 

130 xcuts = [] 

131 for i in range(self.NXLOC): 

132 xcut = int(i * self.orig.shape[1] / self.NXLOC) 

133 col = np.mean(self.orig[:, xcut : xcut + self.orig.shape[1] // self.NXLOC], axis=0) 

134 col = np.argmax(col) 

135 xcuts.append(col + xcut) 

136 return np.array(xcuts) 

137 

138 def _apply_subpixel_shift(self, signal, shift): 

139 original_indices = np.arange(len(signal)) 

140 shifted_indices = original_indices - shift 

141 interpolator = interp1d(original_indices, signal, kind="cubic", fill_value="extrapolate") 

142 shifted_signal = interpolator(shifted_indices) 

143 

144 return shifted_signal 

145 

146 def _diff_err(self, x, signal1, signal2): 

147 diff = signal1 - self._apply_subpixel_shift(signal2, x[0]) 

148 int_x0 = int(np.abs(np.floor(x[0]))) 

149 err = np.sum(diff[int_x0 + 1 : -int_x0 - 1] ** 2) 

150 return err 

151 

152 def _compute_shift(self, signal1, signal2, min_disp=-MAX_YSHIFT * OVERSAMPLE, max_disp=MAX_YSHIFT * OVERSAMPLE): 

153 return self._compute_shift_fit(signal1, signal2, min_disp=min_disp, max_disp=max_disp) 

154 

155 def _compute_shift_fit(self, signal1, signal2, min_disp=0, max_disp=0): 

156 res = least_squares(self._diff_err, [0], args=(signal1, signal2), bounds=([min_disp], [max_disp])) 

157 return res.x[0]