Coverage for src/susi/reduc/pipeline/blocks/block_f.py: 86%
109 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-08-22 09:20 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-08-22 09:20 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4Module that holds the block `F` of the susi pipeline.
6@author: hoelken
7"""
8import os
10import numpy as np
11from spectroflat.sensor.flat import Flat
13from .block import Block, BlockRegistry
14from ..processing_data import ProcessingData
15from ....base import Logging, IllegalStateException
16from ....base.header_keys import *
17from ....io import FitsBatch, Fits
18from ....utils import Collections
19from ....utils.cropping import adapt_shapes, adapt_shape
20from ....utils.header_checks import check_same_binning
21from ....utils.sub_shift import Shifter
23log = Logging.get_logger()
26class BlockF(Block):
27 """
28 ## BLOCK F: Flat fielding
30 This block takes care of the following calibration steps:
31 - Sensor Flat Correction
32 - TODO Fringes ?
33 """
35 BLOCK_ID = 'F'
37 @staticmethod
38 def start(batch: FitsBatch, proc_data: ProcessingData) -> FitsBatch:
39 """
40 Callback-like entry point for the block
41 """
42 return BlockF(batch, proc_data).run().result
44 def _algorithm(self):
45 self._merge_flats()
46 self._get_roi()
47 self._adapt_shapes()
48 self._apply_binning_factor()
49 for entry in self.batch:
50 self.result.batch.append(self._process_entry(entry))
52 def _process_entry(self, fits: Fits) -> dict:
53 return {
54 'file': fits.path,
55 'data': self._modify_data(fits),
56 'header': self._modify_header(fits),
57 }
59 def _modify_data(self, fits: Fits) -> np.array:
60 return np.array([Flat.save_divide(fits.data[0], self._mod_flat(int(fits.header[MOD_STATE])))])
62 def _mod_flat(self, state: int) -> np.array:
63 if self.proc_data.config.is_cam3() or self._flat.shape[0] == 1:
64 return self._flat[0]
65 else:
66 return self._flat[state]
68 def _modify_header(self, fits: Fits):
69 Fits.override_header(fits.header, Block.BLOCKS_APPLIED, BlockF.BLOCK_ID, append=True)
70 Fits.override_header(
71 fits.header,
72 FLAT_MAP,
73 value=os.path.basename(self.proc_data.sensor_flat.path).split('.')[0],
74 comment='Fixed Sensor Flat File Used',
75 )
76 if self.proc_data.config.base.slit_flat_corr_block_f:
77 Fits.override_header(
78 fits.header,
79 FLAT_MAP_MOVING,
80 value=os.path.basename(self.proc_data.slit_flat.path).split('.')[0],
81 comment='Fixed Slit Flat File Used in block F',
82 )
84 if self.proc_data.config.base.soft_flat_correction:
85 Fits.override_header(
86 fits.header,
87 FLAT_MAP_SOFT,
88 value=os.path.basename(self.proc_data.soft_flat.path).split('.')[0],
89 comment='Fixed Soft Flat File Used',
90 )
92 if not self.proc_data.config.is_cam3() and self.proc_data.config.base.prefilter_correction:
93 Fits.override_header(
94 fits.header,
95 FLAT_MAP_PREFILTER,
96 value=os.path.basename(self.proc_data.prefilter_map.path).split('.')[0],
97 comment='Fixed Prefilter Map File (Ext. 1) Used',
98 )
99 Fits.override_header(
100 fits.header,
101 FLAT_MAP_PREFILTER_MEAN,
102 value=self.prefilter_map_mean,
103 comment='Mean of Prefilter Map File (Ext. 1) Used',
104 )
106 if self.proc_data.config.reduc.soft_flat_wl_offset is not None:
107 Fits.override_header(
108 fits.header,
109 SOFT_FLAT_WL_OFFSET,
110 value=self.proc_data.config.reduc.soft_flat_wl_offset,
111 comment='WL shift for Soft Flat and Prefilter maps',
112 )
113 self._update_roi_header(fits.header)
114 self._update_rms_snr_mean(fits)
115 return fits.header
117 def _merge_flats(self):
118 # sensor and slit flats
119 if self.proc_data.config.is_cam3() or not self.proc_data.config.base.slit_flat_corr_block_f:
120 self._flat = self.proc_data.sensor_flat.data
121 else:
122 self._flat = self.proc_data.sensor_flat.data * self.proc_data.slit_flat.data
124 # soft flat
125 if not self.proc_data.config.is_cam3() and self.proc_data.config.base.soft_flat_correction:
126 if self.proc_data.config.reduc.soft_flat_wl_offset is not None:
127 self._flat *= self._shift_map_wl(
128 self.proc_data.soft_flat.data, self.proc_data.config.reduc.soft_flat_wl_offset
129 )
130 else:
131 self._flat *= self.proc_data.soft_flat.data
133 # prefilter map
134 if not self.proc_data.config.is_cam3() and self.proc_data.config.base.prefilter_correction:
135 self.prefilter_map_mean = np.mean(self.proc_data.prefilter_map.data)
136 if self.proc_data.config.reduc.soft_flat_wl_offset is not None:
137 self._flat *= self._shift_map_wl(
138 self.proc_data.prefilter_map.data / self.prefilter_map_mean,
139 float(self.proc_data.config.reduc.soft_flat_wl_offset),
140 )
141 else:
142 self._flat *= self.proc_data.prefilter_map.data / self.prefilter_map_mean
144 def _shift_map_wl(self, map: np.array, wl_offset: float) -> np.array:
145 """
146 Shift the map by the given wavelength offset
147 """
148 if wl_offset == 0:
149 return map
150 px_offset = wl_offset / float(self.proc_data.wl_cal_axis.header[DISPERSION])
151 log.warning(f'Shifting flat map by {wl_offset} nm ({px_offset} px)')
152 if np.ndim(map) == 2:
153 return Shifter.d2shift(map, [0, -px_offset])
154 elif np.ndim(map) == 3:
155 return Shifter.d2shift(map, [0, 0, -px_offset])
156 else:
157 raise IllegalStateException(f'Flat map has unexpected shape {map.shape}, cannot shift it')
159 def _adapt_shapes(self):
160 self.batch = adapt_shapes(self.batch, self.roi)
161 self._flat = adapt_shape(self._flat, self.proc_data.sensor_flat.header, self.roi)
163 def _apply_binning_factor(self):
164 check_same_binning(self.batch)
165 bins = self.batch[0]['header'][SPATIAL_BIN] if SPATIAL_BIN in self.batch[0]['header'] else None
166 if bins is None:
167 return
168 bins = bins.split(',')
169 # Do not bin mod state dimension
170 bins.insert(0, 1)
171 self._flat = Collections.bin(self._flat, Collections.as_int_array(bins))
173 def _get_roi(self) -> None:
174 self._target_roi(self._flat_shape())
176 def _flat_shape(self) -> dict:
177 h = self.proc_data.sensor_flat.header
178 x0 = int(h[ROI_X0]) if ROI_X0 in h else 0
179 y0 = int(h[ROI_Y0]) if ROI_Y0 in h else 0
180 x1 = int(h[ROI_X1]) if ROI_X1 in h else self._flat.shape[2]
181 y1 = int(h[ROI_Y1]) if ROI_Y1 in h else self._flat.shape[1]
182 return {'y0': y0, 'y1': y1, 'x0': x0, 'x1': x1}
184 @staticmethod
185 def prepare(proc_data: ProcessingData, files=None) -> None:
186 if proc_data.sensor_flat is None:
187 raise IllegalStateException('No SENSOR FLAT given')
188 slit_needed = proc_data.slit_flat is None and proc_data.config.base.slit_flat_corr_block_f
189 if not proc_data.config.is_cam3() and slit_needed:
190 raise IllegalStateException('No SLIT FLAT given')
191 if proc_data.config.base.soft_flat_correction and proc_data.soft_flat is None:
192 raise IllegalStateException('No SOFT FLAT file given')
193 if proc_data.config.base.prefilter_correction and proc_data.prefilter_map is None:
194 raise IllegalStateException('No amended SOFT FLAT file given to extract prefilter from FITS extension 1')
195 if proc_data.config.reduc.soft_flat_wl_offset is not None and proc_data.wl_cal_axis is None:
196 raise IllegalStateException('No amended soft_flat given to shift (need wl axis), see soft_flat_wl_offset')
198 @staticmethod
199 def input_needed(cam3: bool = False) -> list:
200 return ['sensor_flat'] # see comment in the block parent class
203BlockRegistry().register(BlockF.BLOCK_ID, BlockF)