Coverage for src/susi/reduc/fields/shielded_px_correction.py: 77%

109 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4module for shielded pixels' correction provides SHIPXCorrector 

5 

6@author: iglesias, hoelken 

7""" 

8 

9import numpy as np 

10from scipy import signal 

11 

12from ...base import Logging, Config 

13from ...io import Fits 

14from ...utils import MP 

15 

16logger = Logging.get_logger() 

17 

18 

19class SHIPXCorrector: 

20 """ 

21 Provides algorithm for shielded pixel correction 

22 (aka banding correction) 

23 """ 

24 

25 # TODO: Why data need to be a Fits object? 

26 def __init__(self, config: Config, data: Fits): 

27 self.config = config 

28 # Expected data shape is (frames, x, y) 

29 self.data = data.data 

30 

31 def run(self) -> np.array: 

32 """ 

33 Applies the shielded pixel correction to the average image. 

34 Subtracts the mean value of all the shielded px. 

35 Outliers within the shielded px are removed using a median filter 

36 

37 :return: [Array] of img data with the shielded pixel correction applied 

38 """ 

39 if self.config.cam.shielded_px_mode == "N/A": 

40 return self.data 

41 if self.config.cam.shielded_px_mode == "mean": 

42 return ShiPixGlobalMean(self.config, self.data).run() 

43 if self.config.cam.shielded_px_mode == "linear_row": 

44 return ShiPixLinearRow(self.config, self.data).run() 

45 if self.config.cam.shielded_px_mode == "median_col": 

46 return ShiPixMedianCol(self.config, self.data).run() 

47 if self.config.cam.shielded_px_mode == "linear_row_and_median_col": 

48 return ShiPixColAndRow(self.config, self.data).run() 

49 

50 logger.warning( 

51 "Value of config.cam_shielded_px_mode not recognized, skipping shielded px correction" 

52 ) 

53 return self.data 

54 

55 

56class ShiPixLinearRow: 

57 """ 

58 Algorithm for row wise linear fit 

59 """ 

60 

61 SIGMA = 2 

62 

63 def __init__(self, config: Config, data: np.array): 

64 self.config = config 

65 self.data = data.astype(float) + 50 # TODO only works with positive data? 

66 

67 def run(self) -> np.array: 

68 for i in range(self.data.shape[0]): 

69 self._correct_frame(i) 

70 return self.data 

71 

72 def _correct_frame(self, frame_no: int): 

73 mean, std = self._compute_dark(frame_no) 

74 hot_limit = mean + std * ShiPixLinearRow.SIGMA 

75 res = dict( 

76 MP.threaded( 

77 self._correct_row, 

78 [ 

79 (i, self.data[frame_no, i], hot_limit) 

80 for i in range(self.data.shape[1]) 

81 ], 

82 ) 

83 ) 

84 self.data[frame_no] = np.array([res[i] for i in range(self.data.shape[1])]) 

85 

86 def _compute_dark(self, frame_no: int) -> tuple: 

87 left = self.data[frame_no][0 : self.config.cam.shielded_px[2]] 

88 right = self.data[frame_no][-self.config.cam.shielded_px[3] :] 

89 dd = np.concatenate((left, right), axis=1) 

90 return np.mean(dd), np.std(dd) 

91 

92 def _correct_row(self, arg: tuple) -> tuple: 

93 y1 = np.mean( 

94 [v for v in arg[1][0 : self.config.cam.shielded_px[2]] if v < arg[2]] 

95 ) 

96 y2 = np.mean( 

97 [v for v in arg[1][-self.config.cam.shielded_px[3] :] if v < arg[2]] 

98 ) 

99 return arg[0], arg[1] - ShiPixLinearRow._line(y1, y2, len(arg[1])) 

100 

101 @staticmethod 

102 def _line(y1: float, y2: float, length: int) -> np.array: 

103 xes = np.arange(length) 

104 line = (y2 - y1) / length * xes 

105 line = line - line.min() + min(y1, y2) 

106 return line 

107 

108 

109class ShiPixGlobalMean: 

110 """ 

111 Algorithm for global mean 

112 """ 

113 

114 def __init__(self, config: Config, data: np.array): 

115 self.config = config 

116 self.data = data.astype(float) 

117 

118 def run(self) -> np.array: 

119 for i in range(self.data.shape[0]): 

120 #: Correction per frame / state 

121 self.data[i] = self.data[i] - self._shpx_mean(i) 

122 return self.data 

123 

124 def _shpx_mean(self, frame_no: int) -> np.array: 

125 """ 

126 Compute the mean value of the shielded pixels. 

127 Outliers within the shielded px are removed using a median filter 

128 """ 

129 frame_data = self.data[frame_no] 

130 spix = self.config.cam.shielded_px 

131 shpx_mean = np.concatenate( 

132 ( 

133 frame_data[0 : spix[0], :].flatten(), 

134 frame_data[-spix[1] : -1, :].flatten(), 

135 frame_data[spix[0] : -spix[1], 0 : spix[2]].flatten(), 

136 frame_data[spix[0] : -spix[1], -spix[3] : -1].flatten(), 

137 ) 

138 ) 

139 return np.mean( 

140 signal.medfilt(shpx_mean, kernel_size=self.config.cam.shielded_px_win) 

141 ) 

142 

143 

144class ShiPixMedianCol: 

145 """ 

146 Algorithm for col wise median correction 

147 """ 

148 

149 def __init__(self, config: Config, data: np.array, borders=(True, True)): 

150 self.config = config 

151 self.data = data.astype(float) 

152 self.borders = borders # border to use (top, bottom) 

153 

154 def run(self) -> np.array: 

155 if not self.borders[0] and not self.borders[1]: 

156 raise ValueError("No border selected for shielded pixel correction") 

157 for i in range(self.data.shape[0]): 

158 self._correct_frame(i) 

159 return self.data 

160 

161 def _correct_frame(self, frame_no: int): 

162 res = dict( 

163 MP.threaded( 

164 self._correct_col, 

165 [(i, self.data[frame_no, :, i]) for i in range(self.data.shape[2])], 

166 ) 

167 ) 

168 self.data[frame_no] = np.array([res[i] for i in range(self.data.shape[2])]).T 

169 

170 def _correct_col(self, arg: tuple) -> tuple: 

171 if self.borders[0] and self.borders[1]: 

172 y1 = [ 

173 arg[1][0 : self.config.cam.shielded_px[2]], 

174 arg[1][-self.config.cam.shielded_px[3] :], 

175 ] 

176 elif self.borders[0] and not self.borders[1]: 

177 y1 = [arg[1][0 : self.config.cam.shielded_px[2]]] 

178 elif not self.borders[0] and self.borders[1]: 

179 y1 = [arg[1][-self.config.cam.shielded_px[3] :]] 

180 return arg[0], arg[1] - np.median(y1) 

181 

182 

183class ShiPixColAndRow: 

184 """ 

185 Algorithm for first row wise linear fit 

186 and then column wise linear fit 

187 """ 

188 

189 # if ratio between median of shi pix borders is lareger than 

190 # this value, only the borders with smaller median are used 

191 ILLUM_SHI_PIX_CRITERION = 1.3 # >1 

192 

193 def __init__(self, config: Config, data: np.array): 

194 self.config = config 

195 self.data = data.astype(float) 

196 

197 def run(self) -> np.array: 

198 borders = self._get_valid_borders() 

199 self.data = ShiPixMedianCol(self.config, self.data, borders=borders).run() 

200 self.data = ShiPixLinearRow(self.config, self.data).run() 

201 return self.data 

202 

203 def _get_valid_borders(self) -> tuple: 

204 """ 

205 Get the valid (non illuminated) borders for the shielded px 

206 """ 

207 row_median_top = np.median(self.data[:, 0 : self.config.cam.shielded_px[2]]) 

208 row_median_bottom = np.median(self.data[:, -self.config.cam.shielded_px[3] :]) 

209 ratio = row_median_top / row_median_bottom 

210 if ratio > ShiPixColAndRow.ILLUM_SHI_PIX_CRITERION: 

211 logger.debug( 

212 f"Shielded px levels not equal due to stray light (ratio {ratio}), using only the bottom border" 

213 ) 

214 return (False, True) 

215 elif ratio < 1 / ShiPixColAndRow.ILLUM_SHI_PIX_CRITERION: 

216 logger.debug( 

217 f"Shielded px levels not equal due to stray light (ratio {ratio}), using only the top border" 

218 ) 

219 return (True, False) 

220 return (True, True)