Coverage for src/susi/reduc/pipeline/blocks/block_s.py: 86%
122 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4Module that holds the block `S` of the susi pipeline.
6@author: iglesias
7"""
8import os
10import numpy as np
11from spectroflat.sensor.flat import Flat
12from scipy.signal import savgol_filter
13from src.susi import InsufficientDataException
15from .block import Block, BlockRegistry
16from ..processing_data import ProcessingData
17from ....base import Logging, IllegalStateException
18from ....base.header_keys import *
19from ....io import FitsBatch, Fits, Card
20from ....utils import Collections
21from ....utils.cropping import adapt_shape
22from ....utils.header_checks import check_same_binning
23from ....utils.sub_shift import Shifter
24from ....utils.sub_shift import round_shift_for_roi
27log = Logging.get_logger()
30class BlockS(Block):
31 """
32 ## BLOCK S: Shift and Slit Flat fielding
34 This block takes care of the following calibration steps:
35 - Shift image in slit direction to the desired reference using the input
36 measured slit mask shift (a smooth SLIT_FLAT_OFFSET from the metadata file)
37 - Apply Slit Flat Correction
38 """
40 BLOCK_ID = 'S'
42 @staticmethod
43 def start(batch: FitsBatch, proc_data: ProcessingData) -> FitsBatch:
44 """
45 Callback-like entry point for the block
46 """
47 return BlockS(batch, proc_data).run().result
49 def _algorithm(self):
50 self._prep_flat()
51 self._apply_binning_factor()
52 for entry in self.batch:
53 shift = self._get_shift(entry)
54 self.result.batch.append(self._process_entry(entry, shift))
56 def _process_entry(self, fits: Fits, shift) -> dict:
57 return {
58 'file': fits.path,
59 'data': self._modify_data(fits, shift),
60 'header': self._modify_header(fits, shift),
61 }
63 def _modify_data(self, fits: Fits, shift) -> np.array:
64 # check data shape to avoid non implemented cases
65 if fits.data.shape[0] != 1:
66 raise InsufficientDataException('Block S only works with single frame data for now')
67 fits = self._shift_image(fits, shift)
68 # crop to common roi of the whole shifted dataset
69 fits.data = adapt_shape(fits.data, fits.header, self.roi)
70 # divide by shifted flat
71 if self.proc_data.config.base.slit_flat_corr_block_s:
72 ff_data = np.array(Flat.save_divide(fits.data[0], self._mod_flat(int(fits.header[MOD_STATE]))))
73 return ff_data[np.newaxis, :, :]
74 else:
75 return fits.data
77 def _mod_flat(self, state: int) -> np.array:
78 if self.proc_data.slit_flat.data.shape[0] == 1:
79 return self.proc_data.slit_flat.data[0]
80 else:
81 return self.proc_data.slit_flat.data[state]
83 def _get_shift(self, entry: Fits) -> float:
84 file = os.path.basename(entry.path)
85 try:
86 idx = self.proc_data.slit_shifts['files'].index(file)
87 except ValueError:
88 raise ValueError(f'File {file} not found in self.proc_data.slit_shifts')
89 yshift = self.proc_data.slit_shifts['shifts'][idx]
90 return np.array([yshift, 0])
92 def _prep_flat(self):
93 # no flat nedeed case
94 if self.proc_data.config.base.slit_flat_corr_block_s is False:
95 self.roi = self.proc_data.slit_flat_shift_ref['common_roi']
96 return
97 # shift flat to current data reference.
98 # The flat was already shifted to its own reference, so we use SLIT_FLAT_REF_OFFSET
99 self.flat_shift = self.proc_data.slit_flat.header[SLIT_FLAT_REF_OFFSET]
100 self.flat_shift = self.flat_shift - self.proc_data.slit_flat_shift_ref['offset']
101 shifted_flat = []
102 for i in range(self.proc_data.slit_flat.data.shape[0]):
103 sh_img = Shifter.d2shift(self.proc_data.slit_flat.data[i], np.array([self.flat_shift, 0]))
104 shifted_flat.append(sh_img)
105 self.proc_data.slit_flat.data = np.array(shifted_flat)
106 # finds common roi with batch
107 self._get_common_flat_roi()
108 # adapt flat shape
109 self.proc_data.slit_flat.data = adapt_shape(
110 self.proc_data.slit_flat.data, self.proc_data.slit_flat.header, self.roi
111 )
112 self.proc_data.slit_flat.header[ROI_Y0] = int(self.roi['y0'])
113 self.proc_data.slit_flat.header[ROI_Y1] = int(self.roi['y1'])
114 self.proc_data.slit_flat.header[ROI_X0] = int(self.roi['x0'])
115 self.proc_data.slit_flat.header[ROI_X1] = int(self.roi['x1'])
117 def _shift_image(self, fits: Fits, shift) -> np.array:
118 # positive shift moves torwards 0
119 fits.data = Shifter.d2shift(fits.data[0], shift)
120 fits.data = fits.data[np.newaxis, :, :]
121 return fits
123 def _apply_binning_factor(self):
124 check_same_binning(self.batch)
125 bins = self.batch[0]['header'][SPATIAL_BIN] if SPATIAL_BIN in self.batch[0]['header'] else None
126 if bins is None:
127 return
128 # TODO
129 raise NotImplementedError('Binning in Block S is not implemented yet')
130 bins = bins.split(',')
131 # Do not bin mod state dimension
132 bins.insert(0, 1)
133 self.proc_data.slit_flat.data = Collections.bin(self.proc_data.slit_flat.data, Collections.as_int_array(bins))
135 def _get_common_flat_roi(self) -> None:
136 # TODO move to orechestrator ?
137 flat_roi = self._get_shape(self.proc_data.slit_flat)
138 # get usable roi of shifted flat
139 shift = round_shift_for_roi(self.flat_shift)
140 if shift > 0:
141 flat_roi['y1'] = flat_roi['y1'] - shift
142 elif shift < 0:
143 flat_roi['y0'] = flat_roi['y0'] - shift
144 # common area between shifted flat and shifted frames
145 data_roi = self.proc_data.slit_flat_shift_ref['common_roi']
146 self.roi = {
147 "y0": max(flat_roi["y0"], data_roi["y0"]),
148 "y1": min(flat_roi["y1"], data_roi["y1"]),
149 "x0": max(flat_roi["x0"], data_roi["x0"]),
150 "x1": min(flat_roi["x1"], data_roi["x1"]),
151 }
153 def _get_shape(self, fits: Fits) -> dict:
154 h = fits.header
155 x0 = int(h[ROI_X0]) if ROI_X0 in h else 0
156 y0 = int(h[ROI_Y0]) if ROI_Y0 in h else 0
157 x1 = int(h[ROI_X1]) if ROI_X1 in h else fits.data.shape[2]
158 y1 = int(h[ROI_Y1]) if ROI_Y1 in h else fits.data.shape[1]
159 return {'y0': y0, 'y1': y1, 'x0': x0, 'x1': x1}
161 def _modify_header(self, fits: Fits, shift):
162 Fits.override_header(fits.header, Block.BLOCKS_APPLIED, BlockS.BLOCK_ID, append=True)
163 if self.proc_data.config.base.slit_flat_corr_block_s:
164 Fits.override_header(
165 fits.header, FLAT_MAP_MOVING, value=os.path.basename(self.proc_data.slit_flat.path).split('.')[0]
166 )
167 fits.header.append(Card(SHIFT_APPLIED_MOVING_FLAT, value=self.flat_shift, comment='[px]').to_card())
168 fits.header.append(
169 Card(SLIT_FLAT_REF_OFFSET, value=self.proc_data.slit_flat_shift_ref['offset'], comment='[px]').to_card()
170 )
171 fits.header.append(Card(SLIT_FLAT_REF_SLOPE, value=self.proc_data.slit_flat_shift_ref['slope']).to_card())
172 fits.header.append(
173 Card(SLIT_FLAT_REF_FILE, value=os.path.basename(self.proc_data.slit_flat_shift_ref['file'])).to_card()
174 )
175 if SHIFT_APPLIED not in fits.header:
176 shift_str = f'{shift[0]:.6f}, {shift[1]:.6f}'
177 fits.header.append(Card(SHIFT_APPLIED, value=shift_str, comment='[px]').to_card())
178 else:
179 shift_str = fits.header[SHIFT_APPLIED].split(',')
180 shift_str[0] = f'{float(shift_str[0]) + shift[0]:.6f}'
181 shift_str[1] = f'{float(shift_str[1]) + shift[1]:.6f}'
182 shift_str = ', '.join(shift_str)
183 fits.header.append(Card(SHIFT_APPLIED, value=shift_str, comment='[px]').to_card())
184 self._update_roi_header(fits.header)
185 self._update_rms_snr_mean(fits)
186 return fits.header
188 @staticmethod
189 def prepare(proc_data: ProcessingData) -> None:
190 flat_nedeed = proc_data.slit_flat is None and proc_data.config.base.slit_flat_corr_block_s
191 if not proc_data.config.is_cam3() and flat_nedeed:
192 raise IllegalStateException('No SLIT FLAT given')
195BlockRegistry().register(BlockS.BLOCK_ID, BlockS)