Coverage for src/susi/utils/cropping.py: 58%

60 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1from typing import Union 

2 

3import numpy as np 

4from astropy.io import fits 

5from src.susi.plot import roi 

6 

7from ..base import ROI_X0, ROI_Y0, ROI_Y1, ROI_X1, Globals 

8from ..io import FitsBatch 

9 

10""" 

11Utility functions to crop data to a common shape. 

12NOTE both header ROI_* header keys and the input roi dict are relative to 0,0, namely the original data shape. 

13""" 

14 

15 

16def adapt_shapes(batch: FitsBatch, roi: dict) -> FitsBatch: 

17 """ 

18 TODO: Rename to adapt_shapes_batch 

19 :param batch: a FitsBatch with (X,Y) or (S,X,Y) data to crop 

20 :param roi: a 2D ROI to crop to 

21 :return: The FitsBatch with cropped data 

22 """ 

23 for e in batch.batch: 

24 e['data'] = adapt_shape(e['data'], e['header'], roi) 

25 return batch 

26 

27 

28def adapt_shapes_fits(fits1, fits2, known_roi=None): 

29 """ 

30 Crop the fits images to the common ROI defined by the headers. 

31 :param fits1: The first fits file 

32 :param fits2: The second fits file 

33 :param known_roi: A dict with the ROI to crop to, if None the common ROI is computed and used 

34 """ 

35 if known_roi is None: 

36 max_x0 = max(fits1.header[ROI_X0], fits2.header[ROI_X0]) 

37 max_y0 = max(fits1.header[ROI_Y0], fits2.header[ROI_Y0]) 

38 min_x1 = min(fits1.header[ROI_X1], fits2.header[ROI_X1]) 

39 min_y1 = min(fits1.header[ROI_Y1], fits2.header[ROI_Y1]) 

40 crop_roi = {'x0': max_x0, 'y0': max_y0, 'x1': min_x1, 'y1': min_y1} 

41 else: 

42 crop_roi = known_roi 

43 

44 img1 = adapt_shape(fits1.data, fits1.header, crop_roi) 

45 img2 = adapt_shape(fits2.data, fits2.header, crop_roi) 

46 return img1, img2, crop_roi 

47 

48 

49def adapt_shapes_cubes(cube1, cube2, known_roi=None): 

50 """ 

51 Crop the cubes to a common roi in dimensions 2 and 3 (slit and wl). 

52 Assumes the ROI keys are the same for all headers. (3dExt cube format) 

53 :param cube1: The first cube [stks,scan,slit,wl] 

54 :param cube2: The second cube [stks,scan,slit,wl] 

55 :param known_roi: A dict with the ROI to crop to, if None the common ROI is computed and used 

56 """ 

57 if known_roi is None: 

58 max_x0 = max(cube1.header[0][ROI_X0], cube2.header[0][ROI_X0]) 

59 max_y0 = max(cube1.header[0][ROI_Y0], cube2.header[0][ROI_Y0]) 

60 min_x1 = min(cube1.header[0][ROI_X1], cube2.header[0][ROI_X1]) 

61 min_y1 = min(cube1.header[0][ROI_Y1], cube2.header[0][ROI_Y1]) 

62 crop_roi = {'x0': max_x0, 'y0': max_y0, 'x1': min_x1, 'y1': min_y1} 

63 else: 

64 crop_roi = known_roi 

65 

66 dx, dy = cube1.header[0][ROI_X0], cube1.header[0][ROI_Y0] 

67 roi1 = (slice(crop_roi['y0'] - dy, crop_roi['y1'] - dy), slice(crop_roi['x0'] - dx, crop_roi['x1'] - dx)) 

68 cube1.data = cube1.data[:, :, roi1[0], roi1[1]] 

69 

70 dx, dy = cube2.header[0][ROI_X0], cube2.header[0][ROI_Y0] 

71 roi2 = (slice(crop_roi['y0'] - dy, crop_roi['y1'] - dy), slice(crop_roi['x0'] - dx, crop_roi['x1'] - dx)) 

72 cube2.data = cube2.data[:, :, roi2[0], roi2[1]] 

73 

74 return cube1, cube2, crop_roi 

75 

76 

77def adapt_shape(data: np.array, header: Union[dict, fits.Header], roi: dict) -> np.array: 

78 """ 

79 :param data: The (Y,X) or (Z,Y,X) data to crop 

80 :param header: The header of the data 

81 :param roi: a 2D ROI to crop to 

82 :return: The cropped data 

83 """ 

84 dx = int(header[ROI_X0]) if ROI_X0 in header else 0 

85 dy = int(header[ROI_Y0]) if ROI_Y0 in header else 0 

86 roi = (slice(roi['y0'] - dy, roi['y1'] - dy), slice(roi['x0'] - dx, roi['x1'] - dx)) 

87 return crop(data, roi) 

88 

89 

90def crop(data: np.array, roi: tuple) -> np.array: 

91 """ 

92 :param data: The data to crop (Y,X), (Z,Y,X) or (Y,X,A,B,..) 

93 :param roi: The (Y,X) shape to crop to 

94 :return: The cropped data 

95 """ 

96 ndim = len(data.shape) 

97 if ndim == 3: 

98 return data[(slice(0, None), roi[0], roi[1])] 

99 return data[roi] 

100 

101 

102def common_shape(batch: FitsBatch) -> dict: 

103 return { 

104 'y0': _extract_int(batch, ROI_Y0, 'max'), 

105 'y1': _extract_int(batch, ROI_Y1, 'min'), 

106 'x0': _extract_int(batch, ROI_X0, 'max'), 

107 'x1': _extract_int(batch, ROI_X1, 'min'), 

108 } 

109 

110 

111def check_same_shape(batch: FitsBatch) -> bool: 

112 """ 

113 Check if all entries in the batch have the same shape 

114 :param batch: The FitsBatch to check 

115 :return: True if all entries have the same shape 

116 """ 

117 shapes = set([e['data'].shape for e in batch.batch]) 

118 return len(shapes) == 1 

119 

120 

121def _extract_int(batch: FitsBatch, key: str, method: str = 'max') -> Union[int, None]: 

122 default = Globals.MAX_PX if method == 'max' else 0 

123 values = set([int(e['header'][key]) if key in e['header'] else default for e in batch.batch]) 

124 if not values or tuple(values) == (None,): 

125 return None 

126 if method == 'max': 

127 return max(values) 

128 if method == 'min': 

129 return min(values)