Coverage for src/susi/reduc/fields/flat_field_correction.py: 67%
61 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4module for flat field correction provides LineRemover
6@author: hoelken
7"""
8import numpy as np
10from ...io import Fits
11from ...base import Logging, DataMissMatchException
13logger = Logging.get_logger()
16class LineRemover:
17 """
18 class removes the vertical features (i.e. absorption/emission lines) from the flat field image.
19 The input datacube must be de-smiled.
20 """
22 def __init__(self, cube: np.array, peak_deviation: float = 10.0):
23 #: The original datacube
24 self.cube = cube
25 #: The value in percent how much a pixel may deviate from the columns mean before being masked out.
26 self.peak_deviation = peak_deviation
27 #: The resulting Flat Field
28 self.result = []
30 def run(self):
31 """
32 Iterates over all mod states in the image cube and removes the vertical features
33 while maintaining the vertical gradient.
34 """
35 self.result = np.array([self.__remove_lines(s) for s in range(self.cube.shape[0])])
36 return self
38 def __remove_lines(self, state) -> np.array:
39 img = self.cube[state]
40 vertical_mean = np.mean(img, axis=1)
41 img = np.array([img[i] - vertical_mean[i] for i in range(len(vertical_mean))])
42 horizontal_mean = np.mean(img, axis=0)
43 img = np.transpose(img)
44 for x in range(len(horizontal_mean)):
45 img[x] = self.__process_col(img[x], horizontal_mean[x])
46 img = np.transpose(img)
47 return np.array([img[i] + vertical_mean[i] for i in range(len(vertical_mean))])
49 def __process_col(self, col, mean):
50 g = (np.max(col) - np.min(col)) / 100.0
51 peaks = np.where(abs(col - mean) * g > self.peak_deviation)
52 orig_values = []
53 for y in peaks:
54 orig_values.append(col[y])
55 col[y] = mean
56 mean = np.mean(col)
57 col = col / mean
58 for i in range(len(peaks)):
59 col[peaks[i]] = orig_values[i] / mean
60 return col
63class FFCorrector:
64 """
65 Provides algorithm for flat field image correction
66 """
68 def __init__(self, flat: Fits, img: Fits):
69 self.flat = flat
70 self.img = img
72 def run(self, header_check=True):
73 """
74 Applies the flat field correction to the image.
76 ### Params
77 - header_check: [Bool, optional] set to `False` to bypass the header check.
79 ### Returns
80 [Array] of img data with the flat field correction applied (img / flat)
81 """
82 self.__check_header(header_check, 'Camera ID', 'Flat Field is not from the same camera!')
83 self.__check_shapes()
84 return self.__correct_img()
86 def __check_header(self, fail_on_error, key, message):
87 if self.flat.value_of(key) != self.img.value_of(key):
88 if fail_on_error:
89 raise DataMissMatchException(message)
90 else:
91 logger.warning(message)
93 def __check_shapes(self):
94 for shape, array_type in zip([self.flat.data.shape, self.img.data.shape], ['flat', 'img']):
95 if len(shape) != 3:
96 raise ValueError(array_type + " does not have 3 dimensions")
97 if self.flat.data.shape[0] != 1:
98 raise ValueError("Flat Field image must have shape (1, x, y)")
99 if self.flat.data.shape[1:] != self.img.data.shape[1:]:
100 raise DataMissMatchException(
101 f"Shapes do not match: flat: {self.flat.data.shape[1:]} img: {self.img.data.shape[1:]}")
103 def __correct_img(self):
104 result = np.empty(self.img.data.shape)
105 for i in range(self.img.data.shape[0]):
106 result[i] = np.true_divide(self.img.data[i].astype('float32'),
107 self.flat.data[0].astype('float32'),
108 dtype='float64')
109 return result