Coverage for src/susi/reduc/pipeline/blocks/block_s.py: 82%
130 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-08-11 10:03 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-08-11 10:03 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4Module that holds the block `S` of the susi pipeline.
6@author: iglesias
7"""
8import os
10import numpy as np
11from spectroflat.sensor.flat import Flat
12from scipy.signal import savgol_filter
13from src.susi import InsufficientDataException
15from .block import Block, BlockRegistry
16from ..processing_data import ProcessingData
17from ....base import Logging, IllegalStateException
18from ....base.header_keys import *
19from ....io import FitsBatch, Fits, Card
20from ....utils import Collections
21from ....utils.cropping import adapt_shape
22from ....utils.header_checks import check_same_binning
23from ....utils.sub_shift import Shifter
24from ....utils.sub_shift import round_shift_for_roi
25from ...fields.flat_field_slit_correction import SlitFFCorrector
28log = Logging.get_logger()
31class BlockS(Block):
32 """
33 ## BLOCK S: Shift and Slit Flat fielding
35 This block takes care of the following calibration steps:
36 - Shift image in slit direction to the desired reference using the input
37 measured slit mask shift (a smooth SLIT_FLAT_OFFSET from the metadata file)
38 - Apply Slit Flat Correction
39 """
41 BLOCK_ID = 'S'
43 @staticmethod
44 def start(batch: FitsBatch, proc_data: ProcessingData) -> FitsBatch:
45 """
46 Callback-like entry point for the block
47 """
48 return BlockS(batch, proc_data).run().result
50 def _algorithm(self):
51 self._prep_flat()
52 self._apply_binning_factor()
53 for entry in self.batch:
54 shift = self._get_shift(entry)
55 self.result.batch.append(self._process_entry(entry, shift))
57 def _process_entry(self, fits: Fits, shift) -> dict:
58 return {
59 'file': fits.path,
60 'data': self._modify_data(fits, shift),
61 'header': self._modify_header(fits, shift),
62 }
64 def _modify_data(self, fits: Fits, shift) -> np.array:
65 # check data shape to avoid non implemented cases
66 if fits.data.shape[0] != 1:
67 raise InsufficientDataException('Block S only works with single frame data for now')
68 fits = self._shift_image(fits, shift)
69 # crop to common roi of the whole shifted dataset
70 fits.data = adapt_shape(fits.data, fits.header, self.roi)
71 # divide by shifted flat
72 if self.proc_data.config.base.slit_flat_corr_block_s:
73 ff_data = np.array(Flat.save_divide(fits.data[0], self._mod_flat(int(fits.header[MOD_STATE]))))
74 return ff_data[np.newaxis, :, :]
75 else:
76 return fits.data
78 def _mod_flat(self, state: int) -> np.array:
79 if self.proc_data.slit_flat.data.shape[0] == 1:
80 return self.proc_data.slit_flat.data[0]
81 else:
82 return self.proc_data.slit_flat.data[state]
84 def _get_shift(self, entry: Fits) -> float:
85 file = os.path.basename(entry.path)
86 try:
87 idx = self.proc_data.slit_shifts['files'].index(file)
88 except ValueError:
89 raise ValueError(f'File {file} not found in self.proc_data.slit_shifts')
90 yshift = self.proc_data.slit_shifts['shifts'][idx]
91 return np.array([yshift, 0])
93 def _prep_flat(self):
94 # no flat nedeed case
95 if self.proc_data.config.base.slit_flat_corr_block_s is False:
96 self.roi = self.proc_data.slit_flat_shift_ref['common_roi']
97 return
98 # shift flat to current data reference.
99 # The flat was already shifted to its own reference, so we use SLIT_FLAT_REF_OFFSET
100 self.flat_shift = self.proc_data.slit_flat.header[SLIT_FLAT_REF_OFFSET]
101 self.flat_shift = self.flat_shift - self.proc_data.slit_flat_shift_ref['offset']
102 shifted_flat = []
103 for i in range(self.proc_data.slit_flat.data.shape[0]):
104 sh_img = Shifter.d2shift(self.proc_data.slit_flat.data[i], np.array([self.flat_shift, 0]))
105 shifted_flat.append(sh_img)
106 self.proc_data.slit_flat.data = np.array(shifted_flat)
107 # finds common roi with batch
108 self._get_common_flat_roi()
109 # adapt flat shape
110 self.proc_data.slit_flat.data = adapt_shape(
111 self.proc_data.slit_flat.data, self.proc_data.slit_flat.header, self.roi
112 )
113 self.proc_data.slit_flat.header[ROI_Y0] = int(self.roi['y0'])
114 self.proc_data.slit_flat.header[ROI_Y1] = int(self.roi['y1'])
115 self.proc_data.slit_flat.header[ROI_X0] = int(self.roi['x0'])
116 self.proc_data.slit_flat.header[ROI_X1] = int(self.roi['x1'])
118 def _shift_image(self, fits: Fits, shift) -> np.array:
119 # positive shift moves torwards 0
120 fits.data = Shifter.d2shift(fits.data[0], shift)
121 fits.data = fits.data[np.newaxis, :, :]
122 return fits
124 def _apply_binning_factor(self):
125 check_same_binning(self.batch)
126 bins = self.batch[0]['header'][SPATIAL_BIN] if SPATIAL_BIN in self.batch[0]['header'] else None
127 if bins is None:
128 return
129 # TODO
130 raise NotImplementedError('Binning in Block S is not implemented yet')
131 bins = bins.split(',')
132 # Do not bin mod state dimension
133 bins.insert(0, 1)
134 self.proc_data.slit_flat.data = Collections.bin(self.proc_data.slit_flat.data, Collections.as_int_array(bins))
136 def _get_common_flat_roi(self) -> None:
137 # TODO move to orechestrator ?
138 flat_roi = self._get_shape(self.proc_data.slit_flat)
139 # get usable roi of shifted flat
140 shift = round_shift_for_roi(self.flat_shift)
141 if shift > 0:
142 flat_roi['y1'] = flat_roi['y1'] - shift
143 elif shift < 0:
144 flat_roi['y0'] = flat_roi['y0'] - shift
145 # common area between shifted flat and shifted frames
146 data_roi = self.proc_data.slit_flat_shift_ref['common_roi']
147 self.roi = {
148 "y0": max(flat_roi["y0"], data_roi["y0"]),
149 "y1": min(flat_roi["y1"], data_roi["y1"]),
150 "x0": max(flat_roi["x0"], data_roi["x0"]),
151 "x1": min(flat_roi["x1"], data_roi["x1"]),
152 }
154 def _get_shape(self, fits: Fits) -> dict:
155 h = fits.header
156 x0 = int(h[ROI_X0]) if ROI_X0 in h else 0
157 y0 = int(h[ROI_Y0]) if ROI_Y0 in h else 0
158 x1 = int(h[ROI_X1]) if ROI_X1 in h else fits.data.shape[2]
159 y1 = int(h[ROI_Y1]) if ROI_Y1 in h else fits.data.shape[1]
160 return {'y0': y0, 'y1': y1, 'x0': x0, 'x1': x1}
162 def _modify_header(self, fits: Fits, shift):
163 Fits.override_header(fits.header, Block.BLOCKS_APPLIED, BlockS.BLOCK_ID, append=True)
164 if self.proc_data.config.base.slit_flat_corr_block_s:
165 Fits.override_header(
166 fits.header, FLAT_MAP_MOVING, value=os.path.basename(self.proc_data.slit_flat.path).split('.')[0]
167 )
168 fits.header.append(Card(SHIFT_APPLIED_MOVING_FLAT, value=self.flat_shift, comment='[px]').to_card())
169 fits.header.append(
170 Card(SLIT_FLAT_REF_OFFSET, value=self.proc_data.slit_flat_shift_ref['offset'], comment='[px]').to_card()
171 )
172 fits.header.append(Card(SLIT_FLAT_REF_SLOPE, value=self.proc_data.slit_flat_shift_ref['slope']).to_card())
173 fits.header.append(
174 Card(SLIT_FLAT_REF_FILE, value=os.path.basename(self.proc_data.slit_flat_shift_ref['file'])).to_card()
175 )
176 if SHIFT_APPLIED not in fits.header:
177 shift_str = f'{shift[0]:.6f}, {shift[1]:.6f}'
178 fits.header.append(Card(SHIFT_APPLIED, value=shift_str, comment='[px]').to_card())
179 else:
180 shift_str = fits.header[SHIFT_APPLIED].split(',')
181 shift_str[0] = f'{float(shift_str[0]) + shift[0]:.6f}'
182 shift_str[1] = f'{float(shift_str[1]) + shift[1]:.6f}'
183 shift_str = ', '.join(shift_str)
184 fits.header.append(Card(SHIFT_APPLIED, value=shift_str, comment='[px]').to_card())
185 self._update_roi_header(fits.header)
186 self._update_rms_snr_mean(fits)
187 return fits.header
189 @staticmethod
190 def prepare(proc_data: ProcessingData, files=None) -> None:
191 flat_nedeed = proc_data.slit_flat is None and proc_data.config.base.slit_flat_corr_block_s
192 if not proc_data.config.is_cam3() and flat_nedeed:
193 raise IllegalStateException('No SLIT FLAT given')
195 log.info('Starting slit flat shift estimation, this may take a while...')
196 ff = SlitFFCorrector(files, proc_data.slit_flat, proc_data.config)
197 ff.compute_shifts(
198 plot=True,
199 slit_roi=proc_data.config.reduc.slit_ff_slit_roi,
200 wl_roi=proc_data.config.reduc.slit_ff_wl_roi,
201 ini_val=proc_data.config.reduc.slit_ff_global_ini,
202 )
203 proc_data.slit_shifts = {
204 'files': ff.files,
205 'shifts': ff.shifts,
206 }
207 proc_data.slit_flat_shift_ref = {
208 'file': proc_data.slit_flat.path,
209 }
210 log.info(f'Slit flat shift estimation done for {len(proc_data.slit_shifts["files"])} files')
211 breakpoint()
214BlockRegistry().register(BlockS.BLOCK_ID, BlockS)