Coverage for src/susi/reduc/fitting/line_detector.py: 100%
55 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
1import numpy as np
2from scipy import signal as sig
4from ...utils import MP, progress
5from .line import Line
6from .line_fit import LineFit
7from ...base import Logging
9logger = Logging.get_logger()
12class LineDetector:
13 """
14 This class aims to detect absorption (or emission) lines in an image.
16 It first bins all rows and looks for peaks to detect the approx line center.
17 This reduces noise and allows for a good first estimate where to expect line(s).
19 Then, for each anchor row, the area around the estimated center is fitted with a gaussian
20 to detect the actual peak. A map of those detected peaks is available via the `lines` variable
21 at the end of the process.
22 """
24 def __init__(self, image_data, anchors=170, line_distance=100, error_threshold=2.1):
25 #: number of anchor points to take for each line
26 self.anchors = anchors
27 #: Integer > 1 to define the minimum distance of two lines.
28 self.line_distance = line_distance
29 #: Float to set the max error for gauss (before trying with lorentzian)
30 self.error_threshold = error_threshold
31 #: The image data as 2-dim matrix
32 self.data = np.array(image_data)
33 # list of cols to check
34 self.check_cols = []
35 #: resulting list of lines detected
36 self.lines = []
38 def run(self):
39 """
40 Detect lines at anchor points
41 """
42 self.__normalize()
43 self.__detect_line_centers()
44 self.__determine_cols_to_check()
45 self.__detect_lines_per_col()
47 def __normalize(self):
48 self.data = self.data / np.std(self.data)
49 self.data = self.data - np.min(self.data)
51 def __detect_line_centers(self) -> None:
52 row_means = self.data.mean(axis=1)
53 peaks, _ = sig.find_peaks(row_means, threshold=0.005, distance=self.line_distance)
54 logger.info('detected %s lines', len(peaks))
55 for peak in peaks:
56 self.lines.append(self.__create_line(peak))
58 def __create_line(self, peak: int) -> Line:
59 return Line(peak, self.data.shape[0],
60 rot_anker=0,
61 line_distance=self.line_distance)
63 def __determine_cols_to_check(self) -> None:
64 distance = int(np.ceil(self.data.shape[1] / self.anchors))
65 logger.debug('Creating anchors every %s rows', distance)
66 self.check_cols = np.array([min(self.data.shape[1] - 1, distance * i) for i in range(1, self.anchors)],
67 dtype=int)
68 self.check_cols = np.unique(self.check_cols)
70 def __detect_lines_per_col(self):
71 data = [self.__line_args(line) for line in self.lines]
72 self.lines = MP.simultaneous(_detect_line, data)
73 progress.dot(flush=True)
75 def __line_args(self, line):
76 return {'line': line, 'cols': self.check_cols, 'data': self.data, 'error': self.error_threshold}
79def _detect_line(args):
80 """
81 Method on module level to allow parallelization.
83 :param args: Tuple with (Line, [cols to check], data)
84 """
85 logger.debug('Processing line at %s', args['line'].center)
86 error_cols = []
87 for col in args['cols']:
88 fitter = LineFit(args['line'].area(col), np.transpose(args['data'])[col][args['line'].area(col)],
89 error_threshold=args['error'])
90 try:
91 fitter.run()
92 args['line'].add((fitter.max_location, col))
93 except RuntimeError:
94 error_cols.append(col)
95 success = False if not args['line'].map else len(args['cols'])/len(args['line'].map) > 0.65
96 progress.dot(success=success)
97 return args['line']