Coverage for src/susi/analyse/slit_mask_border.py: 52%
209 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
1import numpy as np
2from scipy.interpolate import interp1d
3from scipy.optimize import least_squares
4import matplotlib.pyplot as plt
5from scipy.signal import savgol_filter
6from scipy.optimize import curve_fit
7from scipy import stats
8from skimage.feature.tests.test_orb import img
9from ... import susi
10from ..utils import MP
13log = susi.Logging.get_logger()
16class GetSlitMaskBorder_v2:
17 """
18 Fits a line to the border of the slit mask seen in SP images.
19 Returns the line coefficients (runs in parallel for images and optionally for columns)
21 # author: iglesias@mps.mpg.de, lagg@mps.mpg.de
23 The method is based on the following steps:
24 - Step 1: coarse determination of the edge based on the derivative of the column-averaged,
25 heavily smoothed step function.
26 - Step 2: subpixel accurate measurement of the edge in each column by fitting a Gaussian
27 to the derivative (take only +-Nc pixels around te step)
28 - Step 3: Fit a linear function to the location of the Gaussian.
29 To avoid outliers, only the points +-0.5*sigma are taken into account.
30 If not enough points are fulfilling this criterion, the sigma level is
31 successively increased in steps of 0.25
32 """
34 N = 40 # [px] max number of pixels to consider around the column-avg slit mask edge
35 NYOFF = 20 # [px] offset in y direction to remove edge effects
36 STEP = 300 # number of columns processed by each worker
37 COL_FILTER = 30 # Columns that are not at least COL_FILTER [DN] brigther than the mean slit mask level are dropped
38 GAUSS_WIDTH = 20 # [px] max number of pixels to consider around the slit mask edge of each column must be < N
40 def __init__(self, roi=None, debug=False, xbin=1):
41 #: ROI to analyze
42 self.roi = roi
43 #: debug flag
44 self.debug = debug
45 #: result dictionary
46 self.result = None
47 # set to compute columns chunks in paralell
48 self.col_paralell: bool = False
49 #: number of workers to use for columns, if none ncol/STEP is used
50 self.col_workers = 10
51 #: number of workers to use for images
52 self.img_workers = 20
53 #: number of columns to bin before the analysis. 1 means no binning
54 self.bin = xbin
56 def run_img(self, img, parallel=True):
57 """
58 for single 2d array, runs col in parallel
59 """
60 self.col_paralell = parallel
61 img = self._bin_data(img)
62 img_args = self.__gen_img_arguments(img[None, :, :])
63 _, self.result = GetSlitMaskBorder_v2.process_image(img_args[0])
64 return self.result
66 def run_cube(self, cube):
67 """
68 for 3d array [n,x,y], runs img in parallel
70 """
71 cube = self._bin_data(cube)
72 img_args = self.__gen_img_arguments(cube)
73 log.info(f'Analysing slit mask shift for cube of shape {cube.shape} with {self.img_workers} workers...')
74 result = MP.simultaneous(GetSlitMaskBorder_v2.process_image, img_args, workers=self.img_workers)
75 log.info('Analysis done, collecting results...')
76 self.result = []
77 result = sorted(result, key=lambda x: x[0])
78 for idx, coeffs in result:
79 self.result.append(coeffs)
80 return self.result
82 def _bin_data(self, data):
83 if data.ndim == 2:
84 return susi.utils.collections.Collections.bin(data, [1, self.bin])
85 else:
86 return susi.utils.collections.Collections.bin(data, [1, 1, self.bin])
88 def __gen_img_arguments(self, cube):
89 img_args = [
90 {
91 'img_idx': i,
92 'img': cube[i],
93 'roi': self.roi,
94 'debug': self.debug,
95 'col_workers': self.col_workers,
96 'col_paralell': self.col_paralell,
97 }
98 for i in range(cube.shape[0])
99 ]
100 return img_args
102 @staticmethod
103 def process_image(args):
104 img, roi, debug, col_workers, col_paralell = (
105 args['img'],
106 args['roi'],
107 args['debug'],
108 args['col_workers'],
109 args['col_paralell'],
110 )
111 if roi is None:
112 roi = [slice(0, img.shape[0]), slice(0, img.shape[1])]
114 nx, ny = img.shape[1], img.shape[0]
115 # detect the "step" from dark area to bright area by summing up all columns along x direction:
116 img_col = np.sum(img, axis=1)
117 img_col_sg = savgol_filter(img_col, 15, 1)
118 img_coldy = np.diff(img_col_sg)
119 dy_max = np.float64(np.argmax(img_coldy))
120 if debug:
121 plt.clf()
122 plt.plot(img_col)
123 plt.plot(img_col_sg)
124 plt.axvline(x=dy_max)
125 plt.show()
126 pass
128 # store indices +-N pixels around edge:
129 iloc = [
130 int(max([dy_max - GetSlitMaskBorder_v2.N, GetSlitMaskBorder_v2.NYOFF])),
131 int(min([dy_max + GetSlitMaskBorder_v2.N, ny])),
132 ]
133 loc_max = dy_max
134 # go through every column in parallel and find the maximum of gaussian fit
135 dy_max = GetSlitMaskBorder_v2.process_all_columns(img, iloc, loc_max, col_paralell, debug, col_workers)
136 if len(dy_max) == 0:
137 log.warning(f"Error during slit mask border detection of image {args['img_idx']}, returning nan")
138 return args['img_idx'], [np.nan, np.nan]
139 dy_max += iloc[0] + roi[0].start
140 # fit a line to dy_max, remove +-x sigma outliers
141 sigma = 0.5
142 ok = False
143 while not ok:
144 dy_stats = stats.describe(dy_max)
145 outlier = np.abs(dy_stats.mean - dy_max) > sigma * dy_stats.variance**0.5
146 # there need to be a sufficient number of valid pixels left, middle and right:
147 nxmin = nx / 3 * 0.80 # 20% of the section shall be valid
148 left = np.sum(outlier[: int(nx / 3)])
149 middle = np.sum(outlier[int(nx / 3) : int(2 * nx / 3)])
150 right = np.sum(outlier[int(2 * nx / 3) :])
151 if left <= nxmin and middle <= nxmin and right <= nxmin:
152 ok = True
153 continue
154 else:
155 sigma += 0.25
156 if sigma >= 3:
157 log.info("Boundary not found during slit mask border analysis")
158 continue
159 dy_max[outlier] = np.nan
160 # fit a linear function to dy_max:
161 xvec = np.arange(len(dy_max)) + roi[1].start
162 idx = np.isfinite(dy_max)
163 coeffs = np.polyfit(xvec[idx], dy_max[idx], 1)
164 ffit = np.polyval(coeffs, xvec)
165 if debug:
166 plt.clf()
167 plt.imshow(img, extent=(roi[1].start, roi[1].stop, roi[0].stop, roi[0].start))
168 plt.plot(xvec, ffit, c='orange', lw=2.0)
169 plt.plot(xvec, dy_max, c='r', lw=0.5, alpha=0.5)
170 plt.title(
171 "coeffs={:.3e},{:.2e}, angle={:.3e} deg".format(
172 coeffs[0], coeffs[1], np.arctan(coeffs[0]) * 180 / np.pi
173 )
174 )
175 plt.show()
176 plt.pause(0.1)
177 return args['img_idx'], coeffs
179 def process_all_columns(img, iloc, loc_max, col_paralell=True, debug=False, col_workers=None):
180 # Drop all columns that are not at least COL_FILTER% brigther than the mean slit mask level
181 slit_mask_level = np.mean(img[0 : int(loc_max - GetSlitMaskBorder_v2.N), :])
182 col_mean = np.mean(img[int(loc_max + GetSlitMaskBorder_v2.N) :, :], axis=0)
183 col_mask = col_mean > (slit_mask_level + GetSlitMaskBorder_v2.COL_FILTER)
184 filt_img = img[:, col_mask]
185 if filt_img.shape[1] < 50:
186 log.warning("Not enough columns for slit mask border detection, image too dark ?")
187 return []
188 dy_max = np.zeros(filt_img.shape[1])
189 if col_paralell:
190 args = GetSlitMaskBorder_v2.gen_col_arguments(filt_img, iloc, loc_max, debug)
192 if col_workers is None:
193 col_workers = filt_img.shape[1] // GetSlitMaskBorder_v2.STEP
195 result = MP.simultaneous(GetSlitMaskBorder_v2.columns_shift, args, workers=col_workers)
196 result = [item for sublist in result for item in sublist]
197 for idx, dy in result:
198 dy_max[idx] = dy
200 else:
201 for j in range(filt_img.shape[1])[:: GetSlitMaskBorder_v2.STEP]:
202 args = {
203 'idx': j,
204 'idx1': j + GetSlitMaskBorder_v2.STEP,
205 'img_colj': filt_img[:, j : j + GetSlitMaskBorder_v2.STEP],
206 'iloc': iloc,
207 'loc_max': loc_max,
208 'debug': debug,
209 }
210 result = GetSlitMaskBorder_v2.columns_shift(args)
211 result = [item[1] for item in result]
212 dy_max[j : j + GetSlitMaskBorder_v2.STEP] = result
213 return dy_max
215 def gen_col_arguments(img, iloc, loc_max, debug) -> list:
216 args = []
217 for j in range(img.shape[1])[:: GetSlitMaskBorder_v2.STEP]:
218 args.append(
219 {
220 'idx': j,
221 'idx1': j + GetSlitMaskBorder_v2.STEP,
222 'img_colj': img[:, j : j + GetSlitMaskBorder_v2.STEP],
223 'iloc': iloc,
224 'loc_max': loc_max,
225 'debug': debug,
226 }
227 )
228 return args
230 def columns_shift(args: dict):
231 """
232 Compute the shift of the slit mask edge in a set of columns
233 """
234 results = []
235 for j in range(0, args['img_colj'].shape[1]):
236 # profile func to fit
237 def gauss(x, a, x0, sigma):
238 return a * np.exp(-((x - x0) ** 2) / (2 * sigma**2))
240 # compute derivative from slightly smoothed img column:
241 col_sg = savgol_filter(args['img_colj'][:, j], 9, 1)
242 col_dy = np.diff(col_sg)[args['iloc'][0] : args['iloc'][1]]
243 # locate the maximum of the derivative fit to col_dy +-Nc pixels around args['loc_max']:
244 Nc = GetSlitMaskBorder_v2.GAUSS_WIDTH
245 x = np.int32(np.arange(-Nc, Nc + 1) + args['loc_max']) - args['iloc'][0]
246 col_cut = col_dy[x[0] : x[-1] + 1]
247 try:
248 popt, pcov = curve_fit(gauss, x, col_cut, p0=[max(col_cut), args['loc_max'] - args['iloc'][0], 5.0])
249 dy_max = popt[1]
250 results.append([args['idx'] + j, dy_max])
251 except Exception as e:
252 # log.debug("Gauss fit did not converge for column %d" % (args['idx'] + j))
253 # if args['debug']:
254 # plt.clf()
255 # plt.plot(x, col_cut)
256 # plt.show()
257 # plt.pause(0.2)
258 # plt.close()
259 results.append([args['idx'] + j, np.nan])
260 return results
263class GetSlitMaskBorder:
264 """
265 Fits a line to the border of the slit mask (level based) seen in SP images.
266 Returns the line angle and its intercept wrt the reference profile.
267 0 deg is considered horizontal (row) direction to the right
268 90 deg is vertical (column) direction to the top
270 # author: iglesias@mps.mpg.de
271 """
273 # numebr of x positions to sample the edge rotation
274 NXLOC = 100
275 # Oversample factor for the edge detection
276 OVERSAMPLE = 100
277 # Max shit [px] allowed (which defines the max angle you can find)
278 MAX_YSHIFT = 30
280 def __init__(self, img: np.array, ref_profile=None):
281 #: input image. Must be cropped to show mostly the slit mask border
282 self.orig = img
283 #: 1D reference profile (same size as img.shape[0]) used to compute column shifts.
284 #: The row coordinate of the half max of the reference profile
285 #: defines the y=0 of the returned line intercept
286 #: If None then is obtained from the column with largest mean value
287 #: Not useful to get an absolute shift but OK to get the angle.
288 self.ref_profile = ref_profile
289 #: The detected rotation angle in degree
290 self.angle: float = 0.0
291 #: the detected offset of the slit edge wrt to half max of the reference
292 self.intercept: float = 0
294 def run(self):
295 xcuts = self._get_columns_to_analyze()
296 yloc, xcuts = self._get_edge_yloc(xcuts)
297 self.angle, self.intercept = self._get_line(xcuts, yloc)
298 return self
300 def _get_line(self, xcuts, yloc):
301 # fit a line to the edge points and get the angle of the line
302 fit = np.polyfit(xcuts, yloc, 1)
303 return np.degrees(np.arctan(fit[0])), fit[1]
305 def _get_edge_yloc(self, xcuts):
306 # get the location of the slit mask edge using as reference
307 yloc = []
308 cut_len = self.orig.shape[0]
309 xcol = np.arange(cut_len)
310 xinterp = np.linspace(0, cut_len - 1, cut_len * self.OVERSAMPLE)
311 if self.ref_profile is None:
312 ref_xloc = np.mean(self.orig[:, xcuts], axis=0).argmax()
313 ref_prof = self.orig[:, ref_xloc]
314 log.debug(f"Reference profile column: {ref_xloc}")
315 else:
316 ref_prof = self.ref_profile
317 ref_prof = interp1d(xcol, ref_prof, kind="linear")(xinterp)
318 for xcut in xcuts:
319 col = self.orig[:, xcut]
320 col = interp1d(xcol, col, kind="linear")(xinterp)
321 loc = self._compute_shift(ref_prof, col)
322 yloc.append(loc)
323 # filter outliers
324 yloc = np.array(yloc)
325 ind = np.where(np.abs(yloc - np.median(yloc)) < 3 * np.std(yloc))
326 # line passes for the half max of the ref profile
327 half_max_level = (np.percentile(ref_prof, 90) - np.percentile(ref_prof, 10)) / 2
328 half_max_loc = np.min(np.where(ref_prof > half_max_level))
329 return (yloc[ind] + half_max_loc) / self.OVERSAMPLE, xcuts[ind]
331 def _get_columns_to_analyze(self):
332 # Divide image dim 1 in NXLOC and get for each cut the column with the highest mean value
333 xcuts = []
334 for i in range(self.NXLOC):
335 xcut = int(i * self.orig.shape[1] / self.NXLOC)
336 col = np.mean(self.orig[:, xcut : xcut + self.orig.shape[1] // self.NXLOC], axis=0)
337 col = np.argmax(col)
338 xcuts.append(col + xcut)
339 return np.array(xcuts)
341 def _apply_subpixel_shift(self, signal, shift):
342 original_indices = np.arange(len(signal))
343 shifted_indices = original_indices - shift
344 interpolator = interp1d(original_indices, signal, kind="cubic", fill_value="extrapolate")
345 shifted_signal = interpolator(shifted_indices)
347 return shifted_signal
349 def _diff_err(self, x, signal1, signal2):
350 diff = signal1 - self._apply_subpixel_shift(signal2, x[0])
351 int_x0 = int(np.abs(np.floor(x[0])))
352 err = np.sum(diff[int_x0 + 1 : -int_x0 - 1] ** 2)
353 return err
355 def _compute_shift(self, signal1, signal2, min_disp=-MAX_YSHIFT * OVERSAMPLE, max_disp=MAX_YSHIFT * OVERSAMPLE):
356 return self._compute_shift_fit(signal1, signal2, min_disp=min_disp, max_disp=max_disp)
358 def _compute_shift_fit(self, signal1, signal2, min_disp=0, max_disp=0):
359 res = least_squares(self._diff_err, [0], args=(signal1, signal2), bounds=([min_disp], [max_disp]))
360 return res.x[0]