Coverage for src/susi/reduc/fitting/line_detector.py: 100%

55 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1import numpy as np 

2from scipy import signal as sig 

3 

4from ...utils import MP, progress 

5from .line import Line 

6from .line_fit import LineFit 

7from ...base import Logging 

8 

9logger = Logging.get_logger() 

10 

11 

12class LineDetector: 

13 """ 

14 This class aims to detect absorption (or emission) lines in an image. 

15 

16 It first bins all rows and looks for peaks to detect the approx line center. 

17 This reduces noise and allows for a good first estimate where to expect line(s). 

18 

19 Then, for each anchor row, the area around the estimated center is fitted with a gaussian 

20 to detect the actual peak. A map of those detected peaks is available via the `lines` variable 

21 at the end of the process. 

22 """ 

23 

24 def __init__(self, image_data, anchors=170, line_distance=100, error_threshold=2.1): 

25 #: number of anchor points to take for each line 

26 self.anchors = anchors 

27 #: Integer > 1 to define the minimum distance of two lines. 

28 self.line_distance = line_distance 

29 #: Float to set the max error for gauss (before trying with lorentzian) 

30 self.error_threshold = error_threshold 

31 #: The image data as 2-dim matrix 

32 self.data = np.array(image_data) 

33 # list of cols to check 

34 self.check_cols = [] 

35 #: resulting list of lines detected 

36 self.lines = [] 

37 

38 def run(self): 

39 """ 

40 Detect lines at anchor points 

41 """ 

42 self.__normalize() 

43 self.__detect_line_centers() 

44 self.__determine_cols_to_check() 

45 self.__detect_lines_per_col() 

46 

47 def __normalize(self): 

48 self.data = self.data / np.std(self.data) 

49 self.data = self.data - np.min(self.data) 

50 

51 def __detect_line_centers(self) -> None: 

52 row_means = self.data.mean(axis=1) 

53 peaks, _ = sig.find_peaks(row_means, threshold=0.005, distance=self.line_distance) 

54 logger.info('detected %s lines', len(peaks)) 

55 for peak in peaks: 

56 self.lines.append(self.__create_line(peak)) 

57 

58 def __create_line(self, peak: int) -> Line: 

59 return Line(peak, self.data.shape[0], 

60 rot_anker=0, 

61 line_distance=self.line_distance) 

62 

63 def __determine_cols_to_check(self) -> None: 

64 distance = int(np.ceil(self.data.shape[1] / self.anchors)) 

65 logger.debug('Creating anchors every %s rows', distance) 

66 self.check_cols = np.array([min(self.data.shape[1] - 1, distance * i) for i in range(1, self.anchors)], 

67 dtype=int) 

68 self.check_cols = np.unique(self.check_cols) 

69 

70 def __detect_lines_per_col(self): 

71 data = [self.__line_args(line) for line in self.lines] 

72 self.lines = MP.simultaneous(_detect_line, data) 

73 progress.dot(flush=True) 

74 

75 def __line_args(self, line): 

76 return {'line': line, 'cols': self.check_cols, 'data': self.data, 'error': self.error_threshold} 

77 

78 

79def _detect_line(args): 

80 """ 

81 Method on module level to allow parallelization. 

82 

83 :param args: Tuple with (Line, [cols to check], data) 

84 """ 

85 logger.debug('Processing line at %s', args['line'].center) 

86 error_cols = [] 

87 for col in args['cols']: 

88 fitter = LineFit(args['line'].area(col), np.transpose(args['data'])[col][args['line'].area(col)], 

89 error_threshold=args['error']) 

90 try: 

91 fitter.run() 

92 args['line'].add((fitter.max_location, col)) 

93 except RuntimeError: 

94 error_cols.append(col) 

95 success = False if not args['line'].map else len(args['cols'])/len(args['line'].map) > 0.65 

96 progress.dot(success=success) 

97 return args['line']