Coverage for src/susi/analyse/slit_mask_border.py: 52%

209 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1import numpy as np 

2from scipy.interpolate import interp1d 

3from scipy.optimize import least_squares 

4import matplotlib.pyplot as plt 

5from scipy.signal import savgol_filter 

6from scipy.optimize import curve_fit 

7from scipy import stats 

8from skimage.feature.tests.test_orb import img 

9from ... import susi 

10from ..utils import MP 

11 

12 

13log = susi.Logging.get_logger() 

14 

15 

16class GetSlitMaskBorder_v2: 

17 """ 

18 Fits a line to the border of the slit mask seen in SP images. 

19 Returns the line coefficients (runs in parallel for images and optionally for columns) 

20 

21 # author: iglesias@mps.mpg.de, lagg@mps.mpg.de 

22 

23 The method is based on the following steps: 

24 - Step 1: coarse determination of the edge based on the derivative of the column-averaged, 

25 heavily smoothed step function. 

26 - Step 2: subpixel accurate measurement of the edge in each column by fitting a Gaussian 

27 to the derivative (take only +-Nc pixels around te step) 

28 - Step 3: Fit a linear function to the location of the Gaussian. 

29 To avoid outliers, only the points +-0.5*sigma are taken into account. 

30 If not enough points are fulfilling this criterion, the sigma level is 

31 successively increased in steps of 0.25 

32 """ 

33 

34 N = 40 # [px] max number of pixels to consider around the column-avg slit mask edge 

35 NYOFF = 20 # [px] offset in y direction to remove edge effects 

36 STEP = 300 # number of columns processed by each worker 

37 COL_FILTER = 30 # Columns that are not at least COL_FILTER [DN] brigther than the mean slit mask level are dropped 

38 GAUSS_WIDTH = 20 # [px] max number of pixels to consider around the slit mask edge of each column must be < N 

39 

40 def __init__(self, roi=None, debug=False, xbin=1): 

41 #: ROI to analyze 

42 self.roi = roi 

43 #: debug flag 

44 self.debug = debug 

45 #: result dictionary 

46 self.result = None 

47 # set to compute columns chunks in paralell 

48 self.col_paralell: bool = False 

49 #: number of workers to use for columns, if none ncol/STEP is used 

50 self.col_workers = 10 

51 #: number of workers to use for images 

52 self.img_workers = 20 

53 #: number of columns to bin before the analysis. 1 means no binning 

54 self.bin = xbin 

55 

56 def run_img(self, img, parallel=True): 

57 """ 

58 for single 2d array, runs col in parallel 

59 """ 

60 self.col_paralell = parallel 

61 img = self._bin_data(img) 

62 img_args = self.__gen_img_arguments(img[None, :, :]) 

63 _, self.result = GetSlitMaskBorder_v2.process_image(img_args[0]) 

64 return self.result 

65 

66 def run_cube(self, cube): 

67 """ 

68 for 3d array [n,x,y], runs img in parallel 

69 

70 """ 

71 cube = self._bin_data(cube) 

72 img_args = self.__gen_img_arguments(cube) 

73 log.info(f'Analysing slit mask shift for cube of shape {cube.shape} with {self.img_workers} workers...') 

74 result = MP.simultaneous(GetSlitMaskBorder_v2.process_image, img_args, workers=self.img_workers) 

75 log.info('Analysis done, collecting results...') 

76 self.result = [] 

77 result = sorted(result, key=lambda x: x[0]) 

78 for idx, coeffs in result: 

79 self.result.append(coeffs) 

80 return self.result 

81 

82 def _bin_data(self, data): 

83 if data.ndim == 2: 

84 return susi.utils.collections.Collections.bin(data, [1, self.bin]) 

85 else: 

86 return susi.utils.collections.Collections.bin(data, [1, 1, self.bin]) 

87 

88 def __gen_img_arguments(self, cube): 

89 img_args = [ 

90 { 

91 'img_idx': i, 

92 'img': cube[i], 

93 'roi': self.roi, 

94 'debug': self.debug, 

95 'col_workers': self.col_workers, 

96 'col_paralell': self.col_paralell, 

97 } 

98 for i in range(cube.shape[0]) 

99 ] 

100 return img_args 

101 

102 @staticmethod 

103 def process_image(args): 

104 img, roi, debug, col_workers, col_paralell = ( 

105 args['img'], 

106 args['roi'], 

107 args['debug'], 

108 args['col_workers'], 

109 args['col_paralell'], 

110 ) 

111 if roi is None: 

112 roi = [slice(0, img.shape[0]), slice(0, img.shape[1])] 

113 

114 nx, ny = img.shape[1], img.shape[0] 

115 # detect the "step" from dark area to bright area by summing up all columns along x direction: 

116 img_col = np.sum(img, axis=1) 

117 img_col_sg = savgol_filter(img_col, 15, 1) 

118 img_coldy = np.diff(img_col_sg) 

119 dy_max = np.float64(np.argmax(img_coldy)) 

120 if debug: 

121 plt.clf() 

122 plt.plot(img_col) 

123 plt.plot(img_col_sg) 

124 plt.axvline(x=dy_max) 

125 plt.show() 

126 pass 

127 

128 # store indices +-N pixels around edge: 

129 iloc = [ 

130 int(max([dy_max - GetSlitMaskBorder_v2.N, GetSlitMaskBorder_v2.NYOFF])), 

131 int(min([dy_max + GetSlitMaskBorder_v2.N, ny])), 

132 ] 

133 loc_max = dy_max 

134 # go through every column in parallel and find the maximum of gaussian fit 

135 dy_max = GetSlitMaskBorder_v2.process_all_columns(img, iloc, loc_max, col_paralell, debug, col_workers) 

136 if len(dy_max) == 0: 

137 log.warning(f"Error during slit mask border detection of image {args['img_idx']}, returning nan") 

138 return args['img_idx'], [np.nan, np.nan] 

139 dy_max += iloc[0] + roi[0].start 

140 # fit a line to dy_max, remove +-x sigma outliers 

141 sigma = 0.5 

142 ok = False 

143 while not ok: 

144 dy_stats = stats.describe(dy_max) 

145 outlier = np.abs(dy_stats.mean - dy_max) > sigma * dy_stats.variance**0.5 

146 # there need to be a sufficient number of valid pixels left, middle and right: 

147 nxmin = nx / 3 * 0.80 # 20% of the section shall be valid 

148 left = np.sum(outlier[: int(nx / 3)]) 

149 middle = np.sum(outlier[int(nx / 3) : int(2 * nx / 3)]) 

150 right = np.sum(outlier[int(2 * nx / 3) :]) 

151 if left <= nxmin and middle <= nxmin and right <= nxmin: 

152 ok = True 

153 continue 

154 else: 

155 sigma += 0.25 

156 if sigma >= 3: 

157 log.info("Boundary not found during slit mask border analysis") 

158 continue 

159 dy_max[outlier] = np.nan 

160 # fit a linear function to dy_max: 

161 xvec = np.arange(len(dy_max)) + roi[1].start 

162 idx = np.isfinite(dy_max) 

163 coeffs = np.polyfit(xvec[idx], dy_max[idx], 1) 

164 ffit = np.polyval(coeffs, xvec) 

165 if debug: 

166 plt.clf() 

167 plt.imshow(img, extent=(roi[1].start, roi[1].stop, roi[0].stop, roi[0].start)) 

168 plt.plot(xvec, ffit, c='orange', lw=2.0) 

169 plt.plot(xvec, dy_max, c='r', lw=0.5, alpha=0.5) 

170 plt.title( 

171 "coeffs={:.3e},{:.2e}, angle={:.3e} deg".format( 

172 coeffs[0], coeffs[1], np.arctan(coeffs[0]) * 180 / np.pi 

173 ) 

174 ) 

175 plt.show() 

176 plt.pause(0.1) 

177 return args['img_idx'], coeffs 

178 

179 def process_all_columns(img, iloc, loc_max, col_paralell=True, debug=False, col_workers=None): 

180 # Drop all columns that are not at least COL_FILTER% brigther than the mean slit mask level 

181 slit_mask_level = np.mean(img[0 : int(loc_max - GetSlitMaskBorder_v2.N), :]) 

182 col_mean = np.mean(img[int(loc_max + GetSlitMaskBorder_v2.N) :, :], axis=0) 

183 col_mask = col_mean > (slit_mask_level + GetSlitMaskBorder_v2.COL_FILTER) 

184 filt_img = img[:, col_mask] 

185 if filt_img.shape[1] < 50: 

186 log.warning("Not enough columns for slit mask border detection, image too dark ?") 

187 return [] 

188 dy_max = np.zeros(filt_img.shape[1]) 

189 if col_paralell: 

190 args = GetSlitMaskBorder_v2.gen_col_arguments(filt_img, iloc, loc_max, debug) 

191 

192 if col_workers is None: 

193 col_workers = filt_img.shape[1] // GetSlitMaskBorder_v2.STEP 

194 

195 result = MP.simultaneous(GetSlitMaskBorder_v2.columns_shift, args, workers=col_workers) 

196 result = [item for sublist in result for item in sublist] 

197 for idx, dy in result: 

198 dy_max[idx] = dy 

199 

200 else: 

201 for j in range(filt_img.shape[1])[:: GetSlitMaskBorder_v2.STEP]: 

202 args = { 

203 'idx': j, 

204 'idx1': j + GetSlitMaskBorder_v2.STEP, 

205 'img_colj': filt_img[:, j : j + GetSlitMaskBorder_v2.STEP], 

206 'iloc': iloc, 

207 'loc_max': loc_max, 

208 'debug': debug, 

209 } 

210 result = GetSlitMaskBorder_v2.columns_shift(args) 

211 result = [item[1] for item in result] 

212 dy_max[j : j + GetSlitMaskBorder_v2.STEP] = result 

213 return dy_max 

214 

215 def gen_col_arguments(img, iloc, loc_max, debug) -> list: 

216 args = [] 

217 for j in range(img.shape[1])[:: GetSlitMaskBorder_v2.STEP]: 

218 args.append( 

219 { 

220 'idx': j, 

221 'idx1': j + GetSlitMaskBorder_v2.STEP, 

222 'img_colj': img[:, j : j + GetSlitMaskBorder_v2.STEP], 

223 'iloc': iloc, 

224 'loc_max': loc_max, 

225 'debug': debug, 

226 } 

227 ) 

228 return args 

229 

230 def columns_shift(args: dict): 

231 """ 

232 Compute the shift of the slit mask edge in a set of columns 

233 """ 

234 results = [] 

235 for j in range(0, args['img_colj'].shape[1]): 

236 # profile func to fit 

237 def gauss(x, a, x0, sigma): 

238 return a * np.exp(-((x - x0) ** 2) / (2 * sigma**2)) 

239 

240 # compute derivative from slightly smoothed img column: 

241 col_sg = savgol_filter(args['img_colj'][:, j], 9, 1) 

242 col_dy = np.diff(col_sg)[args['iloc'][0] : args['iloc'][1]] 

243 # locate the maximum of the derivative fit to col_dy +-Nc pixels around args['loc_max']: 

244 Nc = GetSlitMaskBorder_v2.GAUSS_WIDTH 

245 x = np.int32(np.arange(-Nc, Nc + 1) + args['loc_max']) - args['iloc'][0] 

246 col_cut = col_dy[x[0] : x[-1] + 1] 

247 try: 

248 popt, pcov = curve_fit(gauss, x, col_cut, p0=[max(col_cut), args['loc_max'] - args['iloc'][0], 5.0]) 

249 dy_max = popt[1] 

250 results.append([args['idx'] + j, dy_max]) 

251 except Exception as e: 

252 # log.debug("Gauss fit did not converge for column %d" % (args['idx'] + j)) 

253 # if args['debug']: 

254 # plt.clf() 

255 # plt.plot(x, col_cut) 

256 # plt.show() 

257 # plt.pause(0.2) 

258 # plt.close() 

259 results.append([args['idx'] + j, np.nan]) 

260 return results 

261 

262 

263class GetSlitMaskBorder: 

264 """ 

265 Fits a line to the border of the slit mask (level based) seen in SP images. 

266 Returns the line angle and its intercept wrt the reference profile. 

267 0 deg is considered horizontal (row) direction to the right 

268 90 deg is vertical (column) direction to the top 

269 

270 # author: iglesias@mps.mpg.de 

271 """ 

272 

273 # numebr of x positions to sample the edge rotation 

274 NXLOC = 100 

275 # Oversample factor for the edge detection 

276 OVERSAMPLE = 100 

277 # Max shit [px] allowed (which defines the max angle you can find) 

278 MAX_YSHIFT = 30 

279 

280 def __init__(self, img: np.array, ref_profile=None): 

281 #: input image. Must be cropped to show mostly the slit mask border 

282 self.orig = img 

283 #: 1D reference profile (same size as img.shape[0]) used to compute column shifts. 

284 #: The row coordinate of the half max of the reference profile 

285 #: defines the y=0 of the returned line intercept 

286 #: If None then is obtained from the column with largest mean value 

287 #: Not useful to get an absolute shift but OK to get the angle. 

288 self.ref_profile = ref_profile 

289 #: The detected rotation angle in degree 

290 self.angle: float = 0.0 

291 #: the detected offset of the slit edge wrt to half max of the reference 

292 self.intercept: float = 0 

293 

294 def run(self): 

295 xcuts = self._get_columns_to_analyze() 

296 yloc, xcuts = self._get_edge_yloc(xcuts) 

297 self.angle, self.intercept = self._get_line(xcuts, yloc) 

298 return self 

299 

300 def _get_line(self, xcuts, yloc): 

301 # fit a line to the edge points and get the angle of the line 

302 fit = np.polyfit(xcuts, yloc, 1) 

303 return np.degrees(np.arctan(fit[0])), fit[1] 

304 

305 def _get_edge_yloc(self, xcuts): 

306 # get the location of the slit mask edge using as reference 

307 yloc = [] 

308 cut_len = self.orig.shape[0] 

309 xcol = np.arange(cut_len) 

310 xinterp = np.linspace(0, cut_len - 1, cut_len * self.OVERSAMPLE) 

311 if self.ref_profile is None: 

312 ref_xloc = np.mean(self.orig[:, xcuts], axis=0).argmax() 

313 ref_prof = self.orig[:, ref_xloc] 

314 log.debug(f"Reference profile column: {ref_xloc}") 

315 else: 

316 ref_prof = self.ref_profile 

317 ref_prof = interp1d(xcol, ref_prof, kind="linear")(xinterp) 

318 for xcut in xcuts: 

319 col = self.orig[:, xcut] 

320 col = interp1d(xcol, col, kind="linear")(xinterp) 

321 loc = self._compute_shift(ref_prof, col) 

322 yloc.append(loc) 

323 # filter outliers 

324 yloc = np.array(yloc) 

325 ind = np.where(np.abs(yloc - np.median(yloc)) < 3 * np.std(yloc)) 

326 # line passes for the half max of the ref profile 

327 half_max_level = (np.percentile(ref_prof, 90) - np.percentile(ref_prof, 10)) / 2 

328 half_max_loc = np.min(np.where(ref_prof > half_max_level)) 

329 return (yloc[ind] + half_max_loc) / self.OVERSAMPLE, xcuts[ind] 

330 

331 def _get_columns_to_analyze(self): 

332 # Divide image dim 1 in NXLOC and get for each cut the column with the highest mean value 

333 xcuts = [] 

334 for i in range(self.NXLOC): 

335 xcut = int(i * self.orig.shape[1] / self.NXLOC) 

336 col = np.mean(self.orig[:, xcut : xcut + self.orig.shape[1] // self.NXLOC], axis=0) 

337 col = np.argmax(col) 

338 xcuts.append(col + xcut) 

339 return np.array(xcuts) 

340 

341 def _apply_subpixel_shift(self, signal, shift): 

342 original_indices = np.arange(len(signal)) 

343 shifted_indices = original_indices - shift 

344 interpolator = interp1d(original_indices, signal, kind="cubic", fill_value="extrapolate") 

345 shifted_signal = interpolator(shifted_indices) 

346 

347 return shifted_signal 

348 

349 def _diff_err(self, x, signal1, signal2): 

350 diff = signal1 - self._apply_subpixel_shift(signal2, x[0]) 

351 int_x0 = int(np.abs(np.floor(x[0]))) 

352 err = np.sum(diff[int_x0 + 1 : -int_x0 - 1] ** 2) 

353 return err 

354 

355 def _compute_shift(self, signal1, signal2, min_disp=-MAX_YSHIFT * OVERSAMPLE, max_disp=MAX_YSHIFT * OVERSAMPLE): 

356 return self._compute_shift_fit(signal1, signal2, min_disp=min_disp, max_disp=max_disp) 

357 

358 def _compute_shift_fit(self, signal1, signal2, min_disp=0, max_disp=0): 

359 res = least_squares(self._diff_err, [0], args=(signal1, signal2), bounds=([min_disp], [max_disp])) 

360 return res.x[0]