Coverage for src/susi/reduc/shear_distortion/shear.py: 95%
42 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-06-13 14:15 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4Provides image shear distortion
6@author: iglesias
7"""
9import numpy as np
10import logging
11from skimage.transform import warp
13log = logging.getLogger("SUSI")
16class ShearDistortion:
17 """
18 Applys the following row non-linear shear distortion to the image:
19 row' = shear_factor/1e5 * col * row
20 col' = col
21 """
23 def __init__(self, array):
24 """
25 :param array: numpy array with the image data [row, col]
26 """
27 self.data = array # original image
28 self.result = None # distorted image
29 self.shear = None # shear factor
30 self.coord_off = None # offset to add to the coordinates of the image
31 self.out_roi = slice(None, None), slice(None, None) # output roi
33 def run(self, shear_factor, coord_off=(0, 0)):
34 """
35 shear_factor: float
36 """
37 self.shear = shear_factor
38 self.coord_off = coord_off
39 if self.shear == 0:
40 self.result = self.data
41 self.out_roi = slice(0, self.data.shape[0]), slice(0, self.data.shape[1])
42 else:
43 self.result = self.transform_image()
44 self.result = self.keep_valid_rectangular_roi()
45 return self.result
47 def transform_image(self):
48 transform = ShearDistortion.nl_shear
49 transformed_img = warp(
50 self.data,
51 inverse_map=transform,
52 map_args={"shear_factor": self.shear, "coord_off": self.coord_off},
53 mode="constant",
54 cval=np.NAN,
55 )
56 return transformed_img
58 def keep_valid_rectangular_roi(self):
59 roi = [None, None]
60 is_nan = np.isnan(self.result)
61 y = ~np.all(is_nan, axis=1)
62 x = ~np.all(is_nan, axis=0)
63 roi[0] = slice(np.min(np.where(y)), np.max(np.where(y)))
64 roi[1] = slice(np.min(np.where(x)), np.max(np.where(x)))
65 is_nan = is_nan[y, :][:, x]
66 y = np.all(~is_nan, axis=1)
67 roi[0] = slice(roi[0].start + np.min(np.where(y)), roi[0].start + np.max(np.where(y)))
68 self.out_roi = roi[0], roi[1]
69 return self.result[roi[0], roi[1]]
71 @staticmethod
72 def nl_shear(coords, shear_factor, coord_off):
73 x, y = coords.T
74 y += shear_factor / 1e5 * (y + coord_off[0]) * (x + coord_off[1])
75 return np.column_stack((x, y))