Coverage for src/susi/io/fits_batch.py: 74%

194 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-08-11 10:03 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4module provides FitsBatch 

5 

6@author: hoelken, iglesias 

7""" 

8from __future__ import annotations 

9 

10import warnings 

11from datetime import timedelta 

12from typing import Union 

13 

14import numpy as np 

15import os 

16 

17from astropy.io import fits 

18import warnings 

19from astropy.utils.exceptions import AstropyWarning 

20from .state_link import StateLink 

21from ..base import Logging, IllegalStateException, Config, header_keys 

22from ..db import Metadata 

23from ..utils import MP 

24from .fits import Fits 

25 

26from ..base.header_keys import * 

27 

28logger = Logging.get_logger() 

29 

30 

31class FitsBatch: 

32 """ 

33 ## FitsBatch 

34 Provides utilities for loading FITS files in batches. 

35 Uses multithreading for I/O tasks. 

36 """ 

37 

38 def __init__(self, config: Config = None, slices: tuple = None): 

39 """ 

40 Constructor 

41 Does not perform any heavy or dangerous tasks, to allow lightweight 

42 creation. However, you need to `load` files to actually do anything 

43 with it. 

44 """ 

45 #: Instance of Config class. Required to load_ics and convert fits files 

46 self.config = config 

47 #: An optional tuple of slices to cut the data on read 

48 self.slices = slices 

49 self.slices = config.cam.data_shape if self.slices is None and config is not None else self.slices 

50 # The loaded files go here 

51 self.batch = [] 

52 self.next = 0 

53 

54 @staticmethod 

55 def read(file: str, hdu: int = 0, slices: tuple = None) -> tuple: 

56 """ 

57 Static resource-safe fits reader. 

58 

59 ### Params 

60 - file: the file path of the file to open 

61 - hdu: the hdu to read (Default=PrimaryHDU=0) 

62 - slices: an optional tuple of slices to cut the read data to 

63 

64 ### Returns 

65 A tuple with header information at 0 and data at 1 

66 """ 

67 path = file.file if isinstance(file, StateLink) else file 

68 f = Fits(path).read(hdu=hdu, slices=slices) 

69 if isinstance(file, StateLink): 

70 c = "FAKED STATE | No PMU data" if file.fake_state else "Determined from PMU angle" 

71 f.header = Fits.override_header(f.header, MOD_STATE, file.state, comment=c) 

72 return f.header, f.data 

73 

74 @staticmethod 

75 def read_ics(file: str, config: Config, hdu: int = 0, slices: tuple = None): 

76 """ 

77 Static resource-safe ics fits reader. 

78 

79 ### Params 

80 - file: the file path of the file to open 

81 - config: Instance of Config class 

82 - hdu: the hdu to read (Default=PrimaryHDU=0) 

83 - slices: an optional tuple of slices to cut the read data to 

84 

85 ### Returns 

86 A tuple with header information at 0 and data at 1 

87 """ 

88 f = Fits(file).read_ics(config, hdu=hdu, slices=slices) 

89 return f.header, f.data 

90 

91 @staticmethod 

92 def read_header(file: str, hdu: int = 0): 

93 """ 

94 Static resource safe method to read the fits header only 

95 

96 ### Params 

97 - file: the file path of the file to open 

98 - hdu: the hdu to read (Default=PrimaryHDU=0) 

99 

100 ### Returns 

101 The fits header 

102 """ 

103 header = None 

104 try: 

105 header = Fits.get_header(file, hdu) 

106 except OSError as error: 

107 logger.error("Fits.read: %s: %s", file, error) 

108 return header 

109 

110 @staticmethod 

111 def convert_ics(file: str, config: Config, hdu: int = 0, slices: tuple = None, overwrite: bool = False): 

112 """ 

113 Static resource-safe fits converter from ICS fits to standard fits. 

114 

115 ### Params 

116 - file: the path of the file to convert 

117 - config: Instance of Config class 

118 - hdu: the hdu to read (Default=PrimaryHDU=0) 

119 - slices: If defined only the given slice of the HDU data will be loaded. 

120 Number of slices must match array shape. 

121 - overwrite: If `True`, overwrite the output file if it exists. Raises an 

122 `OSError` if `False` and the output file exists. (Default=`False`) 

123 

124 ### Returns 

125 The content of the written Fits file 

126 """ 

127 return Fits(file).convert_ics(config, hdu=hdu, slices=slices, overwrite=overwrite) 

128 

129 def load( 

130 self, f_list: list, workers: int = 10, header_only: bool = False, sort_by: str = None, append: bool = False 

131 ) -> FitsBatch: 

132 """ 

133 Load a batch of fits files into memory 

134 Uses a multithreading approach to load the data from the given file list. 

135 

136 ### Params 

137 - f_list: List of file paths as strings 

138 - workers: Number of concurrent threads to use. 

139 Should be equal or less than the length of the list. 

140 - header_only: if set only the header will be loaded. (Default: False) 

141 - sort_by: if given the batch is sorted by the header field specified 

142 - append: if set the new files are appended to the existing batch 

143 """ 

144 if len(self.batch) > 0 and not append: 

145 logger.debug("Batch already loaded. Use append=True to add more files.") 

146 return self 

147 

148 if header_only: 

149 MP.threaded(self.__load_header, f_list, workers) 

150 else: 

151 MP.threaded(self.__load_file, f_list, workers) 

152 if sort_by is not None and all(sort_by in e["header"] for e in self.batch): 

153 self.sort_by(sort_by) 

154 return self 

155 

156 def header_copy(self) -> FitsBatch: 

157 """Creates and returns a header_only copy of the current batch""" 

158 copy = FitsBatch() 

159 copy.batch = [{"data": None, "file": e["file"], "header": e["header"].copy()} for e in self.batch] 

160 return copy 

161 

162 def is_applied(self, block: str) -> bool: 

163 """Check if a given block was already applied to the batch""" 

164 if len(self.batch) == 0 or BLOCKS_APPLIED not in self.batch[0]["header"]: 

165 # Not loaded 

166 return False 

167 return block in self.batch[0]["header"][BLOCKS_APPLIED] 

168 

169 def write_to_disk(self, overwrite=False, astype="float32", verify="silentfix+ignore", write_metadata=False) -> bool: 

170 """ 

171 Writes all entries in the batch to disk. 

172 A summary is printed to the log. 

173 

174 ### Params 

175 - overwrite: bool, optional 

176 If `True`, overwrite the output file if it exists. In either case 

177 it will not fail the total process but only skip files that where not written. 

178 - astype: string, optional 

179 Define the datatype of the HDU data array. Default: float32 

180 Use 'same' to avoid changing the dtype of the input 

181 """ 

182 count = 0 

183 if verify == "silentfix+ignore": 

184 warnings.filterwarnings('ignore', category=AstropyWarning) 

185 try: 

186 args = [[entry, overwrite, astype, verify] for entry in self.batch] 

187 count = sum(MP.threaded(FitsBatch._dump_entry, args, workers=max(20, len(self)))) 

188 if write_metadata: 

189 self.write_metadata() 

190 finally: 

191 logger.debug(f"{count} of {len(self)} files written to disk.") 

192 return count == len(self) 

193 

194 def write_metadata(self, overwrite=False): 

195 try: 

196 try: 

197 # Try to add the files as batch 

198 Metadata.insert_batch([file for file in self], overwrite=overwrite) 

199 return 

200 except (IndexError, ValueError): 

201 pass 

202 

203 for file in self: 

204 # different times, add them individually... 

205 if os.path.exists(file.path): 

206 Metadata.insert(file, overwrite=overwrite) 

207 except (PermissionError, IOError) as e: 

208 logger.info("Failed to update metadata database: %s", e) 

209 return 

210 

211 @staticmethod 

212 def _dump_entry(args: list) -> int: 

213 # args 

214 # - 0: entry: dict, 

215 # - 1: overwrite: bool, 

216 # - 2: astype: str', 

217 # - 3: verify:str 

218 try: 

219 os.makedirs(os.path.dirname(args[0]["file"]), exist_ok=True) 

220 data = args[0]["data"] if args[2] == "same" else args[0]["data"].astype(args[2]) 

221 hdul = fits.HDUList(fits.PrimaryHDU(data, args[0]["header"], scale_back=False)) 

222 if args[3]: 

223 hdul.writeto(args[0]["file"], overwrite=args[1], output_verify=args[3]) 

224 else: 

225 hdul.writeto(args[0]["file"], overwrite=args[1]) 

226 

227 return 1 

228 except OSError as e: 

229 if not args[1]: 

230 logger.debug("Error writing file: %s", e) 

231 else: 

232 logger.warning("Error writing file: %s", e) 

233 return 0 

234 

235 # TODO: It is kind of awkward to have two dark correction methods lurking around 

236 def apply_dark(self, dark_file: str) -> None: 

237 """ 

238 Apply dark correction to all fits files of the batch 

239 """ 

240 warnings.warn("FitsBatch#apply_dark is deprecated. Use Block C instead", DeprecationWarning, stacklevel=2) 

241 logger.info("Apply quick and dirty dark correction to all files of batch") 

242 logger.debug("Loading area %s of dark file %s", self.config.data_shape, os.path.basename(dark_file)) 

243 

244 dark_data_shape = self.config.data_shape.copy() 

245 dark_data_shape.insert(0, slice(0, 1)) 

246 dark = Fits(dark_file).read(slices=dark_data_shape).data 

247 dark = dark[0, :] 

248 

249 n = len(self.batch) 

250 for i in range(n): 

251 if (i + 1) % 200 == 0: 

252 logger.debug("\t> %s/%s", i + 1, n) 

253 self.batch[i]["data"] = np.subtract(self.batch[i]["data"], dark) 

254 

255 def load_ics(self, f_list: list, workers: int = 10) -> None: 

256 """ 

257 Load a batch of ICS fits files into memory 

258 Uses a multithreading approach to load the data from the given file list. 

259 

260 ### Params 

261 - f_list: List of file paths as strings 

262 - workers: Number of concurrent threads to use. 

263 Should be equal or less than the length of the list. 

264 """ 

265 MP.threaded(self.__load_ics_file, f_list, workers) 

266 

267 def convert(self, f_list: list, workers: int = 10) -> None: 

268 """ 

269 Converts a batch of ICS fits files 

270 Uses a multithreading approach to convert the given file list. 

271 Each Fits is written to self.path, except that Config.collection is replaced by Config.out_path. 

272 The full output dir structure is created. 

273 

274 ### Params 

275 - f_list: List of file paths as strings 

276 - workers: Number of concurrent threads to use. 

277 Should be equal or less than the length of the list. 

278 """ 

279 logger.debug("Converting %s files to level 0.1" % len(f_list)) 

280 MP.threaded(self.__convert_ics_file, f_list, workers) 

281 

282 def __getitem__(self, key: Union[str, int]) -> dict: 

283 """ 

284 Get the loaded data and header entries by file name via [] 

285 

286 ### Params 

287 - key: Either the file name of the fits file to access or the position 

288 of the entry within the batch. 

289 

290 ### Returns 

291 The fits as dict with keys 'file', 'data' and 'header' 

292 """ 

293 if isinstance(key, str): 

294 return next((item for item in self.batch if item["file"] == key), None) 

295 

296 return self.batch[key] 

297 

298 def __iter__(self): 

299 """Make this class iterable (e.g. use with for ... in)""" 

300 self.next = 0 

301 return self 

302 

303 def __next__(self) -> Fits: 

304 """Get the next element in the batch as fits""" 

305 try: 

306 return self.as_fits(self.next) 

307 except IndexError: 

308 raise StopIteration() 

309 finally: 

310 self.next += 1 

311 

312 def __repr__(self): 

313 """Representation string""" 

314 out = f"{self.__class__}\n Loaded content: {len(self.batch)} files\n" 

315 for e in self.batch: 

316 out += f"\t-{e['file']}\t {e['data'].shape if e['data'] is not None else ''}\n" 

317 return out 

318 

319 def as_fits(self, idx: Union[int, str]) -> Fits: 

320 """Return the element at the given position as Fits""" 

321 f = Fits(self.batch[idx]["file"]) 

322 f.header = self.batch[idx]["header"] 

323 f.data = self.batch[idx]["data"] 

324 return f 

325 

326 def data_array(self) -> np.array: 

327 """ 

328 Get the data arrays from the current batch of FITSes as `fits.data.shape + 1` shaped array 

329 """ 

330 self.__check_loaded() 

331 return np.array([np.array(e["data"]) for e in self.batch]) 

332 

333 def file_list(self) -> list[str]: 

334 """ 

335 Get the list of files from this batch 

336 """ 

337 self.__check_loaded() 

338 return [e["file"] for e in self.batch] 

339 

340 def header_field(self, field_name: str) -> list: 

341 """ 

342 Access the same header field of all FITSes in the batch 

343 

344 ### Params 

345 - field_name: the name of the header field to access 

346 

347 ### Returns 

348 List of field_value(s), the list id corresponds to the internal id 

349 """ 

350 self.__check_loaded() 

351 return [entry["header"][field_name] for entry in self.batch] 

352 

353 def header_field_comment(self, field_name: str, split_on: str = None) -> list: 

354 """ 

355 Access the same header field comment of all FITSes in the batch 

356 

357 ### Params 

358 - field_name: The name of the header field to access 

359 - split_on: If set will return the 0th part of the comment 

360 

361 ### Returns 

362 List of dictionaries with field_comments, the list id corresponds to the internal id 

363 """ 

364 self.__check_loaded() 

365 return [FitsBatch.__comment(entry, field_name, split_on) for entry in self.batch] 

366 

367 @staticmethod 

368 def __comment(entry, field, split_on): 

369 if split_on: 

370 return entry["header"].comments[field].split(split_on)[0] 

371 else: 

372 return entry["header"].comments[field] 

373 

374 def file_by(self, idx: int, full_path: bool = False) -> str: 

375 """ 

376 Get the file name corresponding to the id 

377 

378 ### Params 

379 - idx: the id to fetch the file name for 

380 - full_path: if True the full path will be returned otherwise the filename only (Default: False) 

381 

382 ### Returns 

383 The file name 

384 """ 

385 if full_path: 

386 return self.batch[idx]["file"] 

387 return self.batch[idx]["file"].split("/")[-1] 

388 

389 def sort_by(self, field_name: str) -> object: 

390 """ 

391 function sorts the batch by a given header fields content 

392 

393 ### Params 

394 - field_name: the header field name to sort by 

395 

396 ### Returns 

397 `self` 

398 """ 

399 self.batch = sorted(self.batch, key=lambda e: e["header"][field_name]) 

400 if field_name == DATE_OBS: 

401 # check that the resulting order corresponds to ascending TIMESTAMP_US 

402 timestamps = [e["header"][TIMESTAMP_US] for e in self.batch] 

403 if not all(timestamps[i] <= timestamps[i + 1] for i in range(len(timestamps) - 1)): 

404 raise IllegalStateException(f"Sorting by {DATE_OBS} did not result in ascending TIMESTAMP_US order.") 

405 

406 return self 

407 

408 def header_content(self, field_name: str, pos: int = 0): 

409 """ 

410 Get the value of the header field for the requested batch position 

411 ### Params 

412 - field_name: the name of the header field to access 

413 - pos: the position in the batch `0` for first (default), ` -1` for last 

414 

415 ### Returns 

416 Tuple with (value, comment) of the field 

417 """ 

418 self.__check_loaded() 

419 return self.batch[pos]["header"][field_name], self.batch[pos]["header"].comments[field_name] 

420 

421 def __len__(self) -> int: 

422 """ 

423 Get the number of FITSes in the batch. 

424 """ 

425 return len(self.batch) 

426 

427 # Load an individual FITS file 

428 def __load_file(self, file, hdu=0): 

429 content = FitsBatch.read(file, hdu=hdu, slices=self.slices) 

430 if content[1] is not None: 

431 self.batch.append( 

432 {"file": file.file if isinstance(file, StateLink) else file, "header": content[0], "data": content[1]} 

433 ) 

434 

435 # Load an individual ICS FITS file 

436 def __load_ics_file(self, file, hdu=0): 

437 content = FitsBatch.read_ics(file, self.config, hdu=hdu, slices=self.slices) 

438 self.batch.append({"file": file, "header": content[0], "data": content[1]}) 

439 

440 # Load only the header of a given FITS file 

441 def __load_header(self, file, hdu=0): 

442 header = FitsBatch.read_header(file, hdu=hdu) 

443 if header is not None: 

444 self.batch.append({"file": file, "header": header, "data": None}) 

445 

446 # Check if data has been loaded and raise exception if not 

447 def __check_loaded(self): 

448 if not self.batch: 

449 raise IllegalStateException("No batch of FITS files loaded.") 

450 

451 # Convert an individual ics FITS file 

452 def __convert_ics_file(self, file, overwrite=False, hdu=0): 

453 FitsBatch.convert_ics(file, self.config, hdu=hdu, slices=self.slices, overwrite=overwrite)