Coverage for src/susi/reduc/shear_distortion/shear.py: 95%

42 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2025-06-13 14:15 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4Provides image shear distortion 

5 

6@author: iglesias 

7""" 

8 

9import numpy as np 

10import logging 

11from skimage.transform import warp 

12 

13log = logging.getLogger("SUSI") 

14 

15 

16class ShearDistortion: 

17 """ 

18 Applys the following row non-linear shear distortion to the image: 

19 row' = shear_factor/1e5 * col * row 

20 col' = col 

21 """ 

22 

23 def __init__(self, array): 

24 """ 

25 :param array: numpy array with the image data [row, col] 

26 """ 

27 self.data = array # original image 

28 self.result = None # distorted image 

29 self.shear = None # shear factor 

30 self.coord_off = None # offset to add to the coordinates of the image 

31 self.out_roi = slice(None, None), slice(None, None) # output roi 

32 

33 def run(self, shear_factor, coord_off=(0, 0)): 

34 """ 

35 shear_factor: float 

36 """ 

37 self.shear = shear_factor 

38 self.coord_off = coord_off 

39 if self.shear == 0: 

40 self.result = self.data 

41 self.out_roi = slice(0, self.data.shape[0]), slice(0, self.data.shape[1]) 

42 else: 

43 self.result = self.transform_image() 

44 self.result = self.keep_valid_rectangular_roi() 

45 return self.result 

46 

47 def transform_image(self): 

48 transform = ShearDistortion.nl_shear 

49 transformed_img = warp( 

50 self.data, 

51 inverse_map=transform, 

52 map_args={"shear_factor": self.shear, "coord_off": self.coord_off}, 

53 mode="constant", 

54 cval=np.NAN, 

55 ) 

56 return transformed_img 

57 

58 def keep_valid_rectangular_roi(self): 

59 roi = [None, None] 

60 is_nan = np.isnan(self.result) 

61 y = ~np.all(is_nan, axis=1) 

62 x = ~np.all(is_nan, axis=0) 

63 roi[0] = slice(np.min(np.where(y)), np.max(np.where(y))) 

64 roi[1] = slice(np.min(np.where(x)), np.max(np.where(x))) 

65 is_nan = is_nan[y, :][:, x] 

66 y = np.all(~is_nan, axis=1) 

67 roi[0] = slice(roi[0].start + np.min(np.where(y)), roi[0].start + np.max(np.where(y))) 

68 self.out_roi = roi[0], roi[1] 

69 return self.result[roi[0], roi[1]] 

70 

71 @staticmethod 

72 def nl_shear(coords, shear_factor, coord_off): 

73 x, y = coords.T 

74 y += shear_factor / 1e5 * (y + coord_off[0]) * (x + coord_off[1]) 

75 return np.column_stack((x, y))