Coverage for src/susi/utils/cropping.py: 58%
60 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
1from typing import Union
3import numpy as np
4from astropy.io import fits
5from src.susi.plot import roi
7from ..base import ROI_X0, ROI_Y0, ROI_Y1, ROI_X1, Globals
8from ..io import FitsBatch
10"""
11Utility functions to crop data to a common shape.
12NOTE both header ROI_* header keys and the input roi dict are relative to 0,0, namely the original data shape.
13"""
16def adapt_shapes(batch: FitsBatch, roi: dict) -> FitsBatch:
17 """
18 TODO: Rename to adapt_shapes_batch
19 :param batch: a FitsBatch with (X,Y) or (S,X,Y) data to crop
20 :param roi: a 2D ROI to crop to
21 :return: The FitsBatch with cropped data
22 """
23 for e in batch.batch:
24 e['data'] = adapt_shape(e['data'], e['header'], roi)
25 return batch
28def adapt_shapes_fits(fits1, fits2, known_roi=None):
29 """
30 Crop the fits images to the common ROI defined by the headers.
31 :param fits1: The first fits file
32 :param fits2: The second fits file
33 :param known_roi: A dict with the ROI to crop to, if None the common ROI is computed and used
34 """
35 if known_roi is None:
36 max_x0 = max(fits1.header[ROI_X0], fits2.header[ROI_X0])
37 max_y0 = max(fits1.header[ROI_Y0], fits2.header[ROI_Y0])
38 min_x1 = min(fits1.header[ROI_X1], fits2.header[ROI_X1])
39 min_y1 = min(fits1.header[ROI_Y1], fits2.header[ROI_Y1])
40 crop_roi = {'x0': max_x0, 'y0': max_y0, 'x1': min_x1, 'y1': min_y1}
41 else:
42 crop_roi = known_roi
44 img1 = adapt_shape(fits1.data, fits1.header, crop_roi)
45 img2 = adapt_shape(fits2.data, fits2.header, crop_roi)
46 return img1, img2, crop_roi
49def adapt_shapes_cubes(cube1, cube2, known_roi=None):
50 """
51 Crop the cubes to a common roi in dimensions 2 and 3 (slit and wl).
52 Assumes the ROI keys are the same for all headers. (3dExt cube format)
53 :param cube1: The first cube [stks,scan,slit,wl]
54 :param cube2: The second cube [stks,scan,slit,wl]
55 :param known_roi: A dict with the ROI to crop to, if None the common ROI is computed and used
56 """
57 if known_roi is None:
58 max_x0 = max(cube1.header[0][ROI_X0], cube2.header[0][ROI_X0])
59 max_y0 = max(cube1.header[0][ROI_Y0], cube2.header[0][ROI_Y0])
60 min_x1 = min(cube1.header[0][ROI_X1], cube2.header[0][ROI_X1])
61 min_y1 = min(cube1.header[0][ROI_Y1], cube2.header[0][ROI_Y1])
62 crop_roi = {'x0': max_x0, 'y0': max_y0, 'x1': min_x1, 'y1': min_y1}
63 else:
64 crop_roi = known_roi
66 dx, dy = cube1.header[0][ROI_X0], cube1.header[0][ROI_Y0]
67 roi1 = (slice(crop_roi['y0'] - dy, crop_roi['y1'] - dy), slice(crop_roi['x0'] - dx, crop_roi['x1'] - dx))
68 cube1.data = cube1.data[:, :, roi1[0], roi1[1]]
70 dx, dy = cube2.header[0][ROI_X0], cube2.header[0][ROI_Y0]
71 roi2 = (slice(crop_roi['y0'] - dy, crop_roi['y1'] - dy), slice(crop_roi['x0'] - dx, crop_roi['x1'] - dx))
72 cube2.data = cube2.data[:, :, roi2[0], roi2[1]]
74 return cube1, cube2, crop_roi
77def adapt_shape(data: np.array, header: Union[dict, fits.Header], roi: dict) -> np.array:
78 """
79 :param data: The (Y,X) or (Z,Y,X) data to crop
80 :param header: The header of the data
81 :param roi: a 2D ROI to crop to
82 :return: The cropped data
83 """
84 dx = int(header[ROI_X0]) if ROI_X0 in header else 0
85 dy = int(header[ROI_Y0]) if ROI_Y0 in header else 0
86 roi = (slice(roi['y0'] - dy, roi['y1'] - dy), slice(roi['x0'] - dx, roi['x1'] - dx))
87 return crop(data, roi)
90def crop(data: np.array, roi: tuple) -> np.array:
91 """
92 :param data: The data to crop (Y,X), (Z,Y,X) or (Y,X,A,B,..)
93 :param roi: The (Y,X) shape to crop to
94 :return: The cropped data
95 """
96 ndim = len(data.shape)
97 if ndim == 3:
98 return data[(slice(0, None), roi[0], roi[1])]
99 return data[roi]
102def common_shape(batch: FitsBatch) -> dict:
103 return {
104 'y0': _extract_int(batch, ROI_Y0, 'max'),
105 'y1': _extract_int(batch, ROI_Y1, 'min'),
106 'x0': _extract_int(batch, ROI_X0, 'max'),
107 'x1': _extract_int(batch, ROI_X1, 'min'),
108 }
111def check_same_shape(batch: FitsBatch) -> bool:
112 """
113 Check if all entries in the batch have the same shape
114 :param batch: The FitsBatch to check
115 :return: True if all entries have the same shape
116 """
117 shapes = set([e['data'].shape for e in batch.batch])
118 return len(shapes) == 1
121def _extract_int(batch: FitsBatch, key: str, method: str = 'max') -> Union[int, None]:
122 default = Globals.MAX_PX if method == 'max' else 0
123 values = set([int(e['header'][key]) if key in e['header'] else default for e in batch.batch])
124 if not values or tuple(values) == (None,):
125 return None
126 if method == 'max':
127 return max(values)
128 if method == 'min':
129 return min(values)