Coverage for src/susi/reduc/fields/shielded_px_correction.py: 77%
109 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4module for shielded pixels' correction provides SHIPXCorrector
6@author: iglesias, hoelken
7"""
9import numpy as np
10from scipy import signal
12from ...base import Logging, Config
13from ...io import Fits
14from ...utils import MP
16logger = Logging.get_logger()
19class SHIPXCorrector:
20 """
21 Provides algorithm for shielded pixel correction
22 (aka banding correction)
23 """
25 # TODO: Why data need to be a Fits object?
26 def __init__(self, config: Config, data: Fits):
27 self.config = config
28 # Expected data shape is (frames, x, y)
29 self.data = data.data
31 def run(self) -> np.array:
32 """
33 Applies the shielded pixel correction to the average image.
34 Subtracts the mean value of all the shielded px.
35 Outliers within the shielded px are removed using a median filter
37 :return: [Array] of img data with the shielded pixel correction applied
38 """
39 if self.config.cam.shielded_px_mode == "N/A":
40 return self.data
41 if self.config.cam.shielded_px_mode == "mean":
42 return ShiPixGlobalMean(self.config, self.data).run()
43 if self.config.cam.shielded_px_mode == "linear_row":
44 return ShiPixLinearRow(self.config, self.data).run()
45 if self.config.cam.shielded_px_mode == "median_col":
46 return ShiPixMedianCol(self.config, self.data).run()
47 if self.config.cam.shielded_px_mode == "linear_row_and_median_col":
48 return ShiPixColAndRow(self.config, self.data).run()
50 logger.warning(
51 "Value of config.cam_shielded_px_mode not recognized, skipping shielded px correction"
52 )
53 return self.data
56class ShiPixLinearRow:
57 """
58 Algorithm for row wise linear fit
59 """
61 SIGMA = 2
63 def __init__(self, config: Config, data: np.array):
64 self.config = config
65 self.data = data.astype(float) + 50 # TODO only works with positive data?
67 def run(self) -> np.array:
68 for i in range(self.data.shape[0]):
69 self._correct_frame(i)
70 return self.data
72 def _correct_frame(self, frame_no: int):
73 mean, std = self._compute_dark(frame_no)
74 hot_limit = mean + std * ShiPixLinearRow.SIGMA
75 res = dict(
76 MP.threaded(
77 self._correct_row,
78 [
79 (i, self.data[frame_no, i], hot_limit)
80 for i in range(self.data.shape[1])
81 ],
82 )
83 )
84 self.data[frame_no] = np.array([res[i] for i in range(self.data.shape[1])])
86 def _compute_dark(self, frame_no: int) -> tuple:
87 left = self.data[frame_no][0 : self.config.cam.shielded_px[2]]
88 right = self.data[frame_no][-self.config.cam.shielded_px[3] :]
89 dd = np.concatenate((left, right), axis=1)
90 return np.mean(dd), np.std(dd)
92 def _correct_row(self, arg: tuple) -> tuple:
93 y1 = np.mean(
94 [v for v in arg[1][0 : self.config.cam.shielded_px[2]] if v < arg[2]]
95 )
96 y2 = np.mean(
97 [v for v in arg[1][-self.config.cam.shielded_px[3] :] if v < arg[2]]
98 )
99 return arg[0], arg[1] - ShiPixLinearRow._line(y1, y2, len(arg[1]))
101 @staticmethod
102 def _line(y1: float, y2: float, length: int) -> np.array:
103 xes = np.arange(length)
104 line = (y2 - y1) / length * xes
105 line = line - line.min() + min(y1, y2)
106 return line
109class ShiPixGlobalMean:
110 """
111 Algorithm for global mean
112 """
114 def __init__(self, config: Config, data: np.array):
115 self.config = config
116 self.data = data.astype(float)
118 def run(self) -> np.array:
119 for i in range(self.data.shape[0]):
120 #: Correction per frame / state
121 self.data[i] = self.data[i] - self._shpx_mean(i)
122 return self.data
124 def _shpx_mean(self, frame_no: int) -> np.array:
125 """
126 Compute the mean value of the shielded pixels.
127 Outliers within the shielded px are removed using a median filter
128 """
129 frame_data = self.data[frame_no]
130 spix = self.config.cam.shielded_px
131 shpx_mean = np.concatenate(
132 (
133 frame_data[0 : spix[0], :].flatten(),
134 frame_data[-spix[1] : -1, :].flatten(),
135 frame_data[spix[0] : -spix[1], 0 : spix[2]].flatten(),
136 frame_data[spix[0] : -spix[1], -spix[3] : -1].flatten(),
137 )
138 )
139 return np.mean(
140 signal.medfilt(shpx_mean, kernel_size=self.config.cam.shielded_px_win)
141 )
144class ShiPixMedianCol:
145 """
146 Algorithm for col wise median correction
147 """
149 def __init__(self, config: Config, data: np.array, borders=(True, True)):
150 self.config = config
151 self.data = data.astype(float)
152 self.borders = borders # border to use (top, bottom)
154 def run(self) -> np.array:
155 if not self.borders[0] and not self.borders[1]:
156 raise ValueError("No border selected for shielded pixel correction")
157 for i in range(self.data.shape[0]):
158 self._correct_frame(i)
159 return self.data
161 def _correct_frame(self, frame_no: int):
162 res = dict(
163 MP.threaded(
164 self._correct_col,
165 [(i, self.data[frame_no, :, i]) for i in range(self.data.shape[2])],
166 )
167 )
168 self.data[frame_no] = np.array([res[i] for i in range(self.data.shape[2])]).T
170 def _correct_col(self, arg: tuple) -> tuple:
171 if self.borders[0] and self.borders[1]:
172 y1 = [
173 arg[1][0 : self.config.cam.shielded_px[2]],
174 arg[1][-self.config.cam.shielded_px[3] :],
175 ]
176 elif self.borders[0] and not self.borders[1]:
177 y1 = [arg[1][0 : self.config.cam.shielded_px[2]]]
178 elif not self.borders[0] and self.borders[1]:
179 y1 = [arg[1][-self.config.cam.shielded_px[3] :]]
180 return arg[0], arg[1] - np.median(y1)
183class ShiPixColAndRow:
184 """
185 Algorithm for first row wise linear fit
186 and then column wise linear fit
187 """
189 # if ratio between median of shi pix borders is lareger than
190 # this value, only the borders with smaller median are used
191 ILLUM_SHI_PIX_CRITERION = 1.3 # >1
193 def __init__(self, config: Config, data: np.array):
194 self.config = config
195 self.data = data.astype(float)
197 def run(self) -> np.array:
198 borders = self._get_valid_borders()
199 self.data = ShiPixMedianCol(self.config, self.data, borders=borders).run()
200 self.data = ShiPixLinearRow(self.config, self.data).run()
201 return self.data
203 def _get_valid_borders(self) -> tuple:
204 """
205 Get the valid (non illuminated) borders for the shielded px
206 """
207 row_median_top = np.median(self.data[:, 0 : self.config.cam.shielded_px[2]])
208 row_median_bottom = np.median(self.data[:, -self.config.cam.shielded_px[3] :])
209 ratio = row_median_top / row_median_bottom
210 if ratio > ShiPixColAndRow.ILLUM_SHI_PIX_CRITERION:
211 logger.debug(
212 f"Shielded px levels not equal due to stray light (ratio {ratio}), using only the bottom border"
213 )
214 return (False, True)
215 elif ratio < 1 / ShiPixColAndRow.ILLUM_SHI_PIX_CRITERION:
216 logger.debug(
217 f"Shielded px levels not equal due to stray light (ratio {ratio}), using only the top border"
218 )
219 return (True, False)
220 return (True, True)