Coverage for arosics/Tie_Point_Grid.py: 87%

534 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-02-12 23:30 +0000

1# -*- coding: utf-8 -*- 

2 

3# AROSICS - Automated and Robust Open-Source Image Co-Registration Software 

4# 

5# Copyright (C) 2017-2024 

6# - Daniel Scheffler (GFZ Potsdam, daniel.scheffler@gfz-potsdam.de) 

7# - Helmholtz Centre Potsdam - GFZ German Research Centre for Geosciences Potsdam, 

8# Germany (https://www.gfz-potsdam.de/) 

9# 

10# This software was developed within the context of the GeoMultiSens project funded 

11# by the German Federal Ministry of Education and Research 

12# (project grant code: 01 IS 14 010 A-C). 

13# 

14# Licensed under the Apache License, Version 2.0 (the "License"); 

15# you may not use this file except in compliance with the License. 

16# You may obtain a copy of the License at 

17# 

18# https://www.apache.org/licenses/LICENSE-2.0 

19# 

20# Unless required by applicable law or agreed to in writing, software 

21# distributed under the License is distributed on an "AS IS" BASIS, 

22# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

23# See the License for the specific language governing permissions and 

24# limitations under the License. 

25 

26import os 

27import warnings 

28from time import time 

29from typing import Optional 

30from sys import platform 

31 

32# custom 

33from osgeo import gdal # noqa 

34import numpy as np 

35from geopandas import GeoDataFrame 

36from pandas import DataFrame, Series 

37from shapely.geometry import Point 

38from matplotlib import pyplot as plt 

39from scipy.interpolate import RBFInterpolator, RegularGridInterpolator 

40from joblib import Parallel, delayed 

41 

42# internal modules 

43from .CoReg import COREG 

44from py_tools_ds.geo.projection import isLocal 

45from py_tools_ds.io.pathgen import get_generic_outpath 

46from py_tools_ds.processing.progress_mon import ProgressBar 

47from py_tools_ds.geo.vector.conversion import points_to_raster 

48from geoarray import GeoArray 

49 

50from .CoReg import GeoArray_CoReg # noqa F401 # flake8 issue 

51 

52__author__ = 'Daniel Scheffler' 

53 

54 

55class Tie_Point_Grid(object): 

56 """ 

57 The 'Tie_Point_Grid' class applies the algorithm to detect spatial shifts to the overlap area of the input images. 

58 

59 Spatial shifts are calculated for each point in grid of which the parameters can be adjusted using keyword 

60 arguments. Shift correction performs a polynomial transformation using te calculated shifts of each point in the 

61 grid as GCPs. Thus, 'Tie_Point_Grid' can be used to correct for locally varying geometric distortions of the target 

62 image. 

63 

64 See help(Tie_Point_Grid) for documentation! 

65 """ 

66 

67 def __init__(self, 

68 COREG_obj: COREG, 

69 grid_res: float, 

70 max_points: int = None, 

71 outFillVal: int = -9999, 

72 resamp_alg_calc: str = 'cubic', 

73 tieP_filter_level: int = 3, 

74 outlDetect_settings: dict = None, 

75 dir_out: str = None, 

76 CPUs: int = None, 

77 progress: bool = True, 

78 v: bool = False, 

79 q: bool = False): 

80 """Get an instance of the 'Tie_Point_Grid' class. 

81 

82 :param COREG_obj: 

83 an instance of COREG class 

84 

85 :param grid_res: 

86 grid resolution in pixels of the target image (x-direction) 

87 

88 :param max_points: 

89 maximum number of points used to find coregistration tie points 

90 

91 NOTE: Points are selected randomly from the given point grid (specified by 'grid_res'). If the point does 

92 not provide enough points, all available points are chosen. 

93 

94 :param outFillVal: 

95 if given the generated tie points grid is filled with this value in case no match could be found during 

96 co-registration (default: -9999) 

97 

98 :param resamp_alg_calc: 

99 the resampling algorithm to be used for all warping processes during calculation of spatial shifts 

100 (valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average, mode, max, min, med, q1, q3) 

101 default: cubic (highly recommended) 

102 

103 :param tieP_filter_level: 

104 filter tie points used for shift correction in different levels (default: 3). 

105 NOTE: lower levels are also included if a higher level is chosen 

106 

107 - Level 0: no tie point filtering 

108 - Level 1: Reliablity filtering 

109 - filter all tie points out that have a low reliability according to internal tests 

110 - Level 2: SSIM filtering 

111 - filters all tie points out where shift correction does not increase image similarity within 

112 matching window (measured by mean structural similarity index) 

113 - Level 3: RANSAC outlier detection 

114 

115 :param outlDetect_settings: 

116 a dictionary with the settings to be passed to arosics.TiePointGrid.Tie_Point_Refiner. 

117 Available keys: min_reliability, rs_max_outlier, rs_tolerance, rs_max_iter, rs_exclude_previous_outliers, 

118 rs_timeout, rs_random_state, q. See documentation there. 

119 

120 :param dir_out: 

121 output directory to be used for all outputs if nothing else is given to the individual methods 

122 

123 :param CPUs: 

124 number of CPUs to use during calculation of tie points grid 

125 (default: None, which means 'all CPUs available') 

126 

127 :param progress: 

128 show progress bars (default: True) 

129 

130 :param v: 

131 verbose mode (default: False) 

132 

133 :param q: 

134 quiet mode (default: False) 

135 """ 

136 if not isinstance(COREG_obj, COREG): 

137 raise ValueError("'COREG_obj' must be an instance of COREG class.") 

138 

139 self.COREG_obj = COREG_obj # type: COREG 

140 self.grid_res = grid_res 

141 self.max_points = max_points 

142 self.outFillVal = outFillVal 

143 self.rspAlg_calc = resamp_alg_calc 

144 self.tieP_filter_level = tieP_filter_level 

145 self.outlDetect_settings = outlDetect_settings or dict() 

146 self.dir_out = dir_out 

147 self.CPUs = CPUs 

148 self.v = v 

149 self.q = q if not v else False # overridden by v 

150 self.progress = progress if not q else False # overridden by q 

151 

152 if 'q' not in self.outlDetect_settings: 

153 self.outlDetect_settings['q'] = self.q 

154 

155 self.ref = self.COREG_obj.ref # type: GeoArray_CoReg 

156 self.shift = self.COREG_obj.shift # type: GeoArray_CoReg 

157 

158 self.XY_points, self.XY_mapPoints = self._get_imXY__mapXY_points(self.grid_res) 

159 self._CoRegPoints_table = None # set by self.CoRegPoints_table 

160 self._GCPList = None # set by self.to_GCPList() 

161 

162 @property 

163 def mean_x_shift_px(self): 

164 return self.CoRegPoints_table['X_SHIFT_PX'][self.CoRegPoints_table['X_SHIFT_PX'] != self.outFillVal].mean() 

165 

166 @property 

167 def mean_y_shift_px(self): 

168 return self.CoRegPoints_table['Y_SHIFT_PX'][self.CoRegPoints_table['Y_SHIFT_PX'] != self.outFillVal].mean() 

169 

170 @property 

171 def mean_x_shift_map(self): 

172 return self.CoRegPoints_table['X_SHIFT_M'][self.CoRegPoints_table['X_SHIFT_M'] != self.outFillVal].mean() 

173 

174 @property 

175 def mean_y_shift_map(self): 

176 return self.CoRegPoints_table['Y_SHIFT_M'][self.CoRegPoints_table['Y_SHIFT_M'] != self.outFillVal].mean() 

177 

178 @property 

179 def CoRegPoints_table(self): 

180 """Return a GeoDataFrame containing all the results from coregistration for all points in the tie point grid. 

181 

182 Columns of the GeoDataFrame: 'geometry','POINT_ID','X_IM','Y_IM','X_MAP','Y_MAP','X_WIN_SIZE', 'Y_WIN_SIZE', 

183 'X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M', 'Y_SHIFT_M', 'ABS_SHIFT' and 'ANGLE' 

184 """ 

185 if self._CoRegPoints_table is not None: 

186 return self._CoRegPoints_table 

187 else: 

188 self._CoRegPoints_table = self.get_CoRegPoints_table() 

189 return self._CoRegPoints_table 

190 

191 @CoRegPoints_table.setter 

192 def CoRegPoints_table(self, CoRegPoints_table): 

193 self._CoRegPoints_table = CoRegPoints_table 

194 

195 @property 

196 def GCPList(self): 

197 """Return a list of GDAL compatible GCP objects.""" 

198 if self._GCPList: 

199 return self._GCPList 

200 else: 

201 self._GCPList = self.to_GCPList() 

202 return self._GCPList 

203 

204 @GCPList.setter 

205 def GCPList(self, GCPList): 

206 self._GCPList = GCPList 

207 

208 def _get_imXY__mapXY_points(self, grid_res): 

209 """Return a numpy array containing possible positions for coregistration tie points. 

210 

211 NOTE: The returned positions are dependent from the given grid resolution. 

212 

213 :param grid_res: 

214 :return: 

215 """ 

216 if not self.q: 

217 print('Initializing tie points grid...') 

218 

219 Xarr, Yarr = np.meshgrid(np.arange(0, self.shift.shape[1] + grid_res, grid_res), 

220 np.arange(0, self.shift.shape[0] + grid_res, grid_res)) 

221 

222 mapXarr = np.full_like(Xarr, self.shift.gt[0], dtype=np.float64) + Xarr * self.shift.gt[1] 

223 mapYarr = np.full_like(Yarr, self.shift.gt[3], dtype=np.float64) - Yarr * abs(self.shift.gt[5]) 

224 

225 XY_points = np.empty((Xarr.size, 2), Xarr.dtype) 

226 XY_points[:, 0] = Xarr.flat 

227 XY_points[:, 1] = Yarr.flat 

228 

229 XY_mapPoints = np.empty((mapXarr.size, 2), mapXarr.dtype) 

230 XY_mapPoints[:, 0] = mapXarr.flat 

231 XY_mapPoints[:, 1] = mapYarr.flat 

232 

233 assert XY_points.shape == XY_mapPoints.shape 

234 

235 return XY_points, XY_mapPoints 

236 

237 def _exclude_bad_XYpos(self, GDF): 

238 """Exclude all points outside the image overlap area and where the bad data mask is True (if given). 

239 

240 :param GDF: <geopandas.GeoDataFrame> must include the columns 'X_MAP' and 'Y_MAP' 

241 :return: 

242 """ 

243 from skimage.measure import points_in_poly # import here to avoid static TLS ImportError 

244 

245 # exclude all points outside of overlap area 

246 inliers = points_in_poly(self.XY_mapPoints, 

247 np.swapaxes(np.array(self.COREG_obj.overlap_poly.exterior.coords.xy), 0, 1)) 

248 GDF = GDF[inliers].copy() 

249 # GDF = GDF[GDF['geometry'].within(self.COREG_obj.overlap_poly.simplify(tolerance=15))] # works but much slower 

250 

251 assert not GDF.empty, 'No coregistration point could be placed within the overlap area. Check your input data!' 

252 

253 # exclude all points where bad data mask is True (e.g. points on clouds etc.) 

254 orig_len_GDF = len(GDF) # length of GDF after dropping all points outside the overlap polygon 

255 mapXY = np.array(GDF.loc[:, ['X_MAP', 'Y_MAP']]) 

256 GDF['REF_BADDATA'] = self.COREG_obj.ref.mask_baddata.read_pointData(mapXY) \ 

257 if self.COREG_obj.ref.mask_baddata is not None else False 

258 GDF['TGT_BADDATA'] = self.COREG_obj.shift.mask_baddata.read_pointData(mapXY).flatten().astype(bool) \ 

259 if self.COREG_obj.shift.mask_baddata is not None else False 

260 GDF = GDF[(~GDF['REF_BADDATA']) & (~GDF['TGT_BADDATA'])] 

261 if self.COREG_obj.ref.mask_baddata is not None or self.COREG_obj.shift.mask_baddata is not None: 

262 if not self.q: 

263 if not GDF.empty: 

264 print('With respect to the provided bad data mask(s) %s points of initially %s have been excluded.' 

265 % (orig_len_GDF - len(GDF), orig_len_GDF)) 

266 else: 

267 warnings.warn('With respect to the provided bad data mask(s) no coregistration point could be ' 

268 'placed within an image area usable for coregistration.') 

269 

270 return GDF 

271 

272 @staticmethod 

273 def _get_spatial_shifts(imref, im2shift, point_id, fftw_works, **coreg_kwargs): 

274 # run CoReg 

275 CR = COREG(imref, im2shift, CPUs=1, **coreg_kwargs) 

276 CR.fftw_works = fftw_works 

277 CR.calculate_spatial_shifts() 

278 

279 # fetch results 

280 last_err = CR.tracked_errors[-1] if CR.tracked_errors else None 

281 win_sz_y, win_sz_x = CR.matchBox.imDimsYX if CR.matchBox else (None, None) 

282 CR_res = [win_sz_x, win_sz_y, CR.x_shift_px, CR.y_shift_px, CR.x_shift_map, CR.y_shift_map, 

283 CR.vec_length_map, CR.vec_angle_deg, CR.ssim_orig, CR.ssim_deshifted, CR.ssim_improved, 

284 CR.shift_reliability, last_err] 

285 

286 return [point_id] + CR_res 

287 

288 def get_CoRegPoints_table(self): 

289 assert self.XY_points is not None and self.XY_mapPoints is not None 

290 

291 # create a dataframe containing 'geometry','POINT_ID','X_IM','Y_IM','X_MAP','Y_MAP' 

292 # (convert imCoords to mapCoords 

293 XYarr2PointGeom = np.vectorize(lambda X, Y: Point(X, Y), otypes=[Point]) 

294 geomPoints = np.array(XYarr2PointGeom(self.XY_mapPoints[:, 0], self.XY_mapPoints[:, 1])) 

295 

296 crs = self.COREG_obj.shift.prj if not isLocal(self.COREG_obj.shift.prj) else None 

297 

298 GDF = GeoDataFrame(index=range(len(geomPoints)), 

299 crs=crs, 

300 columns=['geometry', 'POINT_ID', 'X_IM', 'Y_IM', 'X_MAP', 'Y_MAP']) 

301 GDF['geometry'] = geomPoints 

302 GDF['POINT_ID'] = range(len(geomPoints)) 

303 GDF[['X_IM', 'Y_IM']] = self.XY_points 

304 GDF[['X_MAP', 'Y_MAP']] = self.XY_mapPoints 

305 

306 # exclude offsite points and points on bad data mask 

307 GDF = self._exclude_bad_XYpos(GDF) 

308 if GDF.empty: 

309 self.CoRegPoints_table = GDF 

310 return self.CoRegPoints_table 

311 

312 # choose a random subset of points if a maximum number has been given 

313 if self.max_points and len(GDF) > self.max_points: 

314 GDF = GDF.sample(self.max_points).copy() 

315 

316 # equalize pixel grids in order to save warping time 

317 if len(GDF) > 100: 

318 # NOTE: actually grid res should be also changed here because self.shift.xgsd changes and grid res is 

319 # connected to that 

320 self.COREG_obj.equalize_pixGrids() 

321 self.ref = self.COREG_obj.ref 

322 self.shift = self.COREG_obj.shift 

323 

324 # validate reference and target image inputs 

325 assert self.ref.footprint_poly # this also checks for mask_nodata and nodata value 

326 assert self.shift.footprint_poly 

327 

328 # ensure the input arrays for CoReg are in memory -> otherwise the code will get stuck in multiprocessing if 

329 # neighboured matching windows overlap during reading from disk!! 

330 self.ref.cache_array_subset( 

331 [self.COREG_obj.ref.band4match]) # only sets geoArr._arr_cache; does not change number of bands 

332 self.shift.cache_array_subset([self.COREG_obj.shift.band4match]) 

333 

334 print(f"Calculating tie point grid ({len(GDF)} points) using {self.CPUs} CPU cores...") 

335 results = [] 

336 bar = ProgressBar(prefix='\tprogress:') 

337 

338 # multiprocessing backend is slightly faster on Linux but can only return a list (no progress) 

339 if platform != 'win32' and (not self.progress or self.q): 

340 kw_parallel = dict(backend='multiprocessing', return_as='list') 

341 else: 

342 kw_parallel = dict(backend='loky', return_as='generator') 

343 

344 for i, res in enumerate( 

345 Parallel(n_jobs=self.CPUs, **kw_parallel)( 

346 delayed(self._get_spatial_shifts)( 

347 self.ref, 

348 self.shift, 

349 point_id, 

350 self.COREG_obj.fftw_works, 

351 wp=self.XY_mapPoints[point_id], 

352 ws=self.COREG_obj.win_size_XY, 

353 resamp_alg_calc=self.rspAlg_calc, 

354 footprint_poly_ref=self.COREG_obj.ref.poly, 

355 footprint_poly_tgt=self.COREG_obj.shift.poly, 

356 r_b4match=self.ref.band4match + 1, # internally indexing from 0 

357 s_b4match=self.shift.band4match + 1, # internally indexing from 0 

358 max_iter=self.COREG_obj.max_iter, 

359 max_shift=self.COREG_obj.max_shift, 

360 nodata=(self.COREG_obj.ref.nodata, self.COREG_obj.shift.nodata), 

361 force_quadratic_win=self.COREG_obj.force_quadratic_win, 

362 binary_ws=self.COREG_obj.bin_ws, 

363 v=False, # True leads to massive STDOUT 

364 q=True, # True leads to massive STDOUT 

365 ignore_errors=True 

366 ) for point_id in GDF.index 

367 ) 

368 ): 

369 results.append(res) 

370 

371 if self.progress and not self.q: 

372 bar.print_progress(percent=(i + 1) / len(GDF) * 100) 

373 

374 # merge results with GDF 

375 # NOTE: We use a pandas.DataFrame here because the geometry column is missing. 

376 # GDF.astype(...) fails with geopandas>0.6.0 if the geometry columns is missing. 

377 records = DataFrame(results, 

378 columns=['POINT_ID', 'X_WIN_SIZE', 'Y_WIN_SIZE', 'X_SHIFT_PX', 'Y_SHIFT_PX', 'X_SHIFT_M', 

379 'Y_SHIFT_M', 'ABS_SHIFT', 'ANGLE', 'SSIM_BEFORE', 'SSIM_AFTER', 

380 'SSIM_IMPROVED', 'RELIABILITY', 'LAST_ERR']) 

381 

382 # merge DataFrames (dtype must be equal to records.dtypes; We need np.object due to None values) 

383 GDF = GDF.astype(object).merge(records.astype(object), on='POINT_ID', how="inner") 

384 GDF = GDF.replace([np.nan, None], int(self.outFillVal)) # fillna fails with geopandas==0.6.0 

385 GDF.crs = crs # gets lost when using GDF.astype(np.object), so we have to reassign that 

386 

387 n_matches = len(GDF[GDF.LAST_ERR == int(self.outFillVal)]) 

388 

389 if not self.q: 

390 print(f"Found {n_matches} matches.") 

391 

392 # filter tie points according to given filter level 

393 if n_matches > 0 and self.tieP_filter_level > 0: 

394 if not self.q: 

395 print('Performing validity checks...') 

396 TPR = Tie_Point_Refiner(GDF[GDF.ABS_SHIFT != self.outFillVal], **self.outlDetect_settings) 

397 GDF_filt, new_columns = TPR.run_filtering(level=self.tieP_filter_level) 

398 GDF = GDF.merge(GDF_filt[['POINT_ID'] + new_columns], on='POINT_ID', how="outer") 

399 

400 GDF = GDF.replace([np.nan, None], int(self.outFillVal)) # fillna fails with geopandas==0.6.0 

401 

402 self.CoRegPoints_table = GDF 

403 

404 if not self.q: 

405 if n_matches == 0 or GDF.empty: 

406 warnings.warn('No valid GCPs could by identified.') 

407 else: 

408 if self.tieP_filter_level > 0: 

409 print("%d valid tie points remain after filtering." % len(GDF[GDF.OUTLIER.__eq__(False)])) 

410 

411 return self.CoRegPoints_table 

412 

413 def calc_rmse(self, include_outliers: bool = False) -> float: 

414 """Calculate root-mean-square error of absolute shifts from the tie point grid. 

415 

416 :param include_outliers: whether to include tie points that have been marked as false-positives (if present) 

417 """ 

418 if self.CoRegPoints_table.empty: 

419 raise RuntimeError('Cannot compute the RMSE because no tie points were found at all.') 

420 

421 tbl = self.CoRegPoints_table 

422 tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == 0].copy() if 'OUTLIER' in tbl.columns else tbl 

423 

424 if not include_outliers and tbl.empty: 

425 raise RuntimeError('Cannot compute the RMSE because all tie points are flagged as false-positives.') 

426 

427 shifts = np.array(tbl['ABS_SHIFT']) 

428 shifts_sq = [i * i for i in shifts if i != self.outFillVal] 

429 

430 return np.sqrt(sum(shifts_sq) / len(shifts_sq)) 

431 

432 def calc_overall_ssim(self, 

433 include_outliers: bool = False, 

434 after_correction: bool = True 

435 ) -> float: 

436 """Calculate the median value of all SSIM values contained in tie point grid. 

437 

438 :param include_outliers: whether to include tie points that have been marked as false-positives 

439 :param after_correction: whether to compute median SSIM before correction or after 

440 """ 

441 if self.CoRegPoints_table.empty: 

442 raise RuntimeError('Cannot compute the overall SSIM because no tie points were found at all.') 

443 

444 tbl = self.CoRegPoints_table 

445 tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == 0].copy() 

446 

447 if not include_outliers and tbl.empty: 

448 raise RuntimeError('Cannot compute the overall SSIM because all tie points are flagged as false-positives.') 

449 

450 ssim_col = np.array(tbl['SSIM_AFTER' if after_correction else 'SSIM_BEFORE']) 

451 ssim_col = [i * i for i in ssim_col if i != self.outFillVal] 

452 

453 return float(np.median(ssim_col)) 

454 

455 def calc_overall_stats(self, include_outliers: bool = False) -> dict: 

456 """Calculate statistics like RMSE, MSE, MAE, ... from the tie point grid. 

457 

458 Full list of returned statistics: 

459 

460 - N_TP: number of tie points 

461 - N_VALID_TP: number of valid tie points 

462 - N_INVALID_TP: number of invalid tie points (false-positives) 

463 - PERC_VALID_TP: percentage of valid tie points 

464 - RMSE_M: root mean squared error of absolute shift vector length in map units 

465 - RMSE_X_M: root mean squared error of shift vector length in x-direction in map units 

466 - RMSE_Y_M: root mean squared error of shift vector length in y-direction in map units 

467 - RMSE_X_PX: root mean squared error of shift vector length in x-direction in pixel units 

468 - RMSE_Y_PX: root mean squared error of shift vector length in y-direction in pixel units 

469 - MSE_M: mean squared error of absolute shift vector length in map units 

470 - MSE_X_M: mean squared error of shift vector length in x-direction in map units 

471 - MSE_Y_M: mean squared error of shift vector length in y-direction in map units 

472 - MSE_X_PX: mean squared error of shift vector length in x-direction in pixel units 

473 - MSE_Y_PX: mean squared error of shift vector length in y-direction in pixel units 

474 - MAE_M: mean absolute error of absolute shift vector length in map units 

475 - MAE_X_M: mean absolute error of shift vector length in x-direction in map units 

476 - MAE_Y_M: mean absolute error of shift vector length in y-direction in map units 

477 - MAE_X_PX: mean absolute error of shift vector length in x-direction in pixel units 

478 - MAE_Y_PX: mean absolute error of shift vector length in y-direction in pixel units 

479 - MEAN_ABS_SHIFT: mean absolute shift vector length in map units 

480 - MEAN_X_SHIFT_M: mean shift vector length in x-direction in map units 

481 - MEAN_Y_SHIFT_M: mean shift vector length in y-direction in map units 

482 - MEAN_X_SHIFT_PX: mean shift vector length in x-direction in pixel units 

483 - MEAN_Y_SHIFT_PX: mean shift vector length in y-direction in pixel units 

484 - MEAN_ANGLE: mean direction of the shift vectors in degrees from north 

485 - MEAN_SSIM_BEFORE: mean structural similatity index within each matching window before co-registration 

486 - MEAN_SSIM_AFTER: mean structural similatity index within each matching window after co-registration 

487 - MEAN_RELIABILITY: mean tie point reliability in percent 

488 - MEDIAN_ABS_SHIFT: median absolute shift vector length in map units 

489 - MEDIAN_X_SHIFT_M: median shift vector length in x-direction in map units 

490 - MEDIAN_Y_SHIFT_M: median shift vector length in y-direction in map units 

491 - MEDIAN_X_SHIFT_PX: median shift vector length in x-direction in pixel units 

492 - MEDIAN_Y_SHIFT_PX: median shift vector length in y-direction in pixel units 

493 - MEDIAN_ANGLE: median direction of the shift vectors in degrees from north 

494 - MEDIAN_SSIM_BEFORE: median structural similatity index within each matching window before co-registration 

495 - MEDIAN_SSIM_AFTER: median structural similatity index within each matching window after co-registration 

496 - MEDIAN_RELIABILITY: median tie point reliability in percent 

497 - STD_ABS_SHIFT: standard deviation of absolute shift vector length in map units 

498 - STD_X_SHIFT_M: standard deviation of shift vector length in x-direction in map units 

499 - STD_Y_SHIFT_M: standard deviation of shift vector length in y-direction in map units 

500 - STD_X_SHIFT_PX: standard deviation of shift vector length in x-direction in pixel units 

501 - STD_Y_SHIFT_PX: standard deviation of shift vector length in y-direction in pixel units 

502 - STD_ANGLE: standard deviation of direction of the shift vectors in degrees from north 

503 - STD_SSIM_BEFORE: standard deviation of structural similatity index within each matching window before 

504 co-registration 

505 - STD_SSIM_AFTER: standard deviation of structural similatity index within each matching window after 

506 co-registration 

507 - STD_RELIABILITY: standard deviation of tie point reliability in percent 

508 - MIN_ABS_SHIFT: minimal absolute shift vector length in map units 

509 - MIN_X_SHIFT_M: minimal shift vector length in x-direction in map units 

510 - MIN_Y_SHIFT_M: minimal shift vector length in y-direction in map units 

511 - MIN_X_SHIFT_PX: minimal shift vector length in x-direction in pixel units 

512 - MIN_Y_SHIFT_PX: minimal shift vector length in y-direction in pixel units 

513 - MIN_ANGLE: minimal direction of the shift vectors in degrees from north 

514 - MIN_SSIM_BEFORE: minimal structural similatity index within each matching window before co-registration 

515 - MIN_SSIM_AFTER: minimal structural similatity index within each matching window after co-registration 

516 - MIN_RELIABILITY: minimal tie point reliability in percent 

517 - MIN_ABS_SHIFT: maximal absolute shift vector length in map units 

518 - MAX_X_SHIFT_M: maximal shift vector length in x-direction in map units 

519 - MAX_Y_SHIFT_M: maximal shift vector length in y-direction in map units 

520 - MAX_X_SHIFT_PX: maximal shift vector length in x-direction in pixel units 

521 - MAX_Y_SHIFT_PX: maximal shift vector length in y-direction in pixel units 

522 - MAX_ANGLE: maximal direction of the shift vectors in degrees from north 

523 - MAX_SSIM_BEFORE: maximal structural similatity index within each matching window before co-registration 

524 - MAX_SSIM_AFTER: maximal structural similatity index within each matching window after co-registration 

525 - MAX_RELIABILITY: maximal tie point reliability in percent 

526 

527 :param include_outliers: whether to include tie points that have been marked as false-positives (if present) 

528 """ 

529 if self.CoRegPoints_table.empty: 

530 raise RuntimeError('Cannot compute overall statistics because no tie points were found at all.') 

531 

532 tbl = self.CoRegPoints_table 

533 

534 n_tiepoints = sum(tbl['ABS_SHIFT'] != self.outFillVal) 

535 n_outliers = sum(tbl['OUTLIER'] == 1) 

536 

537 tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == 0].copy() if 'OUTLIER' in tbl.columns else tbl 

538 tbl = tbl.copy().replace(self.outFillVal, np.nan) 

539 

540 if not include_outliers and tbl.empty: 

541 raise RuntimeError('Cannot compute overall statistics ' 

542 'because all tie points are flagged as false-positives.') 

543 

544 def RMSE(shifts): 

545 shifts_sq = shifts ** 2 

546 return np.sqrt(sum(shifts_sq) / len(shifts_sq)) 

547 

548 def MSE(shifts): 

549 shifts_sq = shifts ** 2 

550 return sum(shifts_sq) / len(shifts_sq) 

551 

552 def MAE(shifts): 

553 shifts_abs = np.abs(shifts) 

554 return sum(shifts_abs) / len(shifts_abs) 

555 

556 abs_shift, x_shift_m, y_shift_m, x_shift_px, y_shift_px, angle, ssim_before, ssim_after, reliability = \ 

557 [tbl[k].dropna().values for k in ['ABS_SHIFT', 'X_SHIFT_M', 'Y_SHIFT_M', 'X_SHIFT_PX', 'Y_SHIFT_PX', 

558 'ANGLE', 'SSIM_BEFORE', 'SSIM_AFTER', 'RELIABILITY']] 

559 

560 stats = dict( 

561 N_TP=n_tiepoints, 

562 N_VALID_TP=len(abs_shift), 

563 N_INVALID_TP=n_outliers, 

564 PERC_VALID_TP=(n_tiepoints - n_outliers) / n_tiepoints * 100, 

565 

566 RMSE_M=RMSE(abs_shift), 

567 RMSE_X_M=RMSE(x_shift_m), 

568 RMSE_Y_M=RMSE(y_shift_m), 

569 RMSE_X_PX=RMSE(x_shift_px), 

570 RMSE_Y_PX=RMSE(y_shift_px), 

571 

572 MSE_M=MSE(abs_shift), 

573 MSE_X_M=MSE(x_shift_m), 

574 MSE_Y_M=MSE(y_shift_m), 

575 MSE_X_PX=MSE(x_shift_px), 

576 MSE_Y_PX=MSE(y_shift_px), 

577 

578 MAE_M=MAE(abs_shift), 

579 MAE_X_M=MAE(x_shift_m), 

580 MAE_Y_M=MAE(y_shift_m), 

581 MAE_X_PX=MAE(x_shift_px), 

582 MAE_Y_PX=MAE(y_shift_px), 

583 ) 

584 

585 for stat, func in zip(['mean', 'median', 'std', 'min', 'max'], 

586 [np.mean, np.median, np.std, np.min, np.max]): 

587 for n in ['abs_shift', 'x_shift_m', 'y_shift_m', 'x_shift_px', 'y_shift_px', 

588 'angle', 'ssim_before', 'ssim_after', 'reliability']: 

589 

590 vals = locals()[n] 

591 stats[f'{stat}_{n}'.upper()] = func(vals) 

592 

593 return stats 

594 

595 def plot_shift_distribution(self, 

596 include_outliers: bool = True, 

597 unit: str = 'm', 

598 interactive: bool = False, 

599 figsize: tuple = None, 

600 xlim: list = None, 

601 ylim: list = None, 

602 fontsize: int = 12, 

603 title: str = 'shift distribution', 

604 savefigPath: str = '', 

605 savefigDPI: int = 96, 

606 showFig: bool = True, 

607 return_fig: bool = False 

608 ) -> tuple: 

609 """Create a 2D scatterplot containing the distribution of calculated X/Y-shifts. 

610 

611 :param include_outliers: whether to include tie points that have been marked as false-positives 

612 :param unit: 'm' for meters or 'px' for pixels (default: 'm') 

613 :param interactive: whether to use interactive mode (uses plotly for visualization) 

614 :param figsize: (xdim, ydim) 

615 :param xlim: [xmin, xmax] 

616 :param ylim: [ymin, ymax] 

617 :param fontsize: size of all used fonts 

618 :param title: the title to be plotted above the figure 

619 :param savefigPath: path where to save the figure 

620 :param savefigDPI: DPI resolution of the output figure when saved to disk 

621 :param showFig: whether to show or to hide the figure 

622 :param return_fig: whether to return the figure and axis objects 

623 """ 

624 if unit not in ['m', 'px']: 

625 raise ValueError("Parameter 'unit' must have the value 'm' (meters) or 'px' (pixels)! Got %s." % unit) 

626 

627 if self.CoRegPoints_table.empty: 

628 raise RuntimeError('Shift distribution cannot be plotted because no tie points were found at all.') 

629 

630 tbl = self.CoRegPoints_table 

631 tbl = tbl[tbl['ABS_SHIFT'] != self.outFillVal] 

632 tbl_il = tbl[tbl['OUTLIER'] == 0].copy() if 'OUTLIER' in tbl.columns else tbl 

633 tbl_ol = tbl[tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else None 

634 x_attr = 'X_SHIFT_M' if unit == 'm' else 'X_SHIFT_PX' 

635 y_attr = 'Y_SHIFT_M' if unit == 'm' else 'Y_SHIFT_PX' 

636 rmse = self.calc_rmse(include_outliers=False) # always exclude outliers when calculating RMSE 

637 figsize = figsize if figsize else (10, 10) 

638 

639 if interactive: 

640 from plotly.offline import iplot, init_notebook_mode 

641 import plotly.graph_objs as go 

642 # FIXME outliers are not plotted 

643 

644 init_notebook_mode(connected=True) 

645 

646 # Create a trace 

647 trace = go.Scatter( 

648 x=tbl_il[x_attr], 

649 y=tbl_il[y_attr], 

650 mode='markers' 

651 ) 

652 

653 data = [trace] 

654 

655 # Plot and embed in ipython notebook! 

656 iplot(data, filename='basic-scatter') 

657 

658 return None, None 

659 

660 else: 

661 fig = plt.figure(figsize=figsize) 

662 ax = fig.add_subplot(111) 

663 

664 if include_outliers and 'OUTLIER' in tbl.columns: 

665 ax.scatter(tbl_ol[x_attr], tbl_ol[y_attr], marker='+', c='r', label='false-positives') 

666 ax.scatter(tbl_il[x_attr], tbl_il[y_attr], marker='+', c='g', label='valid tie points') 

667 

668 # set axis limits 

669 if not xlim: 

670 xmax = np.abs(tbl_il[x_attr]).max() 

671 xlim = [-xmax, xmax] 

672 if not ylim: 

673 ymax = np.abs(tbl_il[y_attr]).max() 

674 ylim = [-ymax, ymax] 

675 ax.set_xlim(xlim) 

676 ax.set_ylim(ylim) 

677 

678 # add text box containing RMSE of plotted shifts 

679 xlim, ylim = ax.get_xlim(), ax.get_ylim() 

680 ax.text(xlim[1] - (xlim[1] / 20), -ylim[1] + (ylim[1] / 20), 

681 'RMSE: %s m / %s px' % (np.round(rmse, 2), np.round(rmse / self.shift.xgsd, 2)), 

682 ha='right', va='bottom', fontsize=fontsize, bbox=dict(facecolor='w', pad=None, alpha=0.8)) 

683 

684 # add grid and increase linewidth of middle line 

685 ax.grid(visible=True) 

686 ax.spines["right"].set_visible(True) 

687 ax.spines["top"].set_visible(True) 

688 xgl = ax.get_xgridlines() 

689 middle_xgl = xgl[int(np.median(np.array(range(len(xgl)))))] 

690 middle_xgl.set_linewidth(2) 

691 middle_xgl.set_linestyle('-') 

692 ygl = ax.get_ygridlines() 

693 middle_ygl = ygl[int(np.median(np.array(range(len(ygl)))))] 

694 middle_ygl.set_linewidth(2) 

695 middle_ygl.set_linestyle('-') 

696 

697 # set title and adjust tick labels 

698 ax.set_title(title, fontsize=fontsize) 

699 [tick.label1.set_fontsize(fontsize) for tick in ax.xaxis.get_major_ticks()] 

700 [tick.label1.set_fontsize(fontsize) for tick in ax.yaxis.get_major_ticks()] 

701 ax.set_xlabel('x-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize) 

702 ax.set_ylabel('y-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize) 

703 

704 # add legend with labels in the right order 

705 handles, labels = ax.get_legend_handles_labels() 

706 leg = plt.legend(reversed(handles), reversed(labels), fontsize=fontsize, loc='upper right', scatterpoints=3) 

707 leg.get_frame().set_edgecolor('black') 

708 

709 # remove white space around the figure 

710 fig.subplots_adjust(top=.94, bottom=.06, right=.96, left=.09) 

711 

712 if savefigPath: 

713 fig.savefig(savefigPath, dpi=savefigDPI, pad_inches=0.3, bbox_inches='tight') 

714 

715 if return_fig: 

716 return fig, ax 

717 

718 if showFig and not self.q: 

719 plt.show(block=True) 

720 else: 

721 plt.close(fig) 

722 

723 def dump_CoRegPoints_table(self, path_out=None): 

724 if self.CoRegPoints_table.empty: 

725 raise RuntimeError('Cannot dump tie points table because it is empty.') 

726 

727 path_out = path_out if path_out else \ 

728 get_generic_outpath(dir_out=self.dir_out, 

729 fName_out="CoRegPoints_table_grid%s_ws(%s_%s)__T_%s__R_%s.pkl" 

730 % (self.grid_res, self.COREG_obj.win_size_XY[0], 

731 self.COREG_obj.win_size_XY[1], self.shift.basename, 

732 self.ref.basename)) 

733 if not self.q: 

734 print('Writing %s ...' % path_out) 

735 self.CoRegPoints_table.to_pickle(path_out) 

736 

737 def to_GCPList(self): 

738 # get copy of tie points grid without no data 

739 try: 

740 GDF = self.CoRegPoints_table.loc[self.CoRegPoints_table.ABS_SHIFT != self.outFillVal, :].copy() 

741 except AttributeError: 

742 # self.CoRegPoints_table has no attribute 'ABS_SHIFT' because all points have been excluded 

743 return [] 

744 

745 if getattr(GDF, 'empty'): # GDF.empty returns AttributeError 

746 return [] 

747 else: 

748 # exclude all points flagged as outliers 

749 if 'OUTLIER' in GDF.columns: 

750 GDF = GDF[GDF.OUTLIER.__eq__(False)].copy() 

751 avail_TP = len(GDF) 

752 

753 if not avail_TP: 

754 # no point passed all validity checks 

755 return [] 

756 

757 if avail_TP > 7000: 

758 GDF = GDF.sample(7000) 

759 warnings.warn('By far not more than 7000 tie points can be used for warping within a limited ' 

760 'computation time (due to a GDAL bottleneck). Thus these 7000 points are randomly chosen ' 

761 'out of the %s available tie points.' % avail_TP) 

762 

763 # calculate GCPs 

764 GDF['X_MAP_new'] = GDF.X_MAP + GDF.X_SHIFT_M 

765 GDF['Y_MAP_new'] = GDF.Y_MAP + GDF.Y_SHIFT_M 

766 GDF['GCP'] = GDF.apply(lambda GDF_row: gdal.GCP(GDF_row.X_MAP_new, 

767 GDF_row.Y_MAP_new, 

768 0, 

769 GDF_row.X_IM, 

770 GDF_row.Y_IM), 

771 axis=1) 

772 self.GCPList = GDF.GCP.tolist() 

773 

774 return self.GCPList 

775 

776 def test_if_singleprocessing_equals_multiprocessing_result(self): 

777 # RANSAC filtering always produces different results because it includes random sampling 

778 self.tieP_filter_level = 1 

779 

780 self.CPUs = None 

781 dataframe = self.get_CoRegPoints_table() 

782 mp_out = np.empty_like(dataframe.values) 

783 mp_out[:] = dataframe.values 

784 self.CPUs = 1 

785 dataframe = self.get_CoRegPoints_table() 

786 sp_out = np.empty_like(dataframe.values) 

787 sp_out[:] = dataframe.values 

788 

789 return np.array_equal(sp_out, mp_out) 

790 

791 def _get_line_by_PID(self, PID): 

792 return self.CoRegPoints_table.loc[PID, :] 

793 

794 def _get_lines_by_PIDs(self, PIDs): 

795 assert isinstance(PIDs, list) 

796 lines = np.zeros((len(PIDs), self.CoRegPoints_table.shape[1])) 

797 for i, PID in enumerate(PIDs): 

798 lines[i, :] = self.CoRegPoints_table[self.CoRegPoints_table['POINT_ID'] == PID] 

799 return lines 

800 

801 def to_PointShapefile(self, 

802 path_out: str = None, 

803 skip_nodata: bool = True, 

804 skip_nodata_col: str = 'ABS_SHIFT', 

805 skip_outliers: bool = False 

806 ) -> None: 

807 """Write the calculated tie point grid to a point shapefile (e.g., for visualization by a GIS software). 

808 

809 NOTE: The shapefile uses Tie_Point_Grid.CoRegPoints_table as attribute table. 

810 

811 :param path_out: <str> the output path. If not given, it is automatically defined. 

812 :param skip_nodata: <bool> whether to skip all points where no valid match could be found 

813 :param skip_nodata_col: <str> determines which column of Tie_Point_Grid.CoRegPoints_table is used to 

814 identify points where no valid match could be found 

815 :param skip_outliers: <bool> whether to exclude all tie points that have been flagged as outlier 

816 (false-positive) 

817 """ 

818 if self.CoRegPoints_table.empty: 

819 raise RuntimeError('Cannot save a point shapefile because no tie points were found at all.') 

820 

821 GDF2pass = self.CoRegPoints_table 

822 

823 if skip_nodata: 

824 GDF2pass = GDF2pass[GDF2pass[skip_nodata_col] != self.outFillVal].copy() 

825 else: 

826 # use the error represetation (including error type) instead of only the error message 

827 GDF2pass.LAST_ERR = GDF2pass.apply(lambda GDF_row: repr(GDF_row.LAST_ERR), axis=1) 

828 

829 if skip_outliers: 

830 GDF2pass = GDF2pass[~GDF2pass['OUTLIER'].__eq__(True)].copy() 

831 

832 # replace boolean values (cannot be written) 

833 GDF2pass = GDF2pass.replace(False, 0).copy() # replace booleans where column dtype is not bool but np.object 

834 GDF2pass = GDF2pass.replace(True, 1).copy() 

835 for col in GDF2pass.columns: 

836 if GDF2pass[col].dtype == bool: 

837 GDF2pass[col] = GDF2pass[col].astype(int) 

838 

839 path_out = path_out if path_out else \ 

840 get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'), 

841 fName_out="CoRegPoints_grid%s_ws(%s_%s)__T_%s__R_%s.shp" 

842 % (self.grid_res, self.COREG_obj.win_size_XY[0], 

843 self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename)) 

844 if not self.q: 

845 print('Writing %s ...' % path_out) 

846 GDF2pass.to_file(path_out) 

847 

848 def to_vectorfield(self, path_out: str = None, fmt: str = None, mode: str = 'md') -> GeoArray: 

849 """Save the calculated X-/Y-shifts to a 2-band raster file that can be used to visualize a vectorfield. 

850 

851 NOTE: For example ArcGIS is able to visualize such 2-band raster files as a vectorfield. 

852 

853 :param path_out: the output path. If not given, it is automatically defined. 

854 :param fmt: output raster format string 

855 :param mode: The mode how the output is written ('uv' or 'md'; default: 'md') 

856 - 'uv': outputs X-/Y shifts 

857 - 'md': outputs magnitude and direction 

858 """ 

859 assert mode in ['uv', 'md'], "'mode' must be either 'uv' (outputs X-/Y shifts) or 'md' " \ 

860 "(outputs magnitude and direction)'. Got %s." % mode 

861 attr_b1 = 'X_SHIFT_M' if mode == 'uv' else 'ABS_SHIFT' 

862 attr_b2 = 'Y_SHIFT_M' if mode == 'uv' else 'ANGLE' 

863 

864 if self.CoRegPoints_table.empty: 

865 raise RuntimeError('Cannot save the vector field because no tie points were found at all.') 

866 

867 xshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'], 

868 values=self.CoRegPoints_table[attr_b1], 

869 tgt_res=self.shift.xgsd * self.grid_res, 

870 prj=self.CoRegPoints_table.crs.to_wkt(), 

871 fillVal=self.outFillVal) 

872 

873 yshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'], 

874 values=self.CoRegPoints_table[attr_b2], 

875 tgt_res=self.shift.xgsd * self.grid_res, 

876 prj=self.CoRegPoints_table.crs.to_wkt(), 

877 fillVal=self.outFillVal) 

878 

879 out_GA = GeoArray(np.dstack([xshift_arr, yshift_arr]), gt, prj, nodata=self.outFillVal) 

880 

881 path_out = path_out if path_out else \ 

882 get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'), 

883 fName_out="CoRegVectorfield%s_ws(%s_%s)__T_%s__R_%s.tif" 

884 % (self.grid_res, self.COREG_obj.win_size_XY[0], 

885 self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename)) 

886 

887 out_GA.save(path_out, fmt=fmt if fmt else 'Gtiff') 

888 

889 return out_GA 

890 

891 def to_interpolated_raster(self, 

892 metric: str = 'ABS_SHIFT', 

893 method: str = 'RBF', 

894 plot_result: bool = False, 

895 lowres_spacing: int = 5, 

896 v: bool = False 

897 ) -> np.ndarray: 

898 """Interpolate the point data of the given metric into space. 

899 

900 :param metric: metric name to interpolate, i.e., one of the column names of 

901 Tie_Point_Grid.CoRegPoints_table, e.g., 'ABS_SHIFT'. 

902 :param method: interpolation algorithm 

903 - 'RBF' (Radial Basis Function) 

904 - 'GPR' (Gaussian Process Regression; equivalent to Simple Kriging) 

905 - 'Kriging' (Ordinary Kriging based on pykrige) 

906 :param plot_result: plot the result to assess the interpolation quality 

907 :param lowres_spacing: by default, RBF, GPR, and Kriging run a lower resolution which is then linearly 

908 interpolated to the full output image resolution. lowres_spacing defines the number of 

909 pixels between the low resolution grid points 

910 (higher values are faster but less accurate, default: 5) 

911 :param v: enable verbose mode 

912 :return: interpolation result as numpy array in the X/Y dimension of the target image of the co-registration 

913 """ 

914 TPGI = Tie_Point_Grid_Interpolator(self, v=v) 

915 

916 return TPGI.interpolate(metric=metric, method=method, plot_result=plot_result, lowres_spacing=lowres_spacing) 

917 

918 

919class Tie_Point_Refiner(object): 

920 """A class for performing outlier detection.""" 

921 

922 def __init__(self, GDF, 

923 min_reliability=60, 

924 rs_max_outlier: float = 10, 

925 rs_tolerance: float = 2.5, 

926 rs_max_iter: int = 15, 

927 rs_exclude_previous_outliers: bool = True, 

928 rs_timeout: float = 20, 

929 rs_random_state: Optional[int] = 0, 

930 q: bool = False): 

931 """Get an instance of Tie_Point_Refiner. 

932 

933 :param GDF: GeoDataFrame like TiePointGrid.CoRegPoints_table containing all tie 

934 points to be filtered and the corresponding metadata 

935 :param min_reliability: minimum threshold for previously computed tie X/Y shift 

936 reliability (default: 60%) 

937 :param rs_max_outlier: RANSAC: maximum percentage of outliers to be detected 

938 (default: 10%) 

939 :param rs_tolerance: RANSAC: percentage tolerance for max_outlier_percentage 

940 (default: 2.5%) 

941 :param rs_max_iter: RANSAC: maximum iterations for finding the best RANSAC threshold 

942 (default: 15) 

943 :param rs_exclude_previous_outliers: RANSAC: whether to exclude points that have been flagged as 

944 outlier by earlier filtering (default:True) 

945 :param rs_timeout: RANSAC: timeout for iteration loop in seconds (default: 20) 

946 :param rs_random_state: RANSAC random state (an integer corresponds to a fixed/pseudo-random 

947 state, None randomizes the result) 

948 

949 :param q: 

950 """ 

951 self.GDF = GDF.copy() 

952 self.min_reliability = min_reliability 

953 self.rs_max_outlier_percentage = rs_max_outlier 

954 self.rs_tolerance = rs_tolerance 

955 self.rs_max_iter = rs_max_iter 

956 self.rs_exclude_previous_outliers = rs_exclude_previous_outliers 

957 self.rs_timeout = rs_timeout 

958 self.rs_random_state = rs_random_state 

959 self.q = q 

960 self.new_cols = [] 

961 self.ransac_model_robust = None 

962 

963 def run_filtering(self, level=3): 

964 """Filter tie points used for shift correction. 

965 

966 :param level: tie point filter level (default: 3). 

967 NOTE: lower levels are also included if a higher level is chosen 

968 

969 - Level 0: no tie point filtering 

970 - Level 1: Reliablity filtering 

971 - filter all tie points out that have a low reliability according to internal tests 

972 - Level 2: SSIM filtering 

973 - filters all tie points out where shift correction does not increase image 

974 similarity within matching window (measured by mean structural similarity index) 

975 - Level 3: RANSAC outlier detection 

976 

977 :return: 

978 """ 

979 # TODO catch empty GDF 

980 

981 # RELIABILITY filtering 

982 if level > 0: 

983 marked_recs = self._reliability_thresholding() # type: Series 

984 self.GDF['L1_OUTLIER'] = marked_recs 

985 self.new_cols.append('L1_OUTLIER') 

986 

987 n_flagged = len(marked_recs[marked_recs]) 

988 perc40 = np.percentile(self.GDF.RELIABILITY, 40) 

989 

990 if n_flagged / len(self.GDF) > .7: 

991 warnings.warn(r"More than 70%% of the found tie points have a reliability lower than %s%% and are " 

992 r"therefore marked as false-positives. Consider relaxing the minimum reliability " 

993 r"(parameter 'min_reliability') to avoid that. For example min_reliability=%d would only " 

994 r"flag 40%% of the tie points in case of your input data." 

995 % (self.min_reliability, perc40)) 

996 

997 if not self.q: 

998 print('%s tie points flagged by level 1 filtering (reliability).' 

999 % n_flagged) 

1000 

1001 # SSIM filtering 

1002 if level > 1: 

1003 marked_recs = self._SSIM_filtering() 

1004 self.GDF['L2_OUTLIER'] = marked_recs # type: Series 

1005 self.new_cols.append('L2_OUTLIER') 

1006 

1007 if not self.q: 

1008 print('%s tie points flagged by level 2 filtering (SSIM).' % (len(marked_recs[marked_recs]))) 

1009 

1010 # RANSAC filtering 

1011 if level > 2: 

1012 # exclude previous outliers 

1013 ransacInGDF = self.GDF[~self.GDF[self.new_cols].any(axis=1)].copy() \ 

1014 if self.rs_exclude_previous_outliers else self.GDF 

1015 

1016 if len(ransacInGDF) > 4: 

1017 # running RANSAC with less than four tie points makes no sense 

1018 

1019 marked_recs = self._RANSAC_outlier_detection(ransacInGDF) # type: Series 

1020 # we need to join a list here because otherwise it's merged by the 'index' column 

1021 self.GDF['L3_OUTLIER'] = marked_recs.tolist() 

1022 

1023 if not self.q: 

1024 print('%s tie points flagged by level 3 filtering (RANSAC)' 

1025 % (len(marked_recs[marked_recs]))) 

1026 else: 

1027 print('RANSAC skipped because too less valid tie points have been found.') 

1028 self.GDF['L3_OUTLIER'] = False 

1029 

1030 self.new_cols.append('L3_OUTLIER') 

1031 

1032 self.GDF['OUTLIER'] = self.GDF[self.new_cols].any(axis=1) 

1033 self.new_cols.append('OUTLIER') 

1034 

1035 return self.GDF, self.new_cols 

1036 

1037 def _reliability_thresholding(self): 

1038 """Exclude all records where estimated reliability of the calculated shifts is below the given threshold.""" 

1039 return self.GDF.RELIABILITY < self.min_reliability 

1040 

1041 def _SSIM_filtering(self): 

1042 """Exclude all records where SSIM decreased.""" 

1043 # ssim_diff = np.median(self.GDF['SSIM_AFTER']) - np.median(self.GDF['SSIM_BEFORE']) 

1044 

1045 # self.GDF.SSIM_IMPROVED = \ 

1046 # self.GDF.apply(lambda GDF_row: GDF_row['SSIM_AFTER']>GDF_row['SSIM_BEFORE'] + ssim_diff, axis=1) 

1047 

1048 return ~self.GDF.SSIM_IMPROVED 

1049 

1050 def _RANSAC_outlier_detection(self, inGDF): 

1051 """Detect geometric outliers between point cloud of source and estimated coordinates using RANSAC algorithm.""" 

1052 # from skimage.transform import PolynomialTransform # import here to avoid static TLS ImportError 

1053 

1054 src_coords = np.array(inGDF[['X_MAP', 'Y_MAP']]) 

1055 xyShift = np.array(inGDF[['X_SHIFT_M', 'Y_SHIFT_M']]) 

1056 est_coords = src_coords + xyShift 

1057 

1058 for co, n in zip([src_coords, est_coords], ['src_coords', 'est_coords']): 

1059 assert co.ndim == 2 and co.shape[1] == 2, "'%s' must have shape [Nx2]. Got shape %s." % (n, co.shape) 

1060 

1061 if not 0 < self.rs_max_outlier_percentage < 100: 

1062 raise ValueError 

1063 min_inlier_percentage = 100 - self.rs_max_outlier_percentage 

1064 

1065 # class PolyTF_1(PolynomialTransform): # pragma: no cover 

1066 # def estimate(*data): 

1067 # return PolynomialTransform.estimate(*data, order=1) 

1068 

1069 # robustly estimate affine transform model with RANSAC 

1070 # eliminates not more than the given maximum outlier percentage of the tie points 

1071 

1072 model_robust, inliers = None, None 

1073 count_inliers = None 

1074 th = 5 # start RANSAC threshold 

1075 th_checked = {} # dict of thresholds that already have been tried + calculated inlier percentage 

1076 th_substract = 2 

1077 count_iter = 0 

1078 time_start = time() 

1079 ideal_count = min_inlier_percentage * src_coords.shape[0] / 100 

1080 

1081 # optimize RANSAC threshold so that it marks not much more or less than the given outlier percentage 

1082 while True: 

1083 if th_checked: 

1084 th_too_strict = count_inliers < ideal_count # True if too few inliers remaining 

1085 

1086 # calculate new theshold using old increment 

1087 # (but ensure th_new>0 by adjusting increment if needed) 

1088 th_new = 0 

1089 while th_new <= 0: 

1090 th_new = th + th_substract if th_too_strict else th - th_substract 

1091 if th_new <= 0: 

1092 th_substract /= 2 

1093 

1094 # check if calculated new threshold has been used before 

1095 th_already_checked = th_new in th_checked.keys() 

1096 

1097 # if yes, decrease increment and recalculate new threshold 

1098 th_substract = \ 

1099 th_substract if not th_already_checked else \ 

1100 th_substract / 2 

1101 th = th_new if not th_already_checked else \ 

1102 (th + th_substract if th_too_strict else 

1103 th - th_substract) 

1104 

1105 ############### 

1106 # RANSAC call # 

1107 ############### 

1108 

1109 # model_robust, inliers = ransac((src, dst), 

1110 # PolynomialTransform, 

1111 # min_samples=3) 

1112 if src_coords.size and \ 

1113 est_coords.size and \ 

1114 src_coords.shape[0] > 6: 

1115 # import here to avoid static TLS ImportError 

1116 from skimage.measure import ransac 

1117 from skimage.transform import AffineTransform 

1118 

1119 model_robust, inliers = \ 

1120 ransac((src_coords, est_coords), 

1121 AffineTransform, 

1122 min_samples=6, 

1123 residual_threshold=th, 

1124 max_trials=2000, 

1125 stop_sample_num=int( 

1126 (min_inlier_percentage - self.rs_tolerance) / 

1127 100 * src_coords.shape[0] 

1128 ), 

1129 stop_residuals_sum=int( 

1130 (self.rs_max_outlier_percentage - self.rs_tolerance) / 

1131 100 * src_coords.shape[0]), 

1132 rng=self.rs_random_state 

1133 ) 

1134 else: 

1135 warnings.warn('RANSAC filtering could not be applied ' 

1136 'because there were too few tie points to fit a model.') 

1137 inliers = np.array([]) 

1138 break 

1139 

1140 count_inliers = np.count_nonzero(inliers) 

1141 

1142 th_checked[th] = count_inliers / src_coords.shape[0] * 100 

1143 # print(th,'\t', th_checked[th], ) 

1144 

1145 if min_inlier_percentage - self.rs_tolerance <\ 

1146 th_checked[th] <\ 

1147 min_inlier_percentage + self.rs_tolerance: 

1148 # print('in tolerance') 

1149 break 

1150 

1151 if count_iter > self.rs_max_iter or \ 

1152 time() - time_start > self.rs_timeout: 

1153 break # keep last values and break while loop 

1154 

1155 count_iter += 1 

1156 

1157 outliers = inliers.__eq__(False) if inliers is not None and inliers.size else np.array([]) 

1158 

1159 if inGDF.empty or outliers is None or \ 

1160 (isinstance(outliers, list) and not outliers) or \ 

1161 (isinstance(outliers, np.ndarray) and not outliers.size): 

1162 outseries = Series([False] * len(self.GDF)) 

1163 

1164 elif len(inGDF) < len(self.GDF): 

1165 inGDF['outliers'] = outliers 

1166 fullGDF = GeoDataFrame(self.GDF['POINT_ID']) 

1167 fullGDF = fullGDF.merge(inGDF[['POINT_ID', 'outliers']], 

1168 on='POINT_ID', 

1169 how="outer") 

1170 # fullGDF.outliers.copy()[~fullGDF.POINT_ID.isin(GDF.POINT_ID)] = False 

1171 fullGDF = fullGDF.fillna(False) # NaNs are due to exclude_previous_outliers 

1172 outseries = fullGDF['outliers'] 

1173 

1174 else: 

1175 outseries = Series(outliers) 

1176 

1177 assert len(outseries) == len(self.GDF), \ 

1178 'RANSAC output validation failed.' 

1179 

1180 self.ransac_model_robust = model_robust 

1181 

1182 return outseries 

1183 

1184 

1185class Tie_Point_Grid_Interpolator(object): 

1186 """Class to interpolate tie point data into space.""" 

1187 

1188 def __init__(self, tiepointgrid: Tie_Point_Grid, v: bool = False) -> None: 

1189 """Get an instance of Tie_Point_Grid_Interpolator. 

1190 

1191 :param tiepointgrid: instance of Tie_Point_Grid after computing spatial shifts 

1192 :param v: enable verbose mode 

1193 """ 

1194 self.tpg = tiepointgrid 

1195 self.v = v 

1196 

1197 def interpolate(self, 

1198 metric: str, 

1199 method: str = 'RBF', 

1200 plot_result: bool = False, 

1201 lowres_spacing: int = 5 

1202 ) -> np.array: 

1203 """Interpolate the point data of the given metric into space. 

1204 

1205 :param metric: metric name to interpolate, i.e., one of the column names of 

1206 Tie_Point_Grid.CoRegPoints_table, e.g., 'ABS_SHIFT'. 

1207 :param method: interpolation algorithm 

1208 - 'RBF' (Radial Basis Function) 

1209 - 'GPR' (Gaussian Process Regression; equivalent to Simple Kriging) 

1210 - 'Kriging' (Ordinary Kriging based on pykrige) 

1211 :param plot_result: plot the result to assess the interpolation quality 

1212 :param lowres_spacing: by default, RBF, GPR, and Kriging run a lower resolution which is then linearly 

1213 interpolated to the full output image resolution. lowres_spacing defines the number of 

1214 pixels between the low resolution grid points 

1215 (higher values are faster but less accurate, default: 5) 

1216 :return: interpolation result as numpy array in the X/Y dimension of the target image of the co-registration 

1217 """ 

1218 t0 = time() 

1219 

1220 rows, cols, data = self._get_pointdata(metric) 

1221 nrows_out, ncols_out = self.tpg.shift.shape[:2] 

1222 

1223 rows_lowres = np.arange(0, nrows_out + lowres_spacing, lowres_spacing) 

1224 cols_lowres = np.arange(0, ncols_out + lowres_spacing, lowres_spacing) 

1225 args = rows, cols, data, rows_lowres, cols_lowres 

1226 

1227 if method == 'RBF': 

1228 data_lowres = self._interpolate_via_rbf(*args) 

1229 elif method == 'GPR': 

1230 data_lowres = self._interpolate_via_gpr(*args) 

1231 elif method == 'Kriging': 

1232 data_lowres = self._interpolate_via_kriging(*args) 

1233 else: 

1234 raise ValueError(method) 

1235 

1236 if lowres_spacing > 1: 

1237 rows_full = np.arange(nrows_out) 

1238 cols_full = np.arange(ncols_out) 

1239 data_full = self._interpolate_regulargrid(rows_lowres, cols_lowres, data_lowres, rows_full, cols_full) 

1240 else: 

1241 data_full = data_lowres 

1242 

1243 if self.v: 

1244 print('interpolation runtime: %.2fs' % (time() - t0)) 

1245 if plot_result: 

1246 self._plot_interpolation_result(data_full, rows, cols, data, metric) 

1247 

1248 return data_full 

1249 

1250 def _get_pointdata(self, metric: str): 

1251 """Get the point data for the given metric from Tie_Point_Grid.CoRegPoints_table while ignoring outliers.""" 

1252 tiepoints = self.tpg.CoRegPoints_table[self.tpg.CoRegPoints_table.OUTLIER.__eq__(False)].copy() 

1253 

1254 rows = np.array(tiepoints.Y_IM) 

1255 cols = np.array(tiepoints.X_IM) 

1256 data = np.array(tiepoints[metric]) 

1257 

1258 return rows, cols, data 

1259 

1260 @staticmethod 

1261 def _plot_interpolation_result(data_full: np.ndarray, 

1262 rows: np.ndarray, 

1263 cols: np.ndarray, 

1264 data: np.ndarray, 

1265 metric: str 

1266 ): 

1267 """Plot the interpolation result together with the input point data.""" 

1268 plt.figure(figsize=(7, 7)) 

1269 im = plt.imshow(data_full) 

1270 plt.colorbar(im) 

1271 plt.scatter(cols, rows, c=data, edgecolors='black') 

1272 plt.title(metric) 

1273 plt.show() 

1274 

1275 @staticmethod 

1276 def _interpolate_regulargrid(rows: np.ndarray, 

1277 cols: np.ndarray, 

1278 data: np.ndarray, 

1279 rows_full: np.ndarray, 

1280 cols_full: np.ndarray 

1281 ): 

1282 """Run linear regular grid interpolation.""" 

1283 RGI = RegularGridInterpolator(points=[cols, rows], 

1284 values=data.T, # must be in shape [x, y] 

1285 method='linear', 

1286 bounds_error=False) 

1287 data_full = RGI(np.dstack(np.meshgrid(cols_full, rows_full))) 

1288 return data_full 

1289 

1290 @staticmethod 

1291 def _interpolate_via_rbf(rows: np.ndarray, 

1292 cols: np.ndarray, 

1293 data: np.ndarray, 

1294 rows_full: np.ndarray, 

1295 cols_full: np.ndarray 

1296 ): 

1297 """Run Radial Basis Function (RBF) interpolation. 

1298 

1299 -> https://github.com/agile-geoscience/xlines/blob/master/notebooks/11_Gridding_map_data.ipynb 

1300 -> documents the legacy scipy.interpolate.Rbf 

1301 """ 

1302 rbf = RBFInterpolator( 

1303 np.column_stack([cols, rows]), data, 

1304 kernel="linear", 

1305 # kernel="thin_plate_spline", 

1306 ) 

1307 cols_grid, rows_grid = np.meshgrid(cols_full, rows_full) 

1308 data_full = \ 

1309 rbf(np.column_stack([cols_grid.flat, rows_grid.flat]))\ 

1310 .reshape(rows_grid.shape) 

1311 

1312 return data_full 

1313 

1314 @staticmethod 

1315 def _interpolate_via_gpr(rows: np.ndarray, 

1316 cols: np.ndarray, 

1317 data: np.ndarray, 

1318 rows_full: np.ndarray, 

1319 cols_full: np.ndarray 

1320 ): 

1321 """Run Gaussian Process Regression (GPR) interpolation. 

1322 

1323 -> https://stackoverflow.com/questions/24978052/interpolation-over-regular-grid-in-python 

1324 """ 

1325 try: 

1326 import sklearn # noqa F401 

1327 except ModuleNotFoundError: 

1328 raise ModuleNotFoundError( 

1329 "GPR interpolation requires the optional package 'scikit-learn' to be installed. You may install it " 

1330 "with Conda (conda install -c conda-forge scikit-learn) or Pip (pip install scikit-learn)." 

1331 ) 

1332 

1333 from sklearn.gaussian_process.kernels import RBF 

1334 from sklearn.gaussian_process import GaussianProcessRegressor 

1335 

1336 gp = GaussianProcessRegressor( 

1337 normalize_y=False, 

1338 alpha=0.001, # Larger values imply more input noise and result in smoother grids; default: 1e-10 

1339 kernel=RBF(length_scale=100)) 

1340 gp.fit(np.column_stack([cols, rows]), data.T) 

1341 

1342 cols_grid, rows_grid = np.meshgrid(cols_full, rows_full) 

1343 data_full = \ 

1344 gp.predict(np.column_stack([cols_grid.flat, rows_grid.flat]))\ 

1345 .reshape(rows_grid.shape) 

1346 

1347 return data_full 

1348 

1349 @staticmethod 

1350 def _interpolate_via_kriging(rows: np.ndarray, 

1351 cols: np.ndarray, 

1352 data: np.ndarray, 

1353 rows_full: np.ndarray, 

1354 cols_full: np.ndarray 

1355 ): 

1356 """Run Ordinary Kriging interpolation based on pykrige. 

1357 

1358 Reference: P.K. Kitanidis, Introduction to Geostatistics: Applications in Hydrogeology, 

1359 (Cambridge University Press, 1997) 272 p. 

1360 """ 

1361 try: 

1362 import pykrige # noqa F401 

1363 except ModuleNotFoundError: 

1364 raise ModuleNotFoundError( 

1365 "Ordinary Kriging requires the optional package 'pykrige' to be installed. You may install it with " 

1366 "Conda (conda install -c conda-forge pykrige) or Pip (pip install pykrige)." 

1367 ) 

1368 

1369 from pykrige.ok import OrdinaryKriging 

1370 

1371 OK = OrdinaryKriging(cols.astype(float), rows.astype(float), data.astype(float), 

1372 variogram_model='spherical', 

1373 verbose=False) 

1374 

1375 data_full, sigmasq = \ 

1376 OK.execute('grid', 

1377 cols_full.astype(float), 

1378 rows_full.astype(float), 

1379 backend='C', 

1380 # n_closest_points=12 

1381 ) 

1382 

1383 return data_full