Coverage for src/susi/io/fits_batch.py: 74%
190 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4module provides FitsBatch
6@author: hoelken, iglesias
7"""
8from __future__ import annotations
10import warnings
11from datetime import timedelta
12from typing import Union
14import numpy as np
15import os
17from astropy.io import fits
18import warnings
19from astropy.utils.exceptions import AstropyWarning
20from .state_link import StateLink
21from ..base import Logging, IllegalStateException, Config
22from ..db import Metadata
23from ..utils import MP
24from .fits import Fits
26from ..base.header_keys import *
28logger = Logging.get_logger()
31class FitsBatch:
32 """
33 ## FitsBatch
34 Provides utilities for loading FITS files in batches.
35 Uses multithreading for I/O tasks.
36 """
38 def __init__(self, config: Config = None, slices: tuple = None):
39 """
40 Constructor
41 Does not perform any heavy or dangerous tasks, to allow lightweight
42 creation. However, you need to `load` files to actually do anything
43 with it.
44 """
45 #: Instance of Config class. Required to load_ics and convert fits files
46 self.config = config
47 #: An optional tuple of slices to cut the data on read
48 self.slices = slices
49 self.slices = config.cam.data_shape if self.slices is None and config is not None else self.slices
50 # The loaded files go here
51 self.batch = []
52 self.next = 0
54 @staticmethod
55 def read(file: str, hdu: int = 0, slices: tuple = None) -> tuple:
56 """
57 Static resource-safe fits reader.
59 ### Params
60 - file: the file path of the file to open
61 - hdu: the hdu to read (Default=PrimaryHDU=0)
62 - slices: an optional tuple of slices to cut the read data to
64 ### Returns
65 A tuple with header information at 0 and data at 1
66 """
67 path = file.file if isinstance(file, StateLink) else file
68 f = Fits(path).read(hdu=hdu, slices=slices)
69 if isinstance(file, StateLink):
70 c = "FAKED STATE | No PMU data" if file.fake_state else "Determined from PMU angle"
71 f.header = Fits.override_header(f.header, MOD_STATE, file.state, comment=c)
72 return f.header, f.data
74 @staticmethod
75 def read_ics(file: str, config: Config, hdu: int = 0, slices: tuple = None):
76 """
77 Static resource-safe ics fits reader.
79 ### Params
80 - file: the file path of the file to open
81 - config: Instance of Config class
82 - hdu: the hdu to read (Default=PrimaryHDU=0)
83 - slices: an optional tuple of slices to cut the read data to
85 ### Returns
86 A tuple with header information at 0 and data at 1
87 """
88 f = Fits(file).read_ics(config, hdu=hdu, slices=slices)
89 return f.header, f.data
91 @staticmethod
92 def read_header(file: str, hdu: int = 0):
93 """
94 Static resource safe method to read the fits header only
96 ### Params
97 - file: the file path of the file to open
98 - hdu: the hdu to read (Default=PrimaryHDU=0)
100 ### Returns
101 The fits header
102 """
103 header = None
104 try:
105 header = Fits.get_header(file, hdu)
106 except OSError as error:
107 logger.error("Fits.read: %s: %s", file, error)
108 return header
110 @staticmethod
111 def convert_ics(file: str, config: Config, hdu: int = 0, slices: tuple = None, overwrite: bool = False):
112 """
113 Static resource-safe fits converter from ICS fits to standard fits.
115 ### Params
116 - file: the path of the file to convert
117 - config: Instance of Config class
118 - hdu: the hdu to read (Default=PrimaryHDU=0)
119 - slices: If defined only the given slice of the HDU data will be loaded.
120 Number of slices must match array shape.
121 - overwrite: If `True`, overwrite the output file if it exists. Raises an
122 `OSError` if `False` and the output file exists. (Default=`False`)
124 ### Returns
125 The content of the written Fits file
126 """
127 return Fits(file).convert_ics(config, hdu=hdu, slices=slices, overwrite=overwrite)
129 def load(self, f_list: list, workers: int = 10, header_only: bool = False, sort_by: str = None,
130 append: bool = False) -> FitsBatch:
131 """
132 Load a batch of fits files into memory
133 Uses a multithreading approach to load the data from the given file list.
135 ### Params
136 - f_list: List of file paths as strings
137 - workers: Number of concurrent threads to use.
138 Should be equal or less than the length of the list.
139 - header_only: if set only the header will be loaded. (Default: False)
140 - sort_by: if given the batch is sorted by the header field specified
141 - append: if set the new files are appended to the existing batch
142 """
143 if len(self.batch) > 0 and not append:
144 logger.debug("Batch already loaded. Use append=True to add more files.")
145 return self
147 if header_only:
148 MP.threaded(self.__load_header, f_list, workers)
149 else:
150 MP.threaded(self.__load_file, f_list, workers)
151 if sort_by is not None and all(sort_by in e["header"] for e in self.batch):
152 self.sort_by(sort_by)
153 return self
155 def header_copy(self) -> FitsBatch:
156 """Creates and returns a header_only copy of the current batch"""
157 copy = FitsBatch()
158 copy.batch = [{"data": None, "file": e["file"], "header": e["header"].copy()} for e in self.batch]
159 return copy
161 def is_applied(self, block: str) -> bool:
162 """Check if a given block was already applied to the batch"""
163 if len(self.batch) == 0 or BLOCKS_APPLIED not in self.batch[0]["header"]:
164 # Not loaded
165 return False
166 return block in self.batch[0]["header"][BLOCKS_APPLIED]
168 def write_to_disk(self, overwrite=False, astype="float32", verify="silentfix+ignore", write_metadata=False) -> bool:
169 """
170 Writes all entries in the batch to disk.
171 A summary is printed to the log.
173 ### Params
174 - overwrite: bool, optional
175 If `True`, overwrite the output file if it exists. In either case
176 it will not fail the total process but only skip files that where not written.
177 - astype: string, optional
178 Define the datatype of the HDU data array. Default: float32
179 Use 'same' to avoid changing the dtype of the input
180 """
181 count = 0
182 if verify == "silentfix+ignore":
183 warnings.filterwarnings('ignore', category=AstropyWarning)
184 try:
185 args = [[entry, overwrite, astype, verify] for entry in self.batch]
186 count = sum(MP.threaded(FitsBatch._dump_entry, args, workers=max(20, len(self))))
187 if write_metadata:
188 self.write_metadata()
189 finally:
190 logger.debug(f"{count} of {len(self)} files written to disk.")
191 return count == len(self)
193 def write_metadata(self, overwrite=False):
194 try:
195 try:
196 # Try to add the files as batch
197 Metadata.insert_batch([file for file in self], overwrite=overwrite)
198 return
199 except (IndexError, ValueError):
200 pass
202 for file in self:
203 # different times, add them individually...
204 if os.path.exists(file.path):
205 Metadata.insert(file, overwrite=overwrite)
206 except (PermissionError, IOError) as e:
207 logger.info("Failed to update metadata database: %s", e)
208 return
210 @staticmethod
211 def _dump_entry(args: list) -> int:
212 # args
213 # - 0: entry: dict,
214 # - 1: overwrite: bool,
215 # - 2: astype: str',
216 # - 3: verify:str
217 try:
218 os.makedirs(os.path.dirname(args[0]["file"]), exist_ok=True)
219 data = args[0]["data"] if args[2] == "same" else args[0]["data"].astype(args[2])
220 hdul = fits.HDUList(fits.PrimaryHDU(data, args[0]["header"], scale_back=False))
221 if args[3]:
222 hdul.writeto(args[0]["file"], overwrite=args[1], output_verify=args[3])
223 else:
224 hdul.writeto(args[0]["file"], overwrite=args[1])
226 return 1
227 except OSError as e:
228 if not args[1]:
229 logger.debug("Error writing file: %s", e)
230 else:
231 logger.warning("Error writing file: %s", e)
232 return 0
234 # TODO: It is kind of awkward to have two dark correction methods lurking around
235 def apply_dark(self, dark_file: str) -> None:
236 """
237 Apply dark correction to all fits files of the batch
238 """
239 warnings.warn("FitsBatch#apply_dark is deprecated. Use Block C instead", DeprecationWarning, stacklevel=2)
240 logger.info("Apply quick and dirty dark correction to all files of batch")
241 logger.debug("Loading area %s of dark file %s", self.config.data_shape, os.path.basename(dark_file))
243 dark_data_shape = self.config.data_shape.copy()
244 dark_data_shape.insert(0, slice(0, 1))
245 dark = Fits(dark_file).read(slices=dark_data_shape).data
246 dark = dark[0, :]
248 n = len(self.batch)
249 for i in range(n):
250 if (i + 1) % 200 == 0:
251 logger.debug("\t> %s/%s", i + 1, n)
252 self.batch[i]["data"] = np.subtract(self.batch[i]["data"], dark)
254 def load_ics(self, f_list: list, workers: int = 10) -> None:
255 """
256 Load a batch of ICS fits files into memory
257 Uses a multithreading approach to load the data from the given file list.
259 ### Params
260 - f_list: List of file paths as strings
261 - workers: Number of concurrent threads to use.
262 Should be equal or less than the length of the list.
263 """
264 MP.threaded(self.__load_ics_file, f_list, workers)
266 def convert(self, f_list: list, workers: int = 10) -> None:
267 """
268 Converts a batch of ICS fits files
269 Uses a multithreading approach to convert the given file list.
270 Each Fits is written to self.path, except that Config.collection is replaced by Config.out_path.
271 The full output dir structure is created.
273 ### Params
274 - f_list: List of file paths as strings
275 - workers: Number of concurrent threads to use.
276 Should be equal or less than the length of the list.
277 """
278 logger.debug("Converting %s files to level 0.1" % len(f_list))
279 MP.threaded(self.__convert_ics_file, f_list, workers)
281 def __getitem__(self, key: Union[str, int]) -> dict:
282 """
283 Get the loaded data and header entries by file name via []
285 ### Params
286 - key: Either the file name of the fits file to access or the position
287 of the entry within the batch.
289 ### Returns
290 The fits as dict with keys 'file', 'data' and 'header'
291 """
292 if isinstance(key, str):
293 return next((item for item in self.batch if item["file"] == key), None)
295 return self.batch[key]
297 def __iter__(self):
298 """Make this class iterable (e.g. use with for ... in)"""
299 self.next = 0
300 return self
302 def __next__(self) -> Fits:
303 """Get the next element in the batch as fits"""
304 try:
305 return self.as_fits(self.next)
306 except IndexError:
307 raise StopIteration()
308 finally:
309 self.next += 1
311 def __repr__(self):
312 """Representation string"""
313 out = f"{self.__class__}\n Loaded content: {len(self.batch)} files\n"
314 for e in self.batch:
315 out += f"\t-{e['file']}\t {e['data'].shape if e['data'] is not None else ''}\n"
316 return out
318 def as_fits(self, idx: Union[int, str]) -> Fits:
319 """Return the element at the given position as Fits"""
320 f = Fits(self.batch[idx]["file"])
321 f.header = self.batch[idx]["header"]
322 f.data = self.batch[idx]["data"]
323 return f
325 def data_array(self) -> np.array:
326 """
327 Get the data arrays from the current batch of FITSes as `fits.data.shape + 1` shaped array
328 """
329 self.__check_loaded()
330 return np.array([np.array(e["data"]) for e in self.batch])
332 def file_list(self) -> list[str]:
333 """
334 Get the list of files from this batch
335 """
336 self.__check_loaded()
337 return [e["file"] for e in self.batch]
339 def header_field(self, field_name: str) -> list:
340 """
341 Access the same header field of all FITSes in the batch
343 ### Params
344 - field_name: the name of the header field to access
346 ### Returns
347 List of field_value(s), the list id corresponds to the internal id
348 """
349 self.__check_loaded()
350 return [entry["header"][field_name] for entry in self.batch]
352 def header_field_comment(self, field_name: str, split_on: str = None) -> list:
353 """
354 Access the same header field comment of all FITSes in the batch
356 ### Params
357 - field_name: The name of the header field to access
358 - split_on: If set will return the 0th part of the comment
360 ### Returns
361 List of dictionaries with field_comments, the list id corresponds to the internal id
362 """
363 self.__check_loaded()
364 return [FitsBatch.__comment(entry, field_name, split_on) for entry in self.batch]
366 @staticmethod
367 def __comment(entry, field, split_on):
368 if split_on:
369 return entry["header"].comments[field].split(split_on)[0]
370 else:
371 return entry["header"].comments[field]
373 def file_by(self, idx: int, full_path: bool = False) -> str:
374 """
375 Get the file name corresponding to the id
377 ### Params
378 - idx: the id to fetch the file name for
379 - full_path: if True the full path will be returned otherwise the filename only (Default: False)
381 ### Returns
382 The file name
383 """
384 if full_path:
385 return self.batch[idx]["file"]
386 return self.batch[idx]["file"].split("/")[-1]
388 def sort_by(self, field_name: str) -> object:
389 """
390 function sorts the batch by a given header fields content
392 ### Params
393 - field_name: the header field name to sort by
395 ### Returns
396 `self`
397 """
398 self.batch = sorted(self.batch, key=lambda e: e["header"][field_name])
399 return self
401 def header_content(self, field_name: str, pos: int = 0):
402 """
403 Get the value of the header field for the requested batch position
404 ### Params
405 - field_name: the name of the header field to access
406 - pos: the position in the batch `0` for first (default), ` -1` for last
408 ### Returns
409 Tuple with (value, comment) of the field
410 """
411 self.__check_loaded()
412 return self.batch[pos]["header"][field_name], self.batch[pos]["header"].comments[field_name]
414 def __len__(self) -> int:
415 """
416 Get the number of FITSes in the batch.
417 """
418 return len(self.batch)
420 # Load an individual FITS file
421 def __load_file(self, file, hdu=0):
422 content = FitsBatch.read(file, hdu=hdu, slices=self.slices)
423 if content[1] is not None:
424 self.batch.append(
425 {"file": file.file if isinstance(file, StateLink) else file, "header": content[0], "data": content[1]}
426 )
428 # Load an individual ICS FITS file
429 def __load_ics_file(self, file, hdu=0):
430 content = FitsBatch.read_ics(file, self.config, hdu=hdu, slices=self.slices)
431 self.batch.append({"file": file, "header": content[0], "data": content[1]})
433 # Load only the header of a given FITS file
434 def __load_header(self, file, hdu=0):
435 header = FitsBatch.read_header(file, hdu=hdu)
436 if header is not None:
437 self.batch.append({"file": file, "header": header, "data": None})
439 # Check if data has been loaded and raise exception if not
440 def __check_loaded(self):
441 if not self.batch:
442 raise IllegalStateException("No batch of FITS files loaded.")
444 # Convert an individual ics FITS file
445 def __convert_ics_file(self, file, overwrite=False, hdu=0):
446 FitsBatch.convert_ics(file, self.config, hdu=hdu, slices=self.slices, overwrite=overwrite)