Coverage for src/susi/io/fits_batch.py: 74%
194 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-08-11 10:03 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-08-11 10:03 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4module provides FitsBatch
6@author: hoelken, iglesias
7"""
8from __future__ import annotations
10import warnings
11from datetime import timedelta
12from typing import Union
14import numpy as np
15import os
17from astropy.io import fits
18import warnings
19from astropy.utils.exceptions import AstropyWarning
20from .state_link import StateLink
21from ..base import Logging, IllegalStateException, Config, header_keys
22from ..db import Metadata
23from ..utils import MP
24from .fits import Fits
26from ..base.header_keys import *
28logger = Logging.get_logger()
31class FitsBatch:
32 """
33 ## FitsBatch
34 Provides utilities for loading FITS files in batches.
35 Uses multithreading for I/O tasks.
36 """
38 def __init__(self, config: Config = None, slices: tuple = None):
39 """
40 Constructor
41 Does not perform any heavy or dangerous tasks, to allow lightweight
42 creation. However, you need to `load` files to actually do anything
43 with it.
44 """
45 #: Instance of Config class. Required to load_ics and convert fits files
46 self.config = config
47 #: An optional tuple of slices to cut the data on read
48 self.slices = slices
49 self.slices = config.cam.data_shape if self.slices is None and config is not None else self.slices
50 # The loaded files go here
51 self.batch = []
52 self.next = 0
54 @staticmethod
55 def read(file: str, hdu: int = 0, slices: tuple = None) -> tuple:
56 """
57 Static resource-safe fits reader.
59 ### Params
60 - file: the file path of the file to open
61 - hdu: the hdu to read (Default=PrimaryHDU=0)
62 - slices: an optional tuple of slices to cut the read data to
64 ### Returns
65 A tuple with header information at 0 and data at 1
66 """
67 path = file.file if isinstance(file, StateLink) else file
68 f = Fits(path).read(hdu=hdu, slices=slices)
69 if isinstance(file, StateLink):
70 c = "FAKED STATE | No PMU data" if file.fake_state else "Determined from PMU angle"
71 f.header = Fits.override_header(f.header, MOD_STATE, file.state, comment=c)
72 return f.header, f.data
74 @staticmethod
75 def read_ics(file: str, config: Config, hdu: int = 0, slices: tuple = None):
76 """
77 Static resource-safe ics fits reader.
79 ### Params
80 - file: the file path of the file to open
81 - config: Instance of Config class
82 - hdu: the hdu to read (Default=PrimaryHDU=0)
83 - slices: an optional tuple of slices to cut the read data to
85 ### Returns
86 A tuple with header information at 0 and data at 1
87 """
88 f = Fits(file).read_ics(config, hdu=hdu, slices=slices)
89 return f.header, f.data
91 @staticmethod
92 def read_header(file: str, hdu: int = 0):
93 """
94 Static resource safe method to read the fits header only
96 ### Params
97 - file: the file path of the file to open
98 - hdu: the hdu to read (Default=PrimaryHDU=0)
100 ### Returns
101 The fits header
102 """
103 header = None
104 try:
105 header = Fits.get_header(file, hdu)
106 except OSError as error:
107 logger.error("Fits.read: %s: %s", file, error)
108 return header
110 @staticmethod
111 def convert_ics(file: str, config: Config, hdu: int = 0, slices: tuple = None, overwrite: bool = False):
112 """
113 Static resource-safe fits converter from ICS fits to standard fits.
115 ### Params
116 - file: the path of the file to convert
117 - config: Instance of Config class
118 - hdu: the hdu to read (Default=PrimaryHDU=0)
119 - slices: If defined only the given slice of the HDU data will be loaded.
120 Number of slices must match array shape.
121 - overwrite: If `True`, overwrite the output file if it exists. Raises an
122 `OSError` if `False` and the output file exists. (Default=`False`)
124 ### Returns
125 The content of the written Fits file
126 """
127 return Fits(file).convert_ics(config, hdu=hdu, slices=slices, overwrite=overwrite)
129 def load(
130 self, f_list: list, workers: int = 10, header_only: bool = False, sort_by: str = None, append: bool = False
131 ) -> FitsBatch:
132 """
133 Load a batch of fits files into memory
134 Uses a multithreading approach to load the data from the given file list.
136 ### Params
137 - f_list: List of file paths as strings
138 - workers: Number of concurrent threads to use.
139 Should be equal or less than the length of the list.
140 - header_only: if set only the header will be loaded. (Default: False)
141 - sort_by: if given the batch is sorted by the header field specified
142 - append: if set the new files are appended to the existing batch
143 """
144 if len(self.batch) > 0 and not append:
145 logger.debug("Batch already loaded. Use append=True to add more files.")
146 return self
148 if header_only:
149 MP.threaded(self.__load_header, f_list, workers)
150 else:
151 MP.threaded(self.__load_file, f_list, workers)
152 if sort_by is not None and all(sort_by in e["header"] for e in self.batch):
153 self.sort_by(sort_by)
154 return self
156 def header_copy(self) -> FitsBatch:
157 """Creates and returns a header_only copy of the current batch"""
158 copy = FitsBatch()
159 copy.batch = [{"data": None, "file": e["file"], "header": e["header"].copy()} for e in self.batch]
160 return copy
162 def is_applied(self, block: str) -> bool:
163 """Check if a given block was already applied to the batch"""
164 if len(self.batch) == 0 or BLOCKS_APPLIED not in self.batch[0]["header"]:
165 # Not loaded
166 return False
167 return block in self.batch[0]["header"][BLOCKS_APPLIED]
169 def write_to_disk(self, overwrite=False, astype="float32", verify="silentfix+ignore", write_metadata=False) -> bool:
170 """
171 Writes all entries in the batch to disk.
172 A summary is printed to the log.
174 ### Params
175 - overwrite: bool, optional
176 If `True`, overwrite the output file if it exists. In either case
177 it will not fail the total process but only skip files that where not written.
178 - astype: string, optional
179 Define the datatype of the HDU data array. Default: float32
180 Use 'same' to avoid changing the dtype of the input
181 """
182 count = 0
183 if verify == "silentfix+ignore":
184 warnings.filterwarnings('ignore', category=AstropyWarning)
185 try:
186 args = [[entry, overwrite, astype, verify] for entry in self.batch]
187 count = sum(MP.threaded(FitsBatch._dump_entry, args, workers=max(20, len(self))))
188 if write_metadata:
189 self.write_metadata()
190 finally:
191 logger.debug(f"{count} of {len(self)} files written to disk.")
192 return count == len(self)
194 def write_metadata(self, overwrite=False):
195 try:
196 try:
197 # Try to add the files as batch
198 Metadata.insert_batch([file for file in self], overwrite=overwrite)
199 return
200 except (IndexError, ValueError):
201 pass
203 for file in self:
204 # different times, add them individually...
205 if os.path.exists(file.path):
206 Metadata.insert(file, overwrite=overwrite)
207 except (PermissionError, IOError) as e:
208 logger.info("Failed to update metadata database: %s", e)
209 return
211 @staticmethod
212 def _dump_entry(args: list) -> int:
213 # args
214 # - 0: entry: dict,
215 # - 1: overwrite: bool,
216 # - 2: astype: str',
217 # - 3: verify:str
218 try:
219 os.makedirs(os.path.dirname(args[0]["file"]), exist_ok=True)
220 data = args[0]["data"] if args[2] == "same" else args[0]["data"].astype(args[2])
221 hdul = fits.HDUList(fits.PrimaryHDU(data, args[0]["header"], scale_back=False))
222 if args[3]:
223 hdul.writeto(args[0]["file"], overwrite=args[1], output_verify=args[3])
224 else:
225 hdul.writeto(args[0]["file"], overwrite=args[1])
227 return 1
228 except OSError as e:
229 if not args[1]:
230 logger.debug("Error writing file: %s", e)
231 else:
232 logger.warning("Error writing file: %s", e)
233 return 0
235 # TODO: It is kind of awkward to have two dark correction methods lurking around
236 def apply_dark(self, dark_file: str) -> None:
237 """
238 Apply dark correction to all fits files of the batch
239 """
240 warnings.warn("FitsBatch#apply_dark is deprecated. Use Block C instead", DeprecationWarning, stacklevel=2)
241 logger.info("Apply quick and dirty dark correction to all files of batch")
242 logger.debug("Loading area %s of dark file %s", self.config.data_shape, os.path.basename(dark_file))
244 dark_data_shape = self.config.data_shape.copy()
245 dark_data_shape.insert(0, slice(0, 1))
246 dark = Fits(dark_file).read(slices=dark_data_shape).data
247 dark = dark[0, :]
249 n = len(self.batch)
250 for i in range(n):
251 if (i + 1) % 200 == 0:
252 logger.debug("\t> %s/%s", i + 1, n)
253 self.batch[i]["data"] = np.subtract(self.batch[i]["data"], dark)
255 def load_ics(self, f_list: list, workers: int = 10) -> None:
256 """
257 Load a batch of ICS fits files into memory
258 Uses a multithreading approach to load the data from the given file list.
260 ### Params
261 - f_list: List of file paths as strings
262 - workers: Number of concurrent threads to use.
263 Should be equal or less than the length of the list.
264 """
265 MP.threaded(self.__load_ics_file, f_list, workers)
267 def convert(self, f_list: list, workers: int = 10) -> None:
268 """
269 Converts a batch of ICS fits files
270 Uses a multithreading approach to convert the given file list.
271 Each Fits is written to self.path, except that Config.collection is replaced by Config.out_path.
272 The full output dir structure is created.
274 ### Params
275 - f_list: List of file paths as strings
276 - workers: Number of concurrent threads to use.
277 Should be equal or less than the length of the list.
278 """
279 logger.debug("Converting %s files to level 0.1" % len(f_list))
280 MP.threaded(self.__convert_ics_file, f_list, workers)
282 def __getitem__(self, key: Union[str, int]) -> dict:
283 """
284 Get the loaded data and header entries by file name via []
286 ### Params
287 - key: Either the file name of the fits file to access or the position
288 of the entry within the batch.
290 ### Returns
291 The fits as dict with keys 'file', 'data' and 'header'
292 """
293 if isinstance(key, str):
294 return next((item for item in self.batch if item["file"] == key), None)
296 return self.batch[key]
298 def __iter__(self):
299 """Make this class iterable (e.g. use with for ... in)"""
300 self.next = 0
301 return self
303 def __next__(self) -> Fits:
304 """Get the next element in the batch as fits"""
305 try:
306 return self.as_fits(self.next)
307 except IndexError:
308 raise StopIteration()
309 finally:
310 self.next += 1
312 def __repr__(self):
313 """Representation string"""
314 out = f"{self.__class__}\n Loaded content: {len(self.batch)} files\n"
315 for e in self.batch:
316 out += f"\t-{e['file']}\t {e['data'].shape if e['data'] is not None else ''}\n"
317 return out
319 def as_fits(self, idx: Union[int, str]) -> Fits:
320 """Return the element at the given position as Fits"""
321 f = Fits(self.batch[idx]["file"])
322 f.header = self.batch[idx]["header"]
323 f.data = self.batch[idx]["data"]
324 return f
326 def data_array(self) -> np.array:
327 """
328 Get the data arrays from the current batch of FITSes as `fits.data.shape + 1` shaped array
329 """
330 self.__check_loaded()
331 return np.array([np.array(e["data"]) for e in self.batch])
333 def file_list(self) -> list[str]:
334 """
335 Get the list of files from this batch
336 """
337 self.__check_loaded()
338 return [e["file"] for e in self.batch]
340 def header_field(self, field_name: str) -> list:
341 """
342 Access the same header field of all FITSes in the batch
344 ### Params
345 - field_name: the name of the header field to access
347 ### Returns
348 List of field_value(s), the list id corresponds to the internal id
349 """
350 self.__check_loaded()
351 return [entry["header"][field_name] for entry in self.batch]
353 def header_field_comment(self, field_name: str, split_on: str = None) -> list:
354 """
355 Access the same header field comment of all FITSes in the batch
357 ### Params
358 - field_name: The name of the header field to access
359 - split_on: If set will return the 0th part of the comment
361 ### Returns
362 List of dictionaries with field_comments, the list id corresponds to the internal id
363 """
364 self.__check_loaded()
365 return [FitsBatch.__comment(entry, field_name, split_on) for entry in self.batch]
367 @staticmethod
368 def __comment(entry, field, split_on):
369 if split_on:
370 return entry["header"].comments[field].split(split_on)[0]
371 else:
372 return entry["header"].comments[field]
374 def file_by(self, idx: int, full_path: bool = False) -> str:
375 """
376 Get the file name corresponding to the id
378 ### Params
379 - idx: the id to fetch the file name for
380 - full_path: if True the full path will be returned otherwise the filename only (Default: False)
382 ### Returns
383 The file name
384 """
385 if full_path:
386 return self.batch[idx]["file"]
387 return self.batch[idx]["file"].split("/")[-1]
389 def sort_by(self, field_name: str) -> object:
390 """
391 function sorts the batch by a given header fields content
393 ### Params
394 - field_name: the header field name to sort by
396 ### Returns
397 `self`
398 """
399 self.batch = sorted(self.batch, key=lambda e: e["header"][field_name])
400 if field_name == DATE_OBS:
401 # check that the resulting order corresponds to ascending TIMESTAMP_US
402 timestamps = [e["header"][TIMESTAMP_US] for e in self.batch]
403 if not all(timestamps[i] <= timestamps[i + 1] for i in range(len(timestamps) - 1)):
404 raise IllegalStateException(f"Sorting by {DATE_OBS} did not result in ascending TIMESTAMP_US order.")
406 return self
408 def header_content(self, field_name: str, pos: int = 0):
409 """
410 Get the value of the header field for the requested batch position
411 ### Params
412 - field_name: the name of the header field to access
413 - pos: the position in the batch `0` for first (default), ` -1` for last
415 ### Returns
416 Tuple with (value, comment) of the field
417 """
418 self.__check_loaded()
419 return self.batch[pos]["header"][field_name], self.batch[pos]["header"].comments[field_name]
421 def __len__(self) -> int:
422 """
423 Get the number of FITSes in the batch.
424 """
425 return len(self.batch)
427 # Load an individual FITS file
428 def __load_file(self, file, hdu=0):
429 content = FitsBatch.read(file, hdu=hdu, slices=self.slices)
430 if content[1] is not None:
431 self.batch.append(
432 {"file": file.file if isinstance(file, StateLink) else file, "header": content[0], "data": content[1]}
433 )
435 # Load an individual ICS FITS file
436 def __load_ics_file(self, file, hdu=0):
437 content = FitsBatch.read_ics(file, self.config, hdu=hdu, slices=self.slices)
438 self.batch.append({"file": file, "header": content[0], "data": content[1]})
440 # Load only the header of a given FITS file
441 def __load_header(self, file, hdu=0):
442 header = FitsBatch.read_header(file, hdu=hdu)
443 if header is not None:
444 self.batch.append({"file": file, "header": header, "data": None})
446 # Check if data has been loaded and raise exception if not
447 def __check_loaded(self):
448 if not self.batch:
449 raise IllegalStateException("No batch of FITS files loaded.")
451 # Convert an individual ics FITS file
452 def __convert_ics_file(self, file, overwrite=False, hdu=0):
453 FitsBatch.convert_ics(file, self.config, hdu=hdu, slices=self.slices, overwrite=overwrite)