Coverage for src/susi/reduc/pipeline/chunker.py: 89%

140 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4Module that provides methods to chunk file lists 

5 

6@author: hoelken 

7""" 

8from __future__ import annotations 

9 

10import os 

11 

12import numpy as np 

13 

14from .blocks import BlockD 

15from ..demodulation import ModStateDetector 

16from ...base import Logging, InsufficientDataException, Config 

17from ...base.header_keys import * 

18from ...db import Metadata 

19from ...io import FitsBatch 

20from ...io.state_link import StateLink 

21from ...utils import Collections 

22from ...utils.reports import create_metadata_report 

23 

24log = Logging.get_logger() 

25 

26 

27class Chunker: 

28 

29 def __init__(self, files: list, config: Config, raw: bool = False): 

30 self.files = files 

31 self.batch = FitsBatch(config) 

32 self.config = config 

33 self.raw = raw 

34 # The map of the modulation states corresponding to the files in batch 

35 self.state_map = None 

36 self.mod_state_blocks = {} 

37 self._fake_states = False 

38 # The generated processing blocks 

39 self.chunks = {} 

40 

41 def run(self) -> Chunker: 

42 self._load_headers() 

43 self._gen_report() 

44 if self._build_raw(): 

45 log.info('Building raw chunks') 

46 self._raw_chunks() 

47 else: 

48 self._detect_mod_sate() 

49 self._remove_overhang_frames() 

50 self._gen_mod_cycle_blocks() 

51 self._build_processing_chunks() 

52 return self 

53 

54 def _gen_report(self): 

55 if self.config.data.log_dir is None: 

56 return 

57 if self.config.data.generate_pdf_reports: 

58 create_metadata_report(self.config, os.path.join(self.config.data.root, self.config.data.level)) 

59 

60 def _build_raw(self): 

61 if self.config.spol.ignore_mod_states: 

62 # Yeah sure, ignore everything. Why not... 

63 return True 

64 if self.batch.is_applied(BlockD.BLOCK_ID): 

65 # Ok, we have demodulated, this makes no sense anymore 

66 return True 

67 if self.raw and any(HK_PMU_ANG not in e['header'] for e in self.batch.batch): 

68 # This is applicable if the Camera HK is not decoded yet 

69 return True 

70 return False 

71 

72 def _load_headers(self) -> None: 

73 leftovers = self.__load_from_db() 

74 self.__load_from_fits(leftovers) 

75 self.__sort_batch() 

76 log.debug('Creating missing metadata entries') 

77 self.batch.write_metadata(overwrite=False) 

78 

79 def __sort_batch(self): 

80 try: 

81 self.batch.sort_by(TIMESTAMP_US) 

82 except KeyError as e: 

83 if self.raw or self.config.spol.pmu_status: 

84 log.debug("Could not sort batch by '%s': Field does not exist.", TIMESTAMP_US) 

85 log.debug( 

86 "Since we currently don't depend on mod. state, we ignore this and sort by '%s' instead", DATE_OBS 

87 ) 

88 self.batch.sort_by(DATE_OBS) 

89 else: 

90 raise e 

91 

92 def __load_from_db(self) -> list: 

93 leftovers = [] 

94 metadata = {} 

95 dbs = list(set([Metadata.db_path(f) for f in self.files])) 

96 log.debug('Collect metadata from %s DB(s)...', len(dbs)) 

97 for db in [Metadata(path) for path in dbs if path is not None]: 

98 metadata.update(db.data) 

99 

100 for file in self.files: 

101 if not self.__search_dbs(file, metadata): 

102 leftovers.append(file) 

103 return leftovers 

104 

105 def __search_dbs(self, file: str, metadata: dict) -> bool: 

106 if os.path.basename(file) in metadata: 

107 self.batch.batch.append( 

108 { 

109 'file': file, 

110 'header': {k: v for k, v in metadata[os.path.basename(file)].items() if v is not None}, 

111 'data': None, 

112 } 

113 ) 

114 return True 

115 return False 

116 

117 def __load_from_fits(self, leftovers: list) -> None: 

118 prev_len = len(self.batch) 

119 log.debug('Loaded %s, missing: %s. Loading header from fits files', prev_len, len(leftovers)) 

120 self.batch.load( 

121 leftovers, 

122 workers=self.config.base.io_speed.value, 

123 sort_by=self.config.base.timestamp_field, 

124 header_only=True, 

125 append=True 

126 ) 

127 if len(self.batch) - prev_len < len(leftovers): 

128 rest = len(leftovers) - (len(self.batch) - prev_len) 

129 if rest > len(self.batch) * 0.001: 

130 raise InsufficientDataException( 

131 f'Could not load {rest} file headers in the range {leftovers[0]} to {leftovers[-1]}' 

132 ) 

133 else: 

134 log.warning('Could not load %s file headers in the range %s to %s', rest, leftovers[0], leftovers[-1]) 

135 

136 def _detect_mod_sate(self) -> None: 

137 if all(MOD_STATE in e['header'] for e in self.batch.batch): 

138 log.info('Using modulation state key in header.') 

139 self._use_header_info() 

140 elif self.config.spol.pmu_status: 

141 log.info('Determining modulation states of all frames in the file list.') 

142 self._detect_from_pmu_angle() 

143 else: 

144 log.info('PMU state ignored. faking mod states') 

145 self._fake_mode_states() 

146 

147 def _raw_chunks(self): 

148 size = 48 if self.raw and self.config.cam.temporal_binning <= 1 else self.config.cam.temporal_binning 

149 chunks = Collections.chunker(self.files, size) 

150 for i in range(len(chunks)): 

151 self.chunks[i] = chunks[i] 

152 

153 def _use_header_info(self): 

154 self.state_map = [int(e['header'][MOD_STATE]) for e in self.batch.batch] 

155 

156 def _detect_from_pmu_angle(self): 

157 ms = ModStateDetector(self.config, self.batch) 

158 ms.analyze() 

159 self.state_map = ms.states 

160 

161 def _fake_mode_states(self): 

162 self._fake_states = True 

163 if len(self.batch.batch) >= self.config.spol.mod_cycle_frames: 

164 self.state_map = np.arange(len(self.batch.batch)) % self.config.spol.mod_cycle_frames 

165 else: 

166 raise InsufficientDataException('There are not enough frames to complete one modulation cycle.') 

167 

168 def _remove_overhang_frames(self) -> None: 

169 while self.state_map[0] != 0: 

170 log.info('\tIgnore start Frame: %s [OVERHANG]', self.batch.file_by(0)) 

171 self.batch.batch.pop(0) 

172 self.state_map = np.delete(self.state_map, 0) 

173 while self.state_map[-1] != 11: 

174 log.info('\tIgnore end Frame: %s [OVERHANG]', self.batch.file_by(-1)) 

175 self.batch.batch.pop() 

176 self.state_map = np.delete(self.state_map, -1) 

177 

178 def _gen_mod_cycle_blocks(self) -> None: 

179 block_id = 0 

180 mod_state_block = list() 

181 for idx in range(len(self.state_map)): 

182 mod_state_block.append( 

183 StateLink(self.state_map[idx], self.batch.file_by(idx, full_path=True), self._fake_states) 

184 ) 

185 if self.state_map[idx] == 11: 

186 self.mod_state_blocks[block_id] = mod_state_block 

187 block_id += 1 

188 mod_state_block = list() 

189 log.info('Detected %s modulation cycles', block_id) 

190 

191 def _build_processing_chunks(self) -> None: 

192 if self.config.cam.temporal_binning < 0: 

193 # Average all 

194 self.chunks['0'] = [self.mod_state_blocks[i] for i in range(len(self.mod_state_blocks.keys()))] 

195 else: 

196 idx = 0 

197 for i in range(0, len(self.mod_state_blocks.keys()), self.config.cam.temporal_binning): 

198 end = min(i + self.config.cam.temporal_binning, len(self.mod_state_blocks.keys())) 

199 self.chunks[str(idx)] = [self.mod_state_blocks[j] for j in range(i, end)] 

200 idx += 1