Coverage for src/susi/reduc/pipeline/blocks/block_d.py: 93%
104 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4Module that holds the block `D` of the susi pipeline.
6@author: hoelken
7"""
8import os
10import numpy as np
11from astropy.io import fits
13from .block import Block, BlockRegistry
14from ..processing_data import ProcessingData
15from ...demodulation import Demodulator, DemodAssembler
16from ....base import Logging, IllegalStateException
17from ....base.header_keys import *
18from ....db import FileDB
19from ....io import FitsBatch, Fits, Card
20from ....utils import Collections
21from ....utils.cropping import adapt_shapes, adapt_shape, common_shape
22from ....utils.header_checks import check_same_binning
24log = Logging.get_logger()
27class BlockD(Block):
28 """
29 ## BLOCK D: Demodulation
31 This blocks will apply the demodulation matrix to the FitsBatch.
32 """
34 BLOCK_ID = 'D'
36 @staticmethod
37 def start(batch: FitsBatch, proc_data: ProcessingData) -> FitsBatch:
38 """
39 Callback-like entry point for the block
40 """
41 return BlockD(batch, proc_data).run().result
43 def _algorithm(self):
44 if self.proc_data.config.is_cam3():
45 self.result = self.batch
46 return
48 self._get_roi()
49 self._adapt_shapes()
50 self._apply_binning_factor()
51 for block in DemodAssembler(self.db, self.batch).run():
52 stokes = Demodulator(self.proc_data.demod_matrix, block['data'], block['states'], block['name']).run()
53 self._process_entry(stokes, block['start'])
55 def _get_roi(self) -> None:
56 self._target_roi(self._demod_mat_shape())
58 def _demod_mat_shape(self) -> dict:
59 # when using a ROW_WISE demodulation, we extend it to fill the whole X dim
60 if self.proc_data.config.spol.demod_mode == 'ROW_WISE':
61 data_roi = common_shape(self.batch)
62 self.proc_data.mod_matrix.header[ROI_X0] = data_roi['x0']
63 self.proc_data.mod_matrix.header[ROI_X1] = data_roi['x1']
64 h = self.proc_data.mod_matrix.header
65 return {'y0': int(h[ROI_Y0]), 'y1': int(h[ROI_Y1]), 'x0': int(h[ROI_X0]), 'x1': int(h[ROI_X1])}
67 def _adapt_shapes(self):
68 self.batch = adapt_shapes(self.batch, self.roi)
69 if self.proc_data.config.spol.demod_mode == 'ROW_WISE':
70 meanm = np.median(self.proc_data.demod_matrix, axis=1, keepdims=True)
71 self.proc_data.demod_matrix = np.repeat(meanm, self.roi['x1'] - self.roi['x0'], axis=1)
72 self.proc_data.demod_matrix = adapt_shape(
73 self.proc_data.demod_matrix, self.proc_data.mod_matrix.header, self.roi
74 )
76 def _apply_binning_factor(self):
77 check_same_binning(self.batch)
78 bins = self.batch[0]['header'][SPATIAL_BIN] if SPATIAL_BIN in self.batch[0]['header'] else None
79 if bins is None:
80 return
81 bins = bins.split(',')
82 # Do not bin output dimension
83 bins.append(1)
84 # Do not bin input dimension
85 bins.append(1)
86 self.proc_data.demod_matrix = Collections.bin(self.proc_data.demod_matrix, Collections.as_int_array(bins))
88 def _process_entry(self, obj: Fits, start_file: str):
89 self.result.batch.append(
90 {
91 'file': obj.path,
92 'data': obj.data,
93 'header': self._merged_header(start_file),
94 }
95 )
97 def _merged_header(self, template_file: str):
98 header = self.batch[template_file]['header']
99 new_header = fits.PrimaryHDU().header
100 for key in self.proc_data.config.base.header_copy_fields:
101 if key not in header or key == MOD_STATE:
102 continue
103 new_header.append(Card.from_orig(header, key).to_card())
104 Fits.override_header(new_header, BLOCKS_APPLIED, BlockD.BLOCK_ID, append=True)
105 modm_name = os.path.basename(os.path.dirname(os.path.dirname(self.proc_data.mod_matrix.path)))
106 new_header.append(Card(DEMOD_MAT, value=modm_name).to_card())
107 new_header.append(Card(DEMOD_MODE, value=self.proc_data.config.spol.demod_mode).to_card())
108 self._update_roi_header(new_header)
109 return new_header
111 @staticmethod
112 def prepare(proc_data: ProcessingData) -> None:
113 if proc_data.config.is_cam3():
114 return
115 if proc_data.demod_matrix is not None:
116 return
117 if proc_data.mod_matrix is None:
118 raise IllegalStateException('Neither modulation nor demodulation matrix given')
120 log.info('Preparing demodulation matrix...')
121 # transpose ROI name to match data convetion ROIY is python dim 0 (or vertical in standard matplotlib plot)
122 old_value = proc_data.mod_matrix.header[ROI_X0], proc_data.mod_matrix.header[ROI_Y0]
123 proc_data.mod_matrix.header[ROI_Y0], proc_data.mod_matrix.header[ROI_X0] = old_value
124 old_value = proc_data.mod_matrix.header[ROI_Y1], proc_data.mod_matrix.header[ROI_X1]
125 proc_data.mod_matrix.header[ROI_X1], proc_data.mod_matrix.header[ROI_Y1] = old_value
126 # demodm
127 demo = proc_data.mod_matrix.data.astype("float64")
128 demo /= demo[:, :, 0:1, 0:1] # norm wrt element 0,0
129 proc_data.demod_matrix = np.linalg.pinv(demo)
130 # invert modm dmiension 0, beacuse it has row 0 at the bottom (old DAQ convention)
131 proc_data.demod_matrix = proc_data.demod_matrix[::-1, :, :, :]
132 old_y0 = proc_data.mod_matrix.header[ROI_Y0]
133 old_y1 = proc_data.mod_matrix.header[ROI_Y1]
134 proc_data.mod_matrix.header[ROI_Y1] = proc_data.config.cam.data_shape[0].stop - old_y0
135 proc_data.mod_matrix.header[ROI_Y0] = proc_data.config.cam.data_shape[0].stop - old_y1
136 proc_data.mod_matrix.data = None
138 @staticmethod
139 def input_needed(cam3: bool = False) -> list:
140 if cam3:
141 return []
142 else:
143 return ['demod_matrix']
145 @staticmethod
146 def predict_output(batch: FitsBatch, proc_data: ProcessingData):
147 new_batch = FitsBatch()
148 db = FileDB(proc_data.config)
149 for block in DemodAssembler(db, batch, with_data=False).run():
150 new_batch.batch.append({'file': block['name'], 'header': batch[block['start']]['header'], 'data': None})
152 return new_batch
155BlockRegistry().register(BlockD.BLOCK_ID, BlockD)