Coverage for src/susi/reduc/shift/slit_shift.py: 95%
82 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4provides slit shift class
6@author: iglesias
7"""
8from tkinter import E
9from weakref import ref
10import numpy as np
11import copy
12import os
14from scipy.sparse.tests.test_array_api import B
15from ...base import Logging, Globals
16from ...io import FitsBatch
17from ...base.header_keys import *
18from ...utils.cropping import common_shape
19from ...utils.sub_shift import round_shift_for_roi
22logger = Logging.get_logger()
25class SlitShiftRef:
26 """
27 1-Selects the frame to be used as reference for the slit flat shift correction
28 2-Read all the shifts in flist and compute the common roi to crop the output data
29 """
31 def __init__(self, files: list, proce_param: int):
32 self.ref_offset = None
33 self.ref_slope = None
34 self.ref_file = None
35 self.common_roi = None
36 self.files = files
37 self.proce_param = proce_param
39 def run(self):
40 batch = self._load_data()
41 self._get_reference(batch)
42 shifts, common_roi = SlitShiftProcessor(batch, self.ref_offset, self.proce_param).run()
43 logger.info(
44 f'shift ref. ={self.ref_file}: offset={self.ref_offset}, slope={self.ref_slope}, Common ROI: {common_roi}'
45 )
46 files = [os.path.basename(batch[f]['file']) for f in range(len(batch))]
47 shifts_dict = {'files': files, 'shifts': shifts}
48 ref_dict = {'offset': self.ref_offset, 'slope': self.ref_slope, 'file': self.ref_file, 'common_roi': common_roi}
49 return ref_dict, shifts_dict
51 def _load_data(self):
52 # TODO change to read SLIT_FLAT_OFFSET and SLIT_FLAT_SLOPE from metadata files instead the header
53 batch = FitsBatch()
54 logger.info(f'Reading the header of {len(self.files)} files to compute shift reference and out roi')
55 batch.load(self.files, header_only=True, sort_by=DATE_OBS)
56 return batch
58 def _get_reference(self, batch):
59 fits = batch[len(batch) // 2]
60 self.ref_offset = fits['header'][SLIT_FLAT_OFFSET]
61 self.ref_slope = fits['header'][SLIT_FLAT_SLOPE]
62 self.ref_file = fits['file']
65class SlitShiftProcessor:
66 """
67 Class to process slit shifts from a batch (assume is sorted by OBS_TIME)
68 1- get shifts from frm header
69 2- smooth the shifts
70 3- subtract the reference shift
71 4- apply the same shift for each modulation cycle
72 5- compute the common roi to crop the output data
73 """
75 def __init__(self, batch, ref_shift, proce_param):
76 self.batch = copy.deepcopy(batch.header_copy())
77 self.data_shape = batch[0]['data'].shape if batch[0]['data'] is not None else None
78 self.proce_param = proce_param
79 self.ref_shift = ref_shift
80 # results
81 self.shifts = []
82 self.common_roi = None
84 def run(self):
85 self._get_offset()
86 self._get_shifts()
87 self._get_common_shape()
88 return self.shifts, self.common_roi
90 def _get_offset(self):
91 for fits in self.batch:
92 self.shifts.append(np.float64(fits.header[SLIT_FLAT_OFFSET]))
94 def _get_shifts(self):
95 # sanity check
96 if len(self.shifts) != len(self.batch):
97 raise ValueError('Number of shifts different from number of frames')
98 if len(self.shifts) < 3:
99 raise ValueError('Number of frames is less than 3, cant smooth/fit the slit shifts')
100 # fit polinomial of order self.proce_param to the shifts
101 fit = np.polyfit(range(len(self.shifts)), self.shifts, self.proce_param)
102 self.shifts = np.polyval(fit, range(len(self.shifts)))
103 # subtract the reference shift
104 self.shifts = np.array(self.shifts) - self.ref_shift
105 # Use the same avg shift for each modulation cycle
106 for i in range(len(self.shifts)):
107 if self.batch[i]['header'][MOD_STATE] == '0':
108 cycle_shift = np.mean(self.shifts[i : i + Globals.MOD_CYCLE_FRAMES])
109 self.shifts[i : i + Globals.MOD_CYCLE_FRAMES] = cycle_shift
111 def _get_common_shape(self):
112 for i in range(len(self.batch)):
113 # modify ROI keywords to account for the shift that will be implemented
114 shift = round_shift_for_roi(self.shifts[i])
115 if shift > 0: # moves up and thus reduces Y1 to account for the lost roi
116 self.batch[i]['header'][ROI_Y1] = int(float(self.batch[i]['header'][ROI_Y1]) - shift)
117 elif shift < 0: # moves down and thus increases Y0 to account for the lost roi
118 self.batch[i]['header'][ROI_Y0] = int(float(self.batch[i]['header'][ROI_Y0]) - shift)
119 else:
120 continue
121 # get the common roi of the full batch
122 self.common_roi = common_shape(self.batch)
123 # reduce common roi by 1 to avoid numeric effects in rounding shifts
124 self.common_roi['y0'] += 1
125 self.common_roi['y1'] -= 1
126 # check that there is no frame with roi smaller than the common roi
127 for i in range(len(self.batch)):
128 if (self.batch[i]['header'][ROI_Y1] < self.common_roi['y1']) or (
129 self.batch[i]['header'][ROI_Y0] > self.common_roi['y0']
130 ):
131 raise ValueError(f'ROI of frame {i} is smaller than common roi')