Coverage for src/susi/io/fits_batch.py: 74%

190 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4module provides FitsBatch 

5 

6@author: hoelken, iglesias 

7""" 

8from __future__ import annotations 

9 

10import warnings 

11from datetime import timedelta 

12from typing import Union 

13 

14import numpy as np 

15import os 

16 

17from astropy.io import fits 

18import warnings 

19from astropy.utils.exceptions import AstropyWarning 

20from .state_link import StateLink 

21from ..base import Logging, IllegalStateException, Config 

22from ..db import Metadata 

23from ..utils import MP 

24from .fits import Fits 

25 

26from ..base.header_keys import * 

27 

28logger = Logging.get_logger() 

29 

30 

31class FitsBatch: 

32 """ 

33 ## FitsBatch 

34 Provides utilities for loading FITS files in batches. 

35 Uses multithreading for I/O tasks. 

36 """ 

37 

38 def __init__(self, config: Config = None, slices: tuple = None): 

39 """ 

40 Constructor 

41 Does not perform any heavy or dangerous tasks, to allow lightweight 

42 creation. However, you need to `load` files to actually do anything 

43 with it. 

44 """ 

45 #: Instance of Config class. Required to load_ics and convert fits files 

46 self.config = config 

47 #: An optional tuple of slices to cut the data on read 

48 self.slices = slices 

49 self.slices = config.cam.data_shape if self.slices is None and config is not None else self.slices 

50 # The loaded files go here 

51 self.batch = [] 

52 self.next = 0 

53 

54 @staticmethod 

55 def read(file: str, hdu: int = 0, slices: tuple = None) -> tuple: 

56 """ 

57 Static resource-safe fits reader. 

58 

59 ### Params 

60 - file: the file path of the file to open 

61 - hdu: the hdu to read (Default=PrimaryHDU=0) 

62 - slices: an optional tuple of slices to cut the read data to 

63 

64 ### Returns 

65 A tuple with header information at 0 and data at 1 

66 """ 

67 path = file.file if isinstance(file, StateLink) else file 

68 f = Fits(path).read(hdu=hdu, slices=slices) 

69 if isinstance(file, StateLink): 

70 c = "FAKED STATE | No PMU data" if file.fake_state else "Determined from PMU angle" 

71 f.header = Fits.override_header(f.header, MOD_STATE, file.state, comment=c) 

72 return f.header, f.data 

73 

74 @staticmethod 

75 def read_ics(file: str, config: Config, hdu: int = 0, slices: tuple = None): 

76 """ 

77 Static resource-safe ics fits reader. 

78 

79 ### Params 

80 - file: the file path of the file to open 

81 - config: Instance of Config class 

82 - hdu: the hdu to read (Default=PrimaryHDU=0) 

83 - slices: an optional tuple of slices to cut the read data to 

84 

85 ### Returns 

86 A tuple with header information at 0 and data at 1 

87 """ 

88 f = Fits(file).read_ics(config, hdu=hdu, slices=slices) 

89 return f.header, f.data 

90 

91 @staticmethod 

92 def read_header(file: str, hdu: int = 0): 

93 """ 

94 Static resource safe method to read the fits header only 

95 

96 ### Params 

97 - file: the file path of the file to open 

98 - hdu: the hdu to read (Default=PrimaryHDU=0) 

99 

100 ### Returns 

101 The fits header 

102 """ 

103 header = None 

104 try: 

105 header = Fits.get_header(file, hdu) 

106 except OSError as error: 

107 logger.error("Fits.read: %s: %s", file, error) 

108 return header 

109 

110 @staticmethod 

111 def convert_ics(file: str, config: Config, hdu: int = 0, slices: tuple = None, overwrite: bool = False): 

112 """ 

113 Static resource-safe fits converter from ICS fits to standard fits. 

114 

115 ### Params 

116 - file: the path of the file to convert 

117 - config: Instance of Config class 

118 - hdu: the hdu to read (Default=PrimaryHDU=0) 

119 - slices: If defined only the given slice of the HDU data will be loaded. 

120 Number of slices must match array shape. 

121 - overwrite: If `True`, overwrite the output file if it exists. Raises an 

122 `OSError` if `False` and the output file exists. (Default=`False`) 

123 

124 ### Returns 

125 The content of the written Fits file 

126 """ 

127 return Fits(file).convert_ics(config, hdu=hdu, slices=slices, overwrite=overwrite) 

128 

129 def load(self, f_list: list, workers: int = 10, header_only: bool = False, sort_by: str = None, 

130 append: bool = False) -> FitsBatch: 

131 """ 

132 Load a batch of fits files into memory 

133 Uses a multithreading approach to load the data from the given file list. 

134 

135 ### Params 

136 - f_list: List of file paths as strings 

137 - workers: Number of concurrent threads to use. 

138 Should be equal or less than the length of the list. 

139 - header_only: if set only the header will be loaded. (Default: False) 

140 - sort_by: if given the batch is sorted by the header field specified 

141 - append: if set the new files are appended to the existing batch 

142 """ 

143 if len(self.batch) > 0 and not append: 

144 logger.debug("Batch already loaded. Use append=True to add more files.") 

145 return self 

146 

147 if header_only: 

148 MP.threaded(self.__load_header, f_list, workers) 

149 else: 

150 MP.threaded(self.__load_file, f_list, workers) 

151 if sort_by is not None and all(sort_by in e["header"] for e in self.batch): 

152 self.sort_by(sort_by) 

153 return self 

154 

155 def header_copy(self) -> FitsBatch: 

156 """Creates and returns a header_only copy of the current batch""" 

157 copy = FitsBatch() 

158 copy.batch = [{"data": None, "file": e["file"], "header": e["header"].copy()} for e in self.batch] 

159 return copy 

160 

161 def is_applied(self, block: str) -> bool: 

162 """Check if a given block was already applied to the batch""" 

163 if len(self.batch) == 0 or BLOCKS_APPLIED not in self.batch[0]["header"]: 

164 # Not loaded 

165 return False 

166 return block in self.batch[0]["header"][BLOCKS_APPLIED] 

167 

168 def write_to_disk(self, overwrite=False, astype="float32", verify="silentfix+ignore", write_metadata=False) -> bool: 

169 """ 

170 Writes all entries in the batch to disk. 

171 A summary is printed to the log. 

172 

173 ### Params 

174 - overwrite: bool, optional 

175 If `True`, overwrite the output file if it exists. In either case 

176 it will not fail the total process but only skip files that where not written. 

177 - astype: string, optional 

178 Define the datatype of the HDU data array. Default: float32 

179 Use 'same' to avoid changing the dtype of the input 

180 """ 

181 count = 0 

182 if verify == "silentfix+ignore": 

183 warnings.filterwarnings('ignore', category=AstropyWarning) 

184 try: 

185 args = [[entry, overwrite, astype, verify] for entry in self.batch] 

186 count = sum(MP.threaded(FitsBatch._dump_entry, args, workers=max(20, len(self)))) 

187 if write_metadata: 

188 self.write_metadata() 

189 finally: 

190 logger.debug(f"{count} of {len(self)} files written to disk.") 

191 return count == len(self) 

192 

193 def write_metadata(self, overwrite=False): 

194 try: 

195 try: 

196 # Try to add the files as batch 

197 Metadata.insert_batch([file for file in self], overwrite=overwrite) 

198 return 

199 except (IndexError, ValueError): 

200 pass 

201 

202 for file in self: 

203 # different times, add them individually... 

204 if os.path.exists(file.path): 

205 Metadata.insert(file, overwrite=overwrite) 

206 except (PermissionError, IOError) as e: 

207 logger.info("Failed to update metadata database: %s", e) 

208 return 

209 

210 @staticmethod 

211 def _dump_entry(args: list) -> int: 

212 # args 

213 # - 0: entry: dict, 

214 # - 1: overwrite: bool, 

215 # - 2: astype: str', 

216 # - 3: verify:str 

217 try: 

218 os.makedirs(os.path.dirname(args[0]["file"]), exist_ok=True) 

219 data = args[0]["data"] if args[2] == "same" else args[0]["data"].astype(args[2]) 

220 hdul = fits.HDUList(fits.PrimaryHDU(data, args[0]["header"], scale_back=False)) 

221 if args[3]: 

222 hdul.writeto(args[0]["file"], overwrite=args[1], output_verify=args[3]) 

223 else: 

224 hdul.writeto(args[0]["file"], overwrite=args[1]) 

225 

226 return 1 

227 except OSError as e: 

228 if not args[1]: 

229 logger.debug("Error writing file: %s", e) 

230 else: 

231 logger.warning("Error writing file: %s", e) 

232 return 0 

233 

234 # TODO: It is kind of awkward to have two dark correction methods lurking around 

235 def apply_dark(self, dark_file: str) -> None: 

236 """ 

237 Apply dark correction to all fits files of the batch 

238 """ 

239 warnings.warn("FitsBatch#apply_dark is deprecated. Use Block C instead", DeprecationWarning, stacklevel=2) 

240 logger.info("Apply quick and dirty dark correction to all files of batch") 

241 logger.debug("Loading area %s of dark file %s", self.config.data_shape, os.path.basename(dark_file)) 

242 

243 dark_data_shape = self.config.data_shape.copy() 

244 dark_data_shape.insert(0, slice(0, 1)) 

245 dark = Fits(dark_file).read(slices=dark_data_shape).data 

246 dark = dark[0, :] 

247 

248 n = len(self.batch) 

249 for i in range(n): 

250 if (i + 1) % 200 == 0: 

251 logger.debug("\t> %s/%s", i + 1, n) 

252 self.batch[i]["data"] = np.subtract(self.batch[i]["data"], dark) 

253 

254 def load_ics(self, f_list: list, workers: int = 10) -> None: 

255 """ 

256 Load a batch of ICS fits files into memory 

257 Uses a multithreading approach to load the data from the given file list. 

258 

259 ### Params 

260 - f_list: List of file paths as strings 

261 - workers: Number of concurrent threads to use. 

262 Should be equal or less than the length of the list. 

263 """ 

264 MP.threaded(self.__load_ics_file, f_list, workers) 

265 

266 def convert(self, f_list: list, workers: int = 10) -> None: 

267 """ 

268 Converts a batch of ICS fits files 

269 Uses a multithreading approach to convert the given file list. 

270 Each Fits is written to self.path, except that Config.collection is replaced by Config.out_path. 

271 The full output dir structure is created. 

272 

273 ### Params 

274 - f_list: List of file paths as strings 

275 - workers: Number of concurrent threads to use. 

276 Should be equal or less than the length of the list. 

277 """ 

278 logger.debug("Converting %s files to level 0.1" % len(f_list)) 

279 MP.threaded(self.__convert_ics_file, f_list, workers) 

280 

281 def __getitem__(self, key: Union[str, int]) -> dict: 

282 """ 

283 Get the loaded data and header entries by file name via [] 

284 

285 ### Params 

286 - key: Either the file name of the fits file to access or the position 

287 of the entry within the batch. 

288 

289 ### Returns 

290 The fits as dict with keys 'file', 'data' and 'header' 

291 """ 

292 if isinstance(key, str): 

293 return next((item for item in self.batch if item["file"] == key), None) 

294 

295 return self.batch[key] 

296 

297 def __iter__(self): 

298 """Make this class iterable (e.g. use with for ... in)""" 

299 self.next = 0 

300 return self 

301 

302 def __next__(self) -> Fits: 

303 """Get the next element in the batch as fits""" 

304 try: 

305 return self.as_fits(self.next) 

306 except IndexError: 

307 raise StopIteration() 

308 finally: 

309 self.next += 1 

310 

311 def __repr__(self): 

312 """Representation string""" 

313 out = f"{self.__class__}\n Loaded content: {len(self.batch)} files\n" 

314 for e in self.batch: 

315 out += f"\t-{e['file']}\t {e['data'].shape if e['data'] is not None else ''}\n" 

316 return out 

317 

318 def as_fits(self, idx: Union[int, str]) -> Fits: 

319 """Return the element at the given position as Fits""" 

320 f = Fits(self.batch[idx]["file"]) 

321 f.header = self.batch[idx]["header"] 

322 f.data = self.batch[idx]["data"] 

323 return f 

324 

325 def data_array(self) -> np.array: 

326 """ 

327 Get the data arrays from the current batch of FITSes as `fits.data.shape + 1` shaped array 

328 """ 

329 self.__check_loaded() 

330 return np.array([np.array(e["data"]) for e in self.batch]) 

331 

332 def file_list(self) -> list[str]: 

333 """ 

334 Get the list of files from this batch 

335 """ 

336 self.__check_loaded() 

337 return [e["file"] for e in self.batch] 

338 

339 def header_field(self, field_name: str) -> list: 

340 """ 

341 Access the same header field of all FITSes in the batch 

342 

343 ### Params 

344 - field_name: the name of the header field to access 

345 

346 ### Returns 

347 List of field_value(s), the list id corresponds to the internal id 

348 """ 

349 self.__check_loaded() 

350 return [entry["header"][field_name] for entry in self.batch] 

351 

352 def header_field_comment(self, field_name: str, split_on: str = None) -> list: 

353 """ 

354 Access the same header field comment of all FITSes in the batch 

355 

356 ### Params 

357 - field_name: The name of the header field to access 

358 - split_on: If set will return the 0th part of the comment 

359 

360 ### Returns 

361 List of dictionaries with field_comments, the list id corresponds to the internal id 

362 """ 

363 self.__check_loaded() 

364 return [FitsBatch.__comment(entry, field_name, split_on) for entry in self.batch] 

365 

366 @staticmethod 

367 def __comment(entry, field, split_on): 

368 if split_on: 

369 return entry["header"].comments[field].split(split_on)[0] 

370 else: 

371 return entry["header"].comments[field] 

372 

373 def file_by(self, idx: int, full_path: bool = False) -> str: 

374 """ 

375 Get the file name corresponding to the id 

376 

377 ### Params 

378 - idx: the id to fetch the file name for 

379 - full_path: if True the full path will be returned otherwise the filename only (Default: False) 

380 

381 ### Returns 

382 The file name 

383 """ 

384 if full_path: 

385 return self.batch[idx]["file"] 

386 return self.batch[idx]["file"].split("/")[-1] 

387 

388 def sort_by(self, field_name: str) -> object: 

389 """ 

390 function sorts the batch by a given header fields content 

391 

392 ### Params 

393 - field_name: the header field name to sort by 

394 

395 ### Returns 

396 `self` 

397 """ 

398 self.batch = sorted(self.batch, key=lambda e: e["header"][field_name]) 

399 return self 

400 

401 def header_content(self, field_name: str, pos: int = 0): 

402 """ 

403 Get the value of the header field for the requested batch position 

404 ### Params 

405 - field_name: the name of the header field to access 

406 - pos: the position in the batch `0` for first (default), ` -1` for last 

407 

408 ### Returns 

409 Tuple with (value, comment) of the field 

410 """ 

411 self.__check_loaded() 

412 return self.batch[pos]["header"][field_name], self.batch[pos]["header"].comments[field_name] 

413 

414 def __len__(self) -> int: 

415 """ 

416 Get the number of FITSes in the batch. 

417 """ 

418 return len(self.batch) 

419 

420 # Load an individual FITS file 

421 def __load_file(self, file, hdu=0): 

422 content = FitsBatch.read(file, hdu=hdu, slices=self.slices) 

423 if content[1] is not None: 

424 self.batch.append( 

425 {"file": file.file if isinstance(file, StateLink) else file, "header": content[0], "data": content[1]} 

426 ) 

427 

428 # Load an individual ICS FITS file 

429 def __load_ics_file(self, file, hdu=0): 

430 content = FitsBatch.read_ics(file, self.config, hdu=hdu, slices=self.slices) 

431 self.batch.append({"file": file, "header": content[0], "data": content[1]}) 

432 

433 # Load only the header of a given FITS file 

434 def __load_header(self, file, hdu=0): 

435 header = FitsBatch.read_header(file, hdu=hdu) 

436 if header is not None: 

437 self.batch.append({"file": file, "header": header, "data": None}) 

438 

439 # Check if data has been loaded and raise exception if not 

440 def __check_loaded(self): 

441 if not self.batch: 

442 raise IllegalStateException("No batch of FITS files loaded.") 

443 

444 # Convert an individual ics FITS file 

445 def __convert_ics_file(self, file, overwrite=False, hdu=0): 

446 FitsBatch.convert_ics(file, self.config, hdu=hdu, slices=self.slices, overwrite=overwrite)