Coverage for src/susi/reduc/fitting/line_fit.py: 95%
75 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4Module to help with fitting Gauss Curves to data
6@author: hoelken
7"""
8import warnings
9import numpy as np
10from scipy.optimize import curve_fit, fminbound
11from scipy.signal import find_peaks
13from ...base import Logging
15logger = Logging.get_logger()
18class LineFit:
19 """
20 ## LineFit
21 Helper class to take care of fitting noise data of line profiles.
23 Provided a set of `x` and `y` values of same dimension the process will first look for
24 peaks in the `y`. Depending on the number of peaks the algorithm will try a single or
25 overlapping gauss fit and will compute starting amplitude, mean and sigma from the peak(s)
26 found.
28 It will first try a gaussian fit, if this does not work it will try a lorentzian fit as fallback.
30 The resulting optimized values, covariance and errors can be retrieved directly after the fit was performed.
31 Also the x-location of the maximum (peak) is available.
32 """
34 def __init__(self, xes, yes, error_threshold=2.1, centered=True):
35 #: x axis
36 self.xes = np.array(xes, dtype='float64')
37 #: y values to x axis entries
38 self.yes = np.array(yes, dtype='float64')
39 #: Float to set the max error for gauss (before checking with lorentzian)
40 self.error_threshold = error_threshold
41 # Flag to determine if the dataset is centered around a central peak
42 self.centered = centered
43 # Initial values
44 self.p0_args = []
45 # Results
46 #: POPT: Optimized values for (amplitude, center, sigma) per peak.
47 #: if more than one peak is detected this will be multiple of 3 values with (a1, c1, s1, a2, s2, c2, ...)
48 self.popt = None
49 #: The estimated covariance of popt.
50 self.pcov = None
51 #: The standard deviation errors on (amplitude, center, sigma)
52 self.perr = None
53 #: the absolute max location (x)
54 self.max_location = None
55 #: Fit used
56 self.fitting_function = None
58 def run(self) -> None:
59 """
60 Trigger the fitting process.
62 ### Raises
63 `RuntimeError` if the fit was not successful
64 """
65 self.__check_input()
66 self.__initial_values(self.__find_peaks())
67 self.__fit_line()
68 self.__find_max()
70 def __check_input(self):
71 if len(self.xes) == 0 or len(self.yes) == 0:
72 raise RuntimeError('At least one of the given data sets is empty')
74 def __fit_line(self):
75 self.__fit_gauss()
76 if self.perr is not None and (self.perr < self.error_threshold).all():
77 return
79 self.__fit_lorentz()
81 def __find_peaks(self):
82 if self.centered:
83 return [len(self.xes)//2]
85 threshold = (max(self.yes) - min(self.yes)) / 10
86 peaks, _ = find_peaks(self.yes, distance=3, rel_height=threshold)
87 if len(peaks) == 0:
88 raise RuntimeError('No peaks detected in given dataset')
89 return peaks
91 def __initial_values(self, peaks):
92 ymin = min(self.yes)
93 ysum = sum(self.yes)
94 for peak in peaks:
95 self.p0_args.append(self.yes[peak] + ymin)
96 self.p0_args.append(self.xes[peak])
97 self.p0_args.append(np.sqrt(sum(self.yes * (self.xes - self.xes[peak]) ** 2) / ysum))
99 def __fit_gauss(self):
100 self.fitting_function = 'gaussian'
101 self.__fit(overlapping_gaussian)
102 # plt.plot(self.xes, overlapping_gaussian(self.xes, *self.popt))
103 # plt.plot(self.xes, self.yes)
104 # plt.show()
106 def __fit_lorentz(self):
107 self.fitting_function = 'lorentzian'
108 self.__fit(overlapping_lorentzian)
109 # plt.plot(self.xes, overlapping_lorentzian(self.xes, *self.popt))
110 # plt.plot(self.xes, self.yes)
111 # plt.show()
112 if self.perr is None or np.isnan(self.perr.sum()):
113 raise RuntimeError('Could not fit given data. Neither Gauss nor Lorentz method worked.')
115 def __find_max(self):
116 x0 = min(self.xes)
117 x1 = max(self.xes)
118 self.max_location = fminbound(lambda x: -overlapping_gaussian(x, *self.popt), x0, x1)
120 def __fit(self, func):
121 with warnings.catch_warnings():
122 warnings.filterwarnings("ignore")
123 try:
124 self.popt, self.pcov = curve_fit(func, self.xes, self.yes, p0=self.p0_args)
125 self.perr = np.sqrt(np.diag(self.pcov))
126 except (TypeError, RuntimeWarning, RuntimeError):
127 pass
130def gaussian(x, amplitude, mean, sigma) -> float:
131 """
132 Fitting function for [Gaussian normal distribution](https://en.wikipedia.org/wiki/Normal_distribution).
134 Signature follows requirements for `scipy.optimize.curve_fit` callable,
135 see [curve_fit documentation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html).
136 It takes the independent variable as the first argument and the parameters to fit as separate remaining arguments.
138 ### Params
139 - `x` The free variable
140 - `amplitude` The amplitude
141 - `mean` The center of the peak
142 - `sigma` The standard deviation (The width of the peak)
144 ### Returns
145 The y value
146 """
147 return amplitude * np.exp(-np.power(x - mean, 2.) / (2 * np.power(sigma, 2.)))
150def overlapping_gaussian(x, *args):
151 """
152 Fitting function for data with (potentially) overlapping gaussian shaped peaks.
153 Parameters are similar to `gaussian`. Always only one x, but the other params may come in packs of three.
155 See `gaussian` for further details
156 """
157 return sum(gaussian(x, *args[i*3:(i+1)*3]) for i in range(int(len(args) / 3)))
160def lorentzian(x, amplitude, center, width) -> float:
161 """
162 Fitting function for [Cauchy-Lorentzian distribution](https://en.wikipedia.org/wiki/Cauchy_distribution)
164 Signature follows requirements for `scipy.optimize.curve_fit` callable,
165 see [curve_fit documentation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html).
166 It takes the independent variable as the first argument and the parameters to fit as separate remaining arguments.
168 ### Params
169 - `x` The free variable
170 - `amplitude` The amplitude
171 - `center` The center of the peak
172 - `width` The width of the peak
174 ### Returns
175 The y value
176 """
177 return amplitude * width**2 / ((x-center)**2 + width**2)
180def overlapping_lorentzian(x, *args) -> float:
181 """
182 Fitting function for data with (potentially) overlapping lorentzian shaped peaks.
183 Parameters are similar to `lorentzian`. Always only one x, but the other params may come in packs of three.
185 See `lorentzian` for further details
186 """
187 return sum([lorentzian(x, *args[i*3:(i+1)*3]) for i in range(int(len(args)/3))])