Coverage for src/susi/reduc/fields/flat_field_correction.py: 67%

61 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4module for flat field correction provides LineRemover 

5 

6@author: hoelken 

7""" 

8import numpy as np 

9 

10from ...io import Fits 

11from ...base import Logging, DataMissMatchException 

12 

13logger = Logging.get_logger() 

14 

15 

16class LineRemover: 

17 """ 

18 class removes the vertical features (i.e. absorption/emission lines) from the flat field image. 

19 The input datacube must be de-smiled. 

20 """ 

21 

22 def __init__(self, cube: np.array, peak_deviation: float = 10.0): 

23 #: The original datacube 

24 self.cube = cube 

25 #: The value in percent how much a pixel may deviate from the columns mean before being masked out. 

26 self.peak_deviation = peak_deviation 

27 #: The resulting Flat Field 

28 self.result = [] 

29 

30 def run(self): 

31 """ 

32 Iterates over all mod states in the image cube and removes the vertical features 

33 while maintaining the vertical gradient. 

34 """ 

35 self.result = np.array([self.__remove_lines(s) for s in range(self.cube.shape[0])]) 

36 return self 

37 

38 def __remove_lines(self, state) -> np.array: 

39 img = self.cube[state] 

40 vertical_mean = np.mean(img, axis=1) 

41 img = np.array([img[i] - vertical_mean[i] for i in range(len(vertical_mean))]) 

42 horizontal_mean = np.mean(img, axis=0) 

43 img = np.transpose(img) 

44 for x in range(len(horizontal_mean)): 

45 img[x] = self.__process_col(img[x], horizontal_mean[x]) 

46 img = np.transpose(img) 

47 return np.array([img[i] + vertical_mean[i] for i in range(len(vertical_mean))]) 

48 

49 def __process_col(self, col, mean): 

50 g = (np.max(col) - np.min(col)) / 100.0 

51 peaks = np.where(abs(col - mean) * g > self.peak_deviation) 

52 orig_values = [] 

53 for y in peaks: 

54 orig_values.append(col[y]) 

55 col[y] = mean 

56 mean = np.mean(col) 

57 col = col / mean 

58 for i in range(len(peaks)): 

59 col[peaks[i]] = orig_values[i] / mean 

60 return col 

61 

62 

63class FFCorrector: 

64 """ 

65 Provides algorithm for flat field image correction 

66 """ 

67 

68 def __init__(self, flat: Fits, img: Fits): 

69 self.flat = flat 

70 self.img = img 

71 

72 def run(self, header_check=True): 

73 """ 

74 Applies the flat field correction to the image. 

75 

76 ### Params 

77 - header_check: [Bool, optional] set to `False` to bypass the header check. 

78 

79 ### Returns 

80 [Array] of img data with the flat field correction applied (img / flat) 

81 """ 

82 self.__check_header(header_check, 'Camera ID', 'Flat Field is not from the same camera!') 

83 self.__check_shapes() 

84 return self.__correct_img() 

85 

86 def __check_header(self, fail_on_error, key, message): 

87 if self.flat.value_of(key) != self.img.value_of(key): 

88 if fail_on_error: 

89 raise DataMissMatchException(message) 

90 else: 

91 logger.warning(message) 

92 

93 def __check_shapes(self): 

94 for shape, array_type in zip([self.flat.data.shape, self.img.data.shape], ['flat', 'img']): 

95 if len(shape) != 3: 

96 raise ValueError(array_type + " does not have 3 dimensions") 

97 if self.flat.data.shape[0] != 1: 

98 raise ValueError("Flat Field image must have shape (1, x, y)") 

99 if self.flat.data.shape[1:] != self.img.data.shape[1:]: 

100 raise DataMissMatchException( 

101 f"Shapes do not match: flat: {self.flat.data.shape[1:]} img: {self.img.data.shape[1:]}") 

102 

103 def __correct_img(self): 

104 result = np.empty(self.img.data.shape) 

105 for i in range(self.img.data.shape[0]): 

106 result[i] = np.true_divide(self.img.data[i].astype('float32'), 

107 self.flat.data[0].astype('float32'), 

108 dtype='float64') 

109 return result