Coverage for src/susi/reduc/shear_distortion/shear_and_rot_correction.py: 98%

51 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4provides ShearAndRotCorrector correction class 

5 

6@author: iglesias 

7""" 

8import numpy as np 

9from scipy import ndimage 

10 

11from ...base import Logging, Config, Globals 

12from ...io import Fits 

13from .shear import ShearDistortion 

14from ...analyse.img_rotation import RotationCorrection 

15from ...base.header_keys import * 

16 

17logger = Logging.get_logger() 

18 

19 

20class ShearAndRotCorrector: 

21 """ 

22 Implements shear distortion correction (req. scalar Shear Factor) 

23 followed by global rigid rotation (req. scalar rot. angle) 

24 Crops output to a rectangular ROI of only usable data 

25 """ 

26 

27 def __init__(self, data: Fits, shear_fact: float, rot_ang: float): 

28 # Input data (frames, x, y) 

29 self.data = data.data.astype(float) 

30 #: rotation angle to apply 

31 self.rot_ang = rot_ang 

32 #: shear factor to apply 

33 self.shear_fact = shear_fact 

34 #: the output roi wrt the input image shape (row, col) 

35 self.out_roi = slice(None, None), slice(None, None) 

36 

37 def run(self) -> np.array: 

38 _, self.out_roi = self._correct_frame(0, get_out_roi=True) 

39 for i in range(self.data.shape[0]): 

40 self._correct_frame(i, modify_data_roi=self.out_roi) 

41 self.data = self.data[:, self.out_roi[0], self.out_roi[1]] 

42 return self.data 

43 

44 def _correct_frame(self, frame_no, get_out_roi=False, modify_data_roi=None) -> np.array: 

45 # TODO: Combine shear and rot in a single transformation 

46 scorr = ShearDistortion(self.data[frame_no]) 

47 result = scorr.run(self.shear_fact) 

48 rot = RotationCorrection(img=result, angle=self.rot_ang) 

49 result = rot.bicubic() 

50 if get_out_roi: 

51 min_y = scorr.out_roi[0].start + rot.out_roi[0].start 

52 max_y = scorr.out_roi[0].start + rot.out_roi[0].stop 

53 min_x = scorr.out_roi[1].start + rot.out_roi[1].start 

54 max_x = scorr.out_roi[1].start + rot.out_roi[1].stop 

55 return result, (slice(min_y, max_y), slice(min_x, max_x)) 

56 if modify_data_roi is not None: 

57 if rot.out_roi != modify_data_roi: 

58 raise ValueError("Fixed shear and rotation correction produced variable ROI for the input time series") 

59 self.data[frame_no][modify_data_roi] = result 

60 

61 

62class ShearAndRotLoader: 

63 """ 

64 Estimates the rotation and shear factors from a set of level 2 target files 

65 Current version: 

66 - Simply use the average of the values given in the files headers 

67 """ 

68 

69 def __init__(self, files: list): 

70 self.files = files 

71 self.shear_factor = None 

72 self.rotation = None 

73 

74 def run(self): 

75 shear = [] 

76 rot = [] 

77 for file in self.files: 

78 fits = Fits(file).read(header_only=True) 

79 shear.append(fits.header[SHEAR_CORR]) 

80 rot.append(fits.header[ROTATION_ANG]) 

81 self.shear_factor = np.mean(shear) 

82 # we saved the original rotation angle when processing grid to lvl2 

83 self.rotation = -np.mean(rot) 

84 return self.shear_factor, self.rotation