Coverage for arosics/Tie_Point_Grid.py: 87%
534 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-12 23:30 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-12 23:30 +0000
1# -*- coding: utf-8 -*-
3# AROSICS - Automated and Robust Open-Source Image Co-Registration Software
4#
5# Copyright (C) 2017-2024
6# - Daniel Scheffler (GFZ Potsdam, daniel.scheffler@gfz-potsdam.de)
7# - Helmholtz Centre Potsdam - GFZ German Research Centre for Geosciences Potsdam,
8# Germany (https://www.gfz-potsdam.de/)
9#
10# This software was developed within the context of the GeoMultiSens project funded
11# by the German Federal Ministry of Education and Research
12# (project grant code: 01 IS 14 010 A-C).
13#
14# Licensed under the Apache License, Version 2.0 (the "License");
15# you may not use this file except in compliance with the License.
16# You may obtain a copy of the License at
17#
18# https://www.apache.org/licenses/LICENSE-2.0
19#
20# Unless required by applicable law or agreed to in writing, software
21# distributed under the License is distributed on an "AS IS" BASIS,
22# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23# See the License for the specific language governing permissions and
24# limitations under the License.
26import os
27import warnings
28from time import time
29from typing import Optional
30from sys import platform
32# custom
33from osgeo import gdal # noqa
34import numpy as np
35from geopandas import GeoDataFrame
36from pandas import DataFrame, Series
37from shapely.geometry import Point
38from matplotlib import pyplot as plt
39from scipy.interpolate import RBFInterpolator, RegularGridInterpolator
40from joblib import Parallel, delayed
42# internal modules
43from .CoReg import COREG
44from py_tools_ds.geo.projection import isLocal
45from py_tools_ds.io.pathgen import get_generic_outpath
46from py_tools_ds.processing.progress_mon import ProgressBar
47from py_tools_ds.geo.vector.conversion import points_to_raster
48from geoarray import GeoArray
50from .CoReg import GeoArray_CoReg # noqa F401 # flake8 issue
52__author__ = 'Daniel Scheffler'
55class Tie_Point_Grid(object):
56 """
57 The 'Tie_Point_Grid' class applies the algorithm to detect spatial shifts to the overlap area of the input images.
59 Spatial shifts are calculated for each point in grid of which the parameters can be adjusted using keyword
60 arguments. Shift correction performs a polynomial transformation using te calculated shifts of each point in the
61 grid as GCPs. Thus, 'Tie_Point_Grid' can be used to correct for locally varying geometric distortions of the target
62 image.
64 See help(Tie_Point_Grid) for documentation!
65 """
67 def __init__(self,
68 COREG_obj: COREG,
69 grid_res: float,
70 max_points: int = None,
71 outFillVal: int = -9999,
72 resamp_alg_calc: str = 'cubic',
73 tieP_filter_level: int = 3,
74 outlDetect_settings: dict = None,
75 dir_out: str = None,
76 CPUs: int = None,
77 progress: bool = True,
78 v: bool = False,
79 q: bool = False):
80 """Get an instance of the 'Tie_Point_Grid' class.
82 :param COREG_obj:
83 an instance of COREG class
85 :param grid_res:
86 grid resolution in pixels of the target image (x-direction)
88 :param max_points:
89 maximum number of points used to find coregistration tie points
91 NOTE: Points are selected randomly from the given point grid (specified by 'grid_res'). If the point does
92 not provide enough points, all available points are chosen.
94 :param outFillVal:
95 if given the generated tie points grid is filled with this value in case no match could be found during
96 co-registration (default: -9999)
98 :param resamp_alg_calc:
99 the resampling algorithm to be used for all warping processes during calculation of spatial shifts
100 (valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average, mode, max, min, med, q1, q3)
101 default: cubic (highly recommended)
103 :param tieP_filter_level:
104 filter tie points used for shift correction in different levels (default: 3).
105 NOTE: lower levels are also included if a higher level is chosen
107 - Level 0: no tie point filtering
108 - Level 1: Reliablity filtering
109 - filter all tie points out that have a low reliability according to internal tests
110 - Level 2: SSIM filtering
111 - filters all tie points out where shift correction does not increase image similarity within
112 matching window (measured by mean structural similarity index)
113 - Level 3: RANSAC outlier detection
115 :param outlDetect_settings:
116 a dictionary with the settings to be passed to arosics.TiePointGrid.Tie_Point_Refiner.
117 Available keys: min_reliability, rs_max_outlier, rs_tolerance, rs_max_iter, rs_exclude_previous_outliers,
118 rs_timeout, rs_random_state, q. See documentation there.
120 :param dir_out:
121 output directory to be used for all outputs if nothing else is given to the individual methods
123 :param CPUs:
124 number of CPUs to use during calculation of tie points grid
125 (default: None, which means 'all CPUs available')
127 :param progress:
128 show progress bars (default: True)
130 :param v:
131 verbose mode (default: False)
133 :param q:
134 quiet mode (default: False)
135 """
136 if not isinstance(COREG_obj, COREG):
137 raise ValueError("'COREG_obj' must be an instance of COREG class.")
139 self.COREG_obj = COREG_obj # type: COREG
140 self.grid_res = grid_res
141 self.max_points = max_points
142 self.outFillVal = outFillVal
143 self.rspAlg_calc = resamp_alg_calc
144 self.tieP_filter_level = tieP_filter_level
145 self.outlDetect_settings = outlDetect_settings or dict()
146 self.dir_out = dir_out
147 self.CPUs = CPUs
148 self.v = v
149 self.q = q if not v else False # overridden by v
150 self.progress = progress if not q else False # overridden by q
152 if 'q' not in self.outlDetect_settings:
153 self.outlDetect_settings['q'] = self.q
155 self.ref = self.COREG_obj.ref # type: GeoArray_CoReg
156 self.shift = self.COREG_obj.shift # type: GeoArray_CoReg
158 self.XY_points, self.XY_mapPoints = self._get_imXY__mapXY_points(self.grid_res)
159 self._CoRegPoints_table = None # set by self.CoRegPoints_table
160 self._GCPList = None # set by self.to_GCPList()
162 @property
163 def mean_x_shift_px(self):
164 return self.CoRegPoints_table['X_SHIFT_PX'][self.CoRegPoints_table['X_SHIFT_PX'] != self.outFillVal].mean()
166 @property
167 def mean_y_shift_px(self):
168 return self.CoRegPoints_table['Y_SHIFT_PX'][self.CoRegPoints_table['Y_SHIFT_PX'] != self.outFillVal].mean()
170 @property
171 def mean_x_shift_map(self):
172 return self.CoRegPoints_table['X_SHIFT_M'][self.CoRegPoints_table['X_SHIFT_M'] != self.outFillVal].mean()
174 @property
175 def mean_y_shift_map(self):
176 return self.CoRegPoints_table['Y_SHIFT_M'][self.CoRegPoints_table['Y_SHIFT_M'] != self.outFillVal].mean()
178 @property
179 def CoRegPoints_table(self):
180 """Return a GeoDataFrame containing all the results from coregistration for all points in the tie point grid.
182 Columns of the GeoDataFrame: 'geometry','POINT_ID','X_IM','Y_IM','X_MAP','Y_MAP','X_WIN_SIZE', 'Y_WIN_SIZE',
183 'X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M', 'Y_SHIFT_M', 'ABS_SHIFT' and 'ANGLE'
184 """
185 if self._CoRegPoints_table is not None:
186 return self._CoRegPoints_table
187 else:
188 self._CoRegPoints_table = self.get_CoRegPoints_table()
189 return self._CoRegPoints_table
191 @CoRegPoints_table.setter
192 def CoRegPoints_table(self, CoRegPoints_table):
193 self._CoRegPoints_table = CoRegPoints_table
195 @property
196 def GCPList(self):
197 """Return a list of GDAL compatible GCP objects."""
198 if self._GCPList:
199 return self._GCPList
200 else:
201 self._GCPList = self.to_GCPList()
202 return self._GCPList
204 @GCPList.setter
205 def GCPList(self, GCPList):
206 self._GCPList = GCPList
208 def _get_imXY__mapXY_points(self, grid_res):
209 """Return a numpy array containing possible positions for coregistration tie points.
211 NOTE: The returned positions are dependent from the given grid resolution.
213 :param grid_res:
214 :return:
215 """
216 if not self.q:
217 print('Initializing tie points grid...')
219 Xarr, Yarr = np.meshgrid(np.arange(0, self.shift.shape[1] + grid_res, grid_res),
220 np.arange(0, self.shift.shape[0] + grid_res, grid_res))
222 mapXarr = np.full_like(Xarr, self.shift.gt[0], dtype=np.float64) + Xarr * self.shift.gt[1]
223 mapYarr = np.full_like(Yarr, self.shift.gt[3], dtype=np.float64) - Yarr * abs(self.shift.gt[5])
225 XY_points = np.empty((Xarr.size, 2), Xarr.dtype)
226 XY_points[:, 0] = Xarr.flat
227 XY_points[:, 1] = Yarr.flat
229 XY_mapPoints = np.empty((mapXarr.size, 2), mapXarr.dtype)
230 XY_mapPoints[:, 0] = mapXarr.flat
231 XY_mapPoints[:, 1] = mapYarr.flat
233 assert XY_points.shape == XY_mapPoints.shape
235 return XY_points, XY_mapPoints
237 def _exclude_bad_XYpos(self, GDF):
238 """Exclude all points outside the image overlap area and where the bad data mask is True (if given).
240 :param GDF: <geopandas.GeoDataFrame> must include the columns 'X_MAP' and 'Y_MAP'
241 :return:
242 """
243 from skimage.measure import points_in_poly # import here to avoid static TLS ImportError
245 # exclude all points outside of overlap area
246 inliers = points_in_poly(self.XY_mapPoints,
247 np.swapaxes(np.array(self.COREG_obj.overlap_poly.exterior.coords.xy), 0, 1))
248 GDF = GDF[inliers].copy()
249 # GDF = GDF[GDF['geometry'].within(self.COREG_obj.overlap_poly.simplify(tolerance=15))] # works but much slower
251 assert not GDF.empty, 'No coregistration point could be placed within the overlap area. Check your input data!'
253 # exclude all points where bad data mask is True (e.g. points on clouds etc.)
254 orig_len_GDF = len(GDF) # length of GDF after dropping all points outside the overlap polygon
255 mapXY = np.array(GDF.loc[:, ['X_MAP', 'Y_MAP']])
256 GDF['REF_BADDATA'] = self.COREG_obj.ref.mask_baddata.read_pointData(mapXY) \
257 if self.COREG_obj.ref.mask_baddata is not None else False
258 GDF['TGT_BADDATA'] = self.COREG_obj.shift.mask_baddata.read_pointData(mapXY).flatten().astype(bool) \
259 if self.COREG_obj.shift.mask_baddata is not None else False
260 GDF = GDF[(~GDF['REF_BADDATA']) & (~GDF['TGT_BADDATA'])]
261 if self.COREG_obj.ref.mask_baddata is not None or self.COREG_obj.shift.mask_baddata is not None:
262 if not self.q:
263 if not GDF.empty:
264 print('With respect to the provided bad data mask(s) %s points of initially %s have been excluded.'
265 % (orig_len_GDF - len(GDF), orig_len_GDF))
266 else:
267 warnings.warn('With respect to the provided bad data mask(s) no coregistration point could be '
268 'placed within an image area usable for coregistration.')
270 return GDF
272 @staticmethod
273 def _get_spatial_shifts(imref, im2shift, point_id, fftw_works, **coreg_kwargs):
274 # run CoReg
275 CR = COREG(imref, im2shift, CPUs=1, **coreg_kwargs)
276 CR.fftw_works = fftw_works
277 CR.calculate_spatial_shifts()
279 # fetch results
280 last_err = CR.tracked_errors[-1] if CR.tracked_errors else None
281 win_sz_y, win_sz_x = CR.matchBox.imDimsYX if CR.matchBox else (None, None)
282 CR_res = [win_sz_x, win_sz_y, CR.x_shift_px, CR.y_shift_px, CR.x_shift_map, CR.y_shift_map,
283 CR.vec_length_map, CR.vec_angle_deg, CR.ssim_orig, CR.ssim_deshifted, CR.ssim_improved,
284 CR.shift_reliability, last_err]
286 return [point_id] + CR_res
288 def get_CoRegPoints_table(self):
289 assert self.XY_points is not None and self.XY_mapPoints is not None
291 # create a dataframe containing 'geometry','POINT_ID','X_IM','Y_IM','X_MAP','Y_MAP'
292 # (convert imCoords to mapCoords
293 XYarr2PointGeom = np.vectorize(lambda X, Y: Point(X, Y), otypes=[Point])
294 geomPoints = np.array(XYarr2PointGeom(self.XY_mapPoints[:, 0], self.XY_mapPoints[:, 1]))
296 crs = self.COREG_obj.shift.prj if not isLocal(self.COREG_obj.shift.prj) else None
298 GDF = GeoDataFrame(index=range(len(geomPoints)),
299 crs=crs,
300 columns=['geometry', 'POINT_ID', 'X_IM', 'Y_IM', 'X_MAP', 'Y_MAP'])
301 GDF['geometry'] = geomPoints
302 GDF['POINT_ID'] = range(len(geomPoints))
303 GDF[['X_IM', 'Y_IM']] = self.XY_points
304 GDF[['X_MAP', 'Y_MAP']] = self.XY_mapPoints
306 # exclude offsite points and points on bad data mask
307 GDF = self._exclude_bad_XYpos(GDF)
308 if GDF.empty:
309 self.CoRegPoints_table = GDF
310 return self.CoRegPoints_table
312 # choose a random subset of points if a maximum number has been given
313 if self.max_points and len(GDF) > self.max_points:
314 GDF = GDF.sample(self.max_points).copy()
316 # equalize pixel grids in order to save warping time
317 if len(GDF) > 100:
318 # NOTE: actually grid res should be also changed here because self.shift.xgsd changes and grid res is
319 # connected to that
320 self.COREG_obj.equalize_pixGrids()
321 self.ref = self.COREG_obj.ref
322 self.shift = self.COREG_obj.shift
324 # validate reference and target image inputs
325 assert self.ref.footprint_poly # this also checks for mask_nodata and nodata value
326 assert self.shift.footprint_poly
328 # ensure the input arrays for CoReg are in memory -> otherwise the code will get stuck in multiprocessing if
329 # neighboured matching windows overlap during reading from disk!!
330 self.ref.cache_array_subset(
331 [self.COREG_obj.ref.band4match]) # only sets geoArr._arr_cache; does not change number of bands
332 self.shift.cache_array_subset([self.COREG_obj.shift.band4match])
334 print(f"Calculating tie point grid ({len(GDF)} points) using {self.CPUs} CPU cores...")
335 results = []
336 bar = ProgressBar(prefix='\tprogress:')
338 # multiprocessing backend is slightly faster on Linux but can only return a list (no progress)
339 if platform != 'win32' and (not self.progress or self.q):
340 kw_parallel = dict(backend='multiprocessing', return_as='list')
341 else:
342 kw_parallel = dict(backend='loky', return_as='generator')
344 for i, res in enumerate(
345 Parallel(n_jobs=self.CPUs, **kw_parallel)(
346 delayed(self._get_spatial_shifts)(
347 self.ref,
348 self.shift,
349 point_id,
350 self.COREG_obj.fftw_works,
351 wp=self.XY_mapPoints[point_id],
352 ws=self.COREG_obj.win_size_XY,
353 resamp_alg_calc=self.rspAlg_calc,
354 footprint_poly_ref=self.COREG_obj.ref.poly,
355 footprint_poly_tgt=self.COREG_obj.shift.poly,
356 r_b4match=self.ref.band4match + 1, # internally indexing from 0
357 s_b4match=self.shift.band4match + 1, # internally indexing from 0
358 max_iter=self.COREG_obj.max_iter,
359 max_shift=self.COREG_obj.max_shift,
360 nodata=(self.COREG_obj.ref.nodata, self.COREG_obj.shift.nodata),
361 force_quadratic_win=self.COREG_obj.force_quadratic_win,
362 binary_ws=self.COREG_obj.bin_ws,
363 v=False, # True leads to massive STDOUT
364 q=True, # True leads to massive STDOUT
365 ignore_errors=True
366 ) for point_id in GDF.index
367 )
368 ):
369 results.append(res)
371 if self.progress and not self.q:
372 bar.print_progress(percent=(i + 1) / len(GDF) * 100)
374 # merge results with GDF
375 # NOTE: We use a pandas.DataFrame here because the geometry column is missing.
376 # GDF.astype(...) fails with geopandas>0.6.0 if the geometry columns is missing.
377 records = DataFrame(results,
378 columns=['POINT_ID', 'X_WIN_SIZE', 'Y_WIN_SIZE', 'X_SHIFT_PX', 'Y_SHIFT_PX', 'X_SHIFT_M',
379 'Y_SHIFT_M', 'ABS_SHIFT', 'ANGLE', 'SSIM_BEFORE', 'SSIM_AFTER',
380 'SSIM_IMPROVED', 'RELIABILITY', 'LAST_ERR'])
382 # merge DataFrames (dtype must be equal to records.dtypes; We need np.object due to None values)
383 GDF = GDF.astype(object).merge(records.astype(object), on='POINT_ID', how="inner")
384 GDF = GDF.replace([np.nan, None], int(self.outFillVal)) # fillna fails with geopandas==0.6.0
385 GDF.crs = crs # gets lost when using GDF.astype(np.object), so we have to reassign that
387 n_matches = len(GDF[GDF.LAST_ERR == int(self.outFillVal)])
389 if not self.q:
390 print(f"Found {n_matches} matches.")
392 # filter tie points according to given filter level
393 if n_matches > 0 and self.tieP_filter_level > 0:
394 if not self.q:
395 print('Performing validity checks...')
396 TPR = Tie_Point_Refiner(GDF[GDF.ABS_SHIFT != self.outFillVal], **self.outlDetect_settings)
397 GDF_filt, new_columns = TPR.run_filtering(level=self.tieP_filter_level)
398 GDF = GDF.merge(GDF_filt[['POINT_ID'] + new_columns], on='POINT_ID', how="outer")
400 GDF = GDF.replace([np.nan, None], int(self.outFillVal)) # fillna fails with geopandas==0.6.0
402 self.CoRegPoints_table = GDF
404 if not self.q:
405 if n_matches == 0 or GDF.empty:
406 warnings.warn('No valid GCPs could by identified.')
407 else:
408 if self.tieP_filter_level > 0:
409 print("%d valid tie points remain after filtering." % len(GDF[GDF.OUTLIER.__eq__(False)]))
411 return self.CoRegPoints_table
413 def calc_rmse(self, include_outliers: bool = False) -> float:
414 """Calculate root-mean-square error of absolute shifts from the tie point grid.
416 :param include_outliers: whether to include tie points that have been marked as false-positives (if present)
417 """
418 if self.CoRegPoints_table.empty:
419 raise RuntimeError('Cannot compute the RMSE because no tie points were found at all.')
421 tbl = self.CoRegPoints_table
422 tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == 0].copy() if 'OUTLIER' in tbl.columns else tbl
424 if not include_outliers and tbl.empty:
425 raise RuntimeError('Cannot compute the RMSE because all tie points are flagged as false-positives.')
427 shifts = np.array(tbl['ABS_SHIFT'])
428 shifts_sq = [i * i for i in shifts if i != self.outFillVal]
430 return np.sqrt(sum(shifts_sq) / len(shifts_sq))
432 def calc_overall_ssim(self,
433 include_outliers: bool = False,
434 after_correction: bool = True
435 ) -> float:
436 """Calculate the median value of all SSIM values contained in tie point grid.
438 :param include_outliers: whether to include tie points that have been marked as false-positives
439 :param after_correction: whether to compute median SSIM before correction or after
440 """
441 if self.CoRegPoints_table.empty:
442 raise RuntimeError('Cannot compute the overall SSIM because no tie points were found at all.')
444 tbl = self.CoRegPoints_table
445 tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == 0].copy()
447 if not include_outliers and tbl.empty:
448 raise RuntimeError('Cannot compute the overall SSIM because all tie points are flagged as false-positives.')
450 ssim_col = np.array(tbl['SSIM_AFTER' if after_correction else 'SSIM_BEFORE'])
451 ssim_col = [i * i for i in ssim_col if i != self.outFillVal]
453 return float(np.median(ssim_col))
455 def calc_overall_stats(self, include_outliers: bool = False) -> dict:
456 """Calculate statistics like RMSE, MSE, MAE, ... from the tie point grid.
458 Full list of returned statistics:
460 - N_TP: number of tie points
461 - N_VALID_TP: number of valid tie points
462 - N_INVALID_TP: number of invalid tie points (false-positives)
463 - PERC_VALID_TP: percentage of valid tie points
464 - RMSE_M: root mean squared error of absolute shift vector length in map units
465 - RMSE_X_M: root mean squared error of shift vector length in x-direction in map units
466 - RMSE_Y_M: root mean squared error of shift vector length in y-direction in map units
467 - RMSE_X_PX: root mean squared error of shift vector length in x-direction in pixel units
468 - RMSE_Y_PX: root mean squared error of shift vector length in y-direction in pixel units
469 - MSE_M: mean squared error of absolute shift vector length in map units
470 - MSE_X_M: mean squared error of shift vector length in x-direction in map units
471 - MSE_Y_M: mean squared error of shift vector length in y-direction in map units
472 - MSE_X_PX: mean squared error of shift vector length in x-direction in pixel units
473 - MSE_Y_PX: mean squared error of shift vector length in y-direction in pixel units
474 - MAE_M: mean absolute error of absolute shift vector length in map units
475 - MAE_X_M: mean absolute error of shift vector length in x-direction in map units
476 - MAE_Y_M: mean absolute error of shift vector length in y-direction in map units
477 - MAE_X_PX: mean absolute error of shift vector length in x-direction in pixel units
478 - MAE_Y_PX: mean absolute error of shift vector length in y-direction in pixel units
479 - MEAN_ABS_SHIFT: mean absolute shift vector length in map units
480 - MEAN_X_SHIFT_M: mean shift vector length in x-direction in map units
481 - MEAN_Y_SHIFT_M: mean shift vector length in y-direction in map units
482 - MEAN_X_SHIFT_PX: mean shift vector length in x-direction in pixel units
483 - MEAN_Y_SHIFT_PX: mean shift vector length in y-direction in pixel units
484 - MEAN_ANGLE: mean direction of the shift vectors in degrees from north
485 - MEAN_SSIM_BEFORE: mean structural similatity index within each matching window before co-registration
486 - MEAN_SSIM_AFTER: mean structural similatity index within each matching window after co-registration
487 - MEAN_RELIABILITY: mean tie point reliability in percent
488 - MEDIAN_ABS_SHIFT: median absolute shift vector length in map units
489 - MEDIAN_X_SHIFT_M: median shift vector length in x-direction in map units
490 - MEDIAN_Y_SHIFT_M: median shift vector length in y-direction in map units
491 - MEDIAN_X_SHIFT_PX: median shift vector length in x-direction in pixel units
492 - MEDIAN_Y_SHIFT_PX: median shift vector length in y-direction in pixel units
493 - MEDIAN_ANGLE: median direction of the shift vectors in degrees from north
494 - MEDIAN_SSIM_BEFORE: median structural similatity index within each matching window before co-registration
495 - MEDIAN_SSIM_AFTER: median structural similatity index within each matching window after co-registration
496 - MEDIAN_RELIABILITY: median tie point reliability in percent
497 - STD_ABS_SHIFT: standard deviation of absolute shift vector length in map units
498 - STD_X_SHIFT_M: standard deviation of shift vector length in x-direction in map units
499 - STD_Y_SHIFT_M: standard deviation of shift vector length in y-direction in map units
500 - STD_X_SHIFT_PX: standard deviation of shift vector length in x-direction in pixel units
501 - STD_Y_SHIFT_PX: standard deviation of shift vector length in y-direction in pixel units
502 - STD_ANGLE: standard deviation of direction of the shift vectors in degrees from north
503 - STD_SSIM_BEFORE: standard deviation of structural similatity index within each matching window before
504 co-registration
505 - STD_SSIM_AFTER: standard deviation of structural similatity index within each matching window after
506 co-registration
507 - STD_RELIABILITY: standard deviation of tie point reliability in percent
508 - MIN_ABS_SHIFT: minimal absolute shift vector length in map units
509 - MIN_X_SHIFT_M: minimal shift vector length in x-direction in map units
510 - MIN_Y_SHIFT_M: minimal shift vector length in y-direction in map units
511 - MIN_X_SHIFT_PX: minimal shift vector length in x-direction in pixel units
512 - MIN_Y_SHIFT_PX: minimal shift vector length in y-direction in pixel units
513 - MIN_ANGLE: minimal direction of the shift vectors in degrees from north
514 - MIN_SSIM_BEFORE: minimal structural similatity index within each matching window before co-registration
515 - MIN_SSIM_AFTER: minimal structural similatity index within each matching window after co-registration
516 - MIN_RELIABILITY: minimal tie point reliability in percent
517 - MIN_ABS_SHIFT: maximal absolute shift vector length in map units
518 - MAX_X_SHIFT_M: maximal shift vector length in x-direction in map units
519 - MAX_Y_SHIFT_M: maximal shift vector length in y-direction in map units
520 - MAX_X_SHIFT_PX: maximal shift vector length in x-direction in pixel units
521 - MAX_Y_SHIFT_PX: maximal shift vector length in y-direction in pixel units
522 - MAX_ANGLE: maximal direction of the shift vectors in degrees from north
523 - MAX_SSIM_BEFORE: maximal structural similatity index within each matching window before co-registration
524 - MAX_SSIM_AFTER: maximal structural similatity index within each matching window after co-registration
525 - MAX_RELIABILITY: maximal tie point reliability in percent
527 :param include_outliers: whether to include tie points that have been marked as false-positives (if present)
528 """
529 if self.CoRegPoints_table.empty:
530 raise RuntimeError('Cannot compute overall statistics because no tie points were found at all.')
532 tbl = self.CoRegPoints_table
534 n_tiepoints = sum(tbl['ABS_SHIFT'] != self.outFillVal)
535 n_outliers = sum(tbl['OUTLIER'] == 1)
537 tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == 0].copy() if 'OUTLIER' in tbl.columns else tbl
538 tbl = tbl.copy().replace(self.outFillVal, np.nan)
540 if not include_outliers and tbl.empty:
541 raise RuntimeError('Cannot compute overall statistics '
542 'because all tie points are flagged as false-positives.')
544 def RMSE(shifts):
545 shifts_sq = shifts ** 2
546 return np.sqrt(sum(shifts_sq) / len(shifts_sq))
548 def MSE(shifts):
549 shifts_sq = shifts ** 2
550 return sum(shifts_sq) / len(shifts_sq)
552 def MAE(shifts):
553 shifts_abs = np.abs(shifts)
554 return sum(shifts_abs) / len(shifts_abs)
556 abs_shift, x_shift_m, y_shift_m, x_shift_px, y_shift_px, angle, ssim_before, ssim_after, reliability = \
557 [tbl[k].dropna().values for k in ['ABS_SHIFT', 'X_SHIFT_M', 'Y_SHIFT_M', 'X_SHIFT_PX', 'Y_SHIFT_PX',
558 'ANGLE', 'SSIM_BEFORE', 'SSIM_AFTER', 'RELIABILITY']]
560 stats = dict(
561 N_TP=n_tiepoints,
562 N_VALID_TP=len(abs_shift),
563 N_INVALID_TP=n_outliers,
564 PERC_VALID_TP=(n_tiepoints - n_outliers) / n_tiepoints * 100,
566 RMSE_M=RMSE(abs_shift),
567 RMSE_X_M=RMSE(x_shift_m),
568 RMSE_Y_M=RMSE(y_shift_m),
569 RMSE_X_PX=RMSE(x_shift_px),
570 RMSE_Y_PX=RMSE(y_shift_px),
572 MSE_M=MSE(abs_shift),
573 MSE_X_M=MSE(x_shift_m),
574 MSE_Y_M=MSE(y_shift_m),
575 MSE_X_PX=MSE(x_shift_px),
576 MSE_Y_PX=MSE(y_shift_px),
578 MAE_M=MAE(abs_shift),
579 MAE_X_M=MAE(x_shift_m),
580 MAE_Y_M=MAE(y_shift_m),
581 MAE_X_PX=MAE(x_shift_px),
582 MAE_Y_PX=MAE(y_shift_px),
583 )
585 for stat, func in zip(['mean', 'median', 'std', 'min', 'max'],
586 [np.mean, np.median, np.std, np.min, np.max]):
587 for n in ['abs_shift', 'x_shift_m', 'y_shift_m', 'x_shift_px', 'y_shift_px',
588 'angle', 'ssim_before', 'ssim_after', 'reliability']:
590 vals = locals()[n]
591 stats[f'{stat}_{n}'.upper()] = func(vals)
593 return stats
595 def plot_shift_distribution(self,
596 include_outliers: bool = True,
597 unit: str = 'm',
598 interactive: bool = False,
599 figsize: tuple = None,
600 xlim: list = None,
601 ylim: list = None,
602 fontsize: int = 12,
603 title: str = 'shift distribution',
604 savefigPath: str = '',
605 savefigDPI: int = 96,
606 showFig: bool = True,
607 return_fig: bool = False
608 ) -> tuple:
609 """Create a 2D scatterplot containing the distribution of calculated X/Y-shifts.
611 :param include_outliers: whether to include tie points that have been marked as false-positives
612 :param unit: 'm' for meters or 'px' for pixels (default: 'm')
613 :param interactive: whether to use interactive mode (uses plotly for visualization)
614 :param figsize: (xdim, ydim)
615 :param xlim: [xmin, xmax]
616 :param ylim: [ymin, ymax]
617 :param fontsize: size of all used fonts
618 :param title: the title to be plotted above the figure
619 :param savefigPath: path where to save the figure
620 :param savefigDPI: DPI resolution of the output figure when saved to disk
621 :param showFig: whether to show or to hide the figure
622 :param return_fig: whether to return the figure and axis objects
623 """
624 if unit not in ['m', 'px']:
625 raise ValueError("Parameter 'unit' must have the value 'm' (meters) or 'px' (pixels)! Got %s." % unit)
627 if self.CoRegPoints_table.empty:
628 raise RuntimeError('Shift distribution cannot be plotted because no tie points were found at all.')
630 tbl = self.CoRegPoints_table
631 tbl = tbl[tbl['ABS_SHIFT'] != self.outFillVal]
632 tbl_il = tbl[tbl['OUTLIER'] == 0].copy() if 'OUTLIER' in tbl.columns else tbl
633 tbl_ol = tbl[tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else None
634 x_attr = 'X_SHIFT_M' if unit == 'm' else 'X_SHIFT_PX'
635 y_attr = 'Y_SHIFT_M' if unit == 'm' else 'Y_SHIFT_PX'
636 rmse = self.calc_rmse(include_outliers=False) # always exclude outliers when calculating RMSE
637 figsize = figsize if figsize else (10, 10)
639 if interactive:
640 from plotly.offline import iplot, init_notebook_mode
641 import plotly.graph_objs as go
642 # FIXME outliers are not plotted
644 init_notebook_mode(connected=True)
646 # Create a trace
647 trace = go.Scatter(
648 x=tbl_il[x_attr],
649 y=tbl_il[y_attr],
650 mode='markers'
651 )
653 data = [trace]
655 # Plot and embed in ipython notebook!
656 iplot(data, filename='basic-scatter')
658 return None, None
660 else:
661 fig = plt.figure(figsize=figsize)
662 ax = fig.add_subplot(111)
664 if include_outliers and 'OUTLIER' in tbl.columns:
665 ax.scatter(tbl_ol[x_attr], tbl_ol[y_attr], marker='+', c='r', label='false-positives')
666 ax.scatter(tbl_il[x_attr], tbl_il[y_attr], marker='+', c='g', label='valid tie points')
668 # set axis limits
669 if not xlim:
670 xmax = np.abs(tbl_il[x_attr]).max()
671 xlim = [-xmax, xmax]
672 if not ylim:
673 ymax = np.abs(tbl_il[y_attr]).max()
674 ylim = [-ymax, ymax]
675 ax.set_xlim(xlim)
676 ax.set_ylim(ylim)
678 # add text box containing RMSE of plotted shifts
679 xlim, ylim = ax.get_xlim(), ax.get_ylim()
680 ax.text(xlim[1] - (xlim[1] / 20), -ylim[1] + (ylim[1] / 20),
681 'RMSE: %s m / %s px' % (np.round(rmse, 2), np.round(rmse / self.shift.xgsd, 2)),
682 ha='right', va='bottom', fontsize=fontsize, bbox=dict(facecolor='w', pad=None, alpha=0.8))
684 # add grid and increase linewidth of middle line
685 ax.grid(visible=True)
686 ax.spines["right"].set_visible(True)
687 ax.spines["top"].set_visible(True)
688 xgl = ax.get_xgridlines()
689 middle_xgl = xgl[int(np.median(np.array(range(len(xgl)))))]
690 middle_xgl.set_linewidth(2)
691 middle_xgl.set_linestyle('-')
692 ygl = ax.get_ygridlines()
693 middle_ygl = ygl[int(np.median(np.array(range(len(ygl)))))]
694 middle_ygl.set_linewidth(2)
695 middle_ygl.set_linestyle('-')
697 # set title and adjust tick labels
698 ax.set_title(title, fontsize=fontsize)
699 [tick.label1.set_fontsize(fontsize) for tick in ax.xaxis.get_major_ticks()]
700 [tick.label1.set_fontsize(fontsize) for tick in ax.yaxis.get_major_ticks()]
701 ax.set_xlabel('x-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
702 ax.set_ylabel('y-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
704 # add legend with labels in the right order
705 handles, labels = ax.get_legend_handles_labels()
706 leg = plt.legend(reversed(handles), reversed(labels), fontsize=fontsize, loc='upper right', scatterpoints=3)
707 leg.get_frame().set_edgecolor('black')
709 # remove white space around the figure
710 fig.subplots_adjust(top=.94, bottom=.06, right=.96, left=.09)
712 if savefigPath:
713 fig.savefig(savefigPath, dpi=savefigDPI, pad_inches=0.3, bbox_inches='tight')
715 if return_fig:
716 return fig, ax
718 if showFig and not self.q:
719 plt.show(block=True)
720 else:
721 plt.close(fig)
723 def dump_CoRegPoints_table(self, path_out=None):
724 if self.CoRegPoints_table.empty:
725 raise RuntimeError('Cannot dump tie points table because it is empty.')
727 path_out = path_out if path_out else \
728 get_generic_outpath(dir_out=self.dir_out,
729 fName_out="CoRegPoints_table_grid%s_ws(%s_%s)__T_%s__R_%s.pkl"
730 % (self.grid_res, self.COREG_obj.win_size_XY[0],
731 self.COREG_obj.win_size_XY[1], self.shift.basename,
732 self.ref.basename))
733 if not self.q:
734 print('Writing %s ...' % path_out)
735 self.CoRegPoints_table.to_pickle(path_out)
737 def to_GCPList(self):
738 # get copy of tie points grid without no data
739 try:
740 GDF = self.CoRegPoints_table.loc[self.CoRegPoints_table.ABS_SHIFT != self.outFillVal, :].copy()
741 except AttributeError:
742 # self.CoRegPoints_table has no attribute 'ABS_SHIFT' because all points have been excluded
743 return []
745 if getattr(GDF, 'empty'): # GDF.empty returns AttributeError
746 return []
747 else:
748 # exclude all points flagged as outliers
749 if 'OUTLIER' in GDF.columns:
750 GDF = GDF[GDF.OUTLIER.__eq__(False)].copy()
751 avail_TP = len(GDF)
753 if not avail_TP:
754 # no point passed all validity checks
755 return []
757 if avail_TP > 7000:
758 GDF = GDF.sample(7000)
759 warnings.warn('By far not more than 7000 tie points can be used for warping within a limited '
760 'computation time (due to a GDAL bottleneck). Thus these 7000 points are randomly chosen '
761 'out of the %s available tie points.' % avail_TP)
763 # calculate GCPs
764 GDF['X_MAP_new'] = GDF.X_MAP + GDF.X_SHIFT_M
765 GDF['Y_MAP_new'] = GDF.Y_MAP + GDF.Y_SHIFT_M
766 GDF['GCP'] = GDF.apply(lambda GDF_row: gdal.GCP(GDF_row.X_MAP_new,
767 GDF_row.Y_MAP_new,
768 0,
769 GDF_row.X_IM,
770 GDF_row.Y_IM),
771 axis=1)
772 self.GCPList = GDF.GCP.tolist()
774 return self.GCPList
776 def test_if_singleprocessing_equals_multiprocessing_result(self):
777 # RANSAC filtering always produces different results because it includes random sampling
778 self.tieP_filter_level = 1
780 self.CPUs = None
781 dataframe = self.get_CoRegPoints_table()
782 mp_out = np.empty_like(dataframe.values)
783 mp_out[:] = dataframe.values
784 self.CPUs = 1
785 dataframe = self.get_CoRegPoints_table()
786 sp_out = np.empty_like(dataframe.values)
787 sp_out[:] = dataframe.values
789 return np.array_equal(sp_out, mp_out)
791 def _get_line_by_PID(self, PID):
792 return self.CoRegPoints_table.loc[PID, :]
794 def _get_lines_by_PIDs(self, PIDs):
795 assert isinstance(PIDs, list)
796 lines = np.zeros((len(PIDs), self.CoRegPoints_table.shape[1]))
797 for i, PID in enumerate(PIDs):
798 lines[i, :] = self.CoRegPoints_table[self.CoRegPoints_table['POINT_ID'] == PID]
799 return lines
801 def to_PointShapefile(self,
802 path_out: str = None,
803 skip_nodata: bool = True,
804 skip_nodata_col: str = 'ABS_SHIFT',
805 skip_outliers: bool = False
806 ) -> None:
807 """Write the calculated tie point grid to a point shapefile (e.g., for visualization by a GIS software).
809 NOTE: The shapefile uses Tie_Point_Grid.CoRegPoints_table as attribute table.
811 :param path_out: <str> the output path. If not given, it is automatically defined.
812 :param skip_nodata: <bool> whether to skip all points where no valid match could be found
813 :param skip_nodata_col: <str> determines which column of Tie_Point_Grid.CoRegPoints_table is used to
814 identify points where no valid match could be found
815 :param skip_outliers: <bool> whether to exclude all tie points that have been flagged as outlier
816 (false-positive)
817 """
818 if self.CoRegPoints_table.empty:
819 raise RuntimeError('Cannot save a point shapefile because no tie points were found at all.')
821 GDF2pass = self.CoRegPoints_table
823 if skip_nodata:
824 GDF2pass = GDF2pass[GDF2pass[skip_nodata_col] != self.outFillVal].copy()
825 else:
826 # use the error represetation (including error type) instead of only the error message
827 GDF2pass.LAST_ERR = GDF2pass.apply(lambda GDF_row: repr(GDF_row.LAST_ERR), axis=1)
829 if skip_outliers:
830 GDF2pass = GDF2pass[~GDF2pass['OUTLIER'].__eq__(True)].copy()
832 # replace boolean values (cannot be written)
833 GDF2pass = GDF2pass.replace(False, 0).copy() # replace booleans where column dtype is not bool but np.object
834 GDF2pass = GDF2pass.replace(True, 1).copy()
835 for col in GDF2pass.columns:
836 if GDF2pass[col].dtype == bool:
837 GDF2pass[col] = GDF2pass[col].astype(int)
839 path_out = path_out if path_out else \
840 get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
841 fName_out="CoRegPoints_grid%s_ws(%s_%s)__T_%s__R_%s.shp"
842 % (self.grid_res, self.COREG_obj.win_size_XY[0],
843 self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
844 if not self.q:
845 print('Writing %s ...' % path_out)
846 GDF2pass.to_file(path_out)
848 def to_vectorfield(self, path_out: str = None, fmt: str = None, mode: str = 'md') -> GeoArray:
849 """Save the calculated X-/Y-shifts to a 2-band raster file that can be used to visualize a vectorfield.
851 NOTE: For example ArcGIS is able to visualize such 2-band raster files as a vectorfield.
853 :param path_out: the output path. If not given, it is automatically defined.
854 :param fmt: output raster format string
855 :param mode: The mode how the output is written ('uv' or 'md'; default: 'md')
856 - 'uv': outputs X-/Y shifts
857 - 'md': outputs magnitude and direction
858 """
859 assert mode in ['uv', 'md'], "'mode' must be either 'uv' (outputs X-/Y shifts) or 'md' " \
860 "(outputs magnitude and direction)'. Got %s." % mode
861 attr_b1 = 'X_SHIFT_M' if mode == 'uv' else 'ABS_SHIFT'
862 attr_b2 = 'Y_SHIFT_M' if mode == 'uv' else 'ANGLE'
864 if self.CoRegPoints_table.empty:
865 raise RuntimeError('Cannot save the vector field because no tie points were found at all.')
867 xshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'],
868 values=self.CoRegPoints_table[attr_b1],
869 tgt_res=self.shift.xgsd * self.grid_res,
870 prj=self.CoRegPoints_table.crs.to_wkt(),
871 fillVal=self.outFillVal)
873 yshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'],
874 values=self.CoRegPoints_table[attr_b2],
875 tgt_res=self.shift.xgsd * self.grid_res,
876 prj=self.CoRegPoints_table.crs.to_wkt(),
877 fillVal=self.outFillVal)
879 out_GA = GeoArray(np.dstack([xshift_arr, yshift_arr]), gt, prj, nodata=self.outFillVal)
881 path_out = path_out if path_out else \
882 get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
883 fName_out="CoRegVectorfield%s_ws(%s_%s)__T_%s__R_%s.tif"
884 % (self.grid_res, self.COREG_obj.win_size_XY[0],
885 self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
887 out_GA.save(path_out, fmt=fmt if fmt else 'Gtiff')
889 return out_GA
891 def to_interpolated_raster(self,
892 metric: str = 'ABS_SHIFT',
893 method: str = 'RBF',
894 plot_result: bool = False,
895 lowres_spacing: int = 5,
896 v: bool = False
897 ) -> np.ndarray:
898 """Interpolate the point data of the given metric into space.
900 :param metric: metric name to interpolate, i.e., one of the column names of
901 Tie_Point_Grid.CoRegPoints_table, e.g., 'ABS_SHIFT'.
902 :param method: interpolation algorithm
903 - 'RBF' (Radial Basis Function)
904 - 'GPR' (Gaussian Process Regression; equivalent to Simple Kriging)
905 - 'Kriging' (Ordinary Kriging based on pykrige)
906 :param plot_result: plot the result to assess the interpolation quality
907 :param lowres_spacing: by default, RBF, GPR, and Kriging run a lower resolution which is then linearly
908 interpolated to the full output image resolution. lowres_spacing defines the number of
909 pixels between the low resolution grid points
910 (higher values are faster but less accurate, default: 5)
911 :param v: enable verbose mode
912 :return: interpolation result as numpy array in the X/Y dimension of the target image of the co-registration
913 """
914 TPGI = Tie_Point_Grid_Interpolator(self, v=v)
916 return TPGI.interpolate(metric=metric, method=method, plot_result=plot_result, lowres_spacing=lowres_spacing)
919class Tie_Point_Refiner(object):
920 """A class for performing outlier detection."""
922 def __init__(self, GDF,
923 min_reliability=60,
924 rs_max_outlier: float = 10,
925 rs_tolerance: float = 2.5,
926 rs_max_iter: int = 15,
927 rs_exclude_previous_outliers: bool = True,
928 rs_timeout: float = 20,
929 rs_random_state: Optional[int] = 0,
930 q: bool = False):
931 """Get an instance of Tie_Point_Refiner.
933 :param GDF: GeoDataFrame like TiePointGrid.CoRegPoints_table containing all tie
934 points to be filtered and the corresponding metadata
935 :param min_reliability: minimum threshold for previously computed tie X/Y shift
936 reliability (default: 60%)
937 :param rs_max_outlier: RANSAC: maximum percentage of outliers to be detected
938 (default: 10%)
939 :param rs_tolerance: RANSAC: percentage tolerance for max_outlier_percentage
940 (default: 2.5%)
941 :param rs_max_iter: RANSAC: maximum iterations for finding the best RANSAC threshold
942 (default: 15)
943 :param rs_exclude_previous_outliers: RANSAC: whether to exclude points that have been flagged as
944 outlier by earlier filtering (default:True)
945 :param rs_timeout: RANSAC: timeout for iteration loop in seconds (default: 20)
946 :param rs_random_state: RANSAC random state (an integer corresponds to a fixed/pseudo-random
947 state, None randomizes the result)
949 :param q:
950 """
951 self.GDF = GDF.copy()
952 self.min_reliability = min_reliability
953 self.rs_max_outlier_percentage = rs_max_outlier
954 self.rs_tolerance = rs_tolerance
955 self.rs_max_iter = rs_max_iter
956 self.rs_exclude_previous_outliers = rs_exclude_previous_outliers
957 self.rs_timeout = rs_timeout
958 self.rs_random_state = rs_random_state
959 self.q = q
960 self.new_cols = []
961 self.ransac_model_robust = None
963 def run_filtering(self, level=3):
964 """Filter tie points used for shift correction.
966 :param level: tie point filter level (default: 3).
967 NOTE: lower levels are also included if a higher level is chosen
969 - Level 0: no tie point filtering
970 - Level 1: Reliablity filtering
971 - filter all tie points out that have a low reliability according to internal tests
972 - Level 2: SSIM filtering
973 - filters all tie points out where shift correction does not increase image
974 similarity within matching window (measured by mean structural similarity index)
975 - Level 3: RANSAC outlier detection
977 :return:
978 """
979 # TODO catch empty GDF
981 # RELIABILITY filtering
982 if level > 0:
983 marked_recs = self._reliability_thresholding() # type: Series
984 self.GDF['L1_OUTLIER'] = marked_recs
985 self.new_cols.append('L1_OUTLIER')
987 n_flagged = len(marked_recs[marked_recs])
988 perc40 = np.percentile(self.GDF.RELIABILITY, 40)
990 if n_flagged / len(self.GDF) > .7:
991 warnings.warn(r"More than 70%% of the found tie points have a reliability lower than %s%% and are "
992 r"therefore marked as false-positives. Consider relaxing the minimum reliability "
993 r"(parameter 'min_reliability') to avoid that. For example min_reliability=%d would only "
994 r"flag 40%% of the tie points in case of your input data."
995 % (self.min_reliability, perc40))
997 if not self.q:
998 print('%s tie points flagged by level 1 filtering (reliability).'
999 % n_flagged)
1001 # SSIM filtering
1002 if level > 1:
1003 marked_recs = self._SSIM_filtering()
1004 self.GDF['L2_OUTLIER'] = marked_recs # type: Series
1005 self.new_cols.append('L2_OUTLIER')
1007 if not self.q:
1008 print('%s tie points flagged by level 2 filtering (SSIM).' % (len(marked_recs[marked_recs])))
1010 # RANSAC filtering
1011 if level > 2:
1012 # exclude previous outliers
1013 ransacInGDF = self.GDF[~self.GDF[self.new_cols].any(axis=1)].copy() \
1014 if self.rs_exclude_previous_outliers else self.GDF
1016 if len(ransacInGDF) > 4:
1017 # running RANSAC with less than four tie points makes no sense
1019 marked_recs = self._RANSAC_outlier_detection(ransacInGDF) # type: Series
1020 # we need to join a list here because otherwise it's merged by the 'index' column
1021 self.GDF['L3_OUTLIER'] = marked_recs.tolist()
1023 if not self.q:
1024 print('%s tie points flagged by level 3 filtering (RANSAC)'
1025 % (len(marked_recs[marked_recs])))
1026 else:
1027 print('RANSAC skipped because too less valid tie points have been found.')
1028 self.GDF['L3_OUTLIER'] = False
1030 self.new_cols.append('L3_OUTLIER')
1032 self.GDF['OUTLIER'] = self.GDF[self.new_cols].any(axis=1)
1033 self.new_cols.append('OUTLIER')
1035 return self.GDF, self.new_cols
1037 def _reliability_thresholding(self):
1038 """Exclude all records where estimated reliability of the calculated shifts is below the given threshold."""
1039 return self.GDF.RELIABILITY < self.min_reliability
1041 def _SSIM_filtering(self):
1042 """Exclude all records where SSIM decreased."""
1043 # ssim_diff = np.median(self.GDF['SSIM_AFTER']) - np.median(self.GDF['SSIM_BEFORE'])
1045 # self.GDF.SSIM_IMPROVED = \
1046 # self.GDF.apply(lambda GDF_row: GDF_row['SSIM_AFTER']>GDF_row['SSIM_BEFORE'] + ssim_diff, axis=1)
1048 return ~self.GDF.SSIM_IMPROVED
1050 def _RANSAC_outlier_detection(self, inGDF):
1051 """Detect geometric outliers between point cloud of source and estimated coordinates using RANSAC algorithm."""
1052 # from skimage.transform import PolynomialTransform # import here to avoid static TLS ImportError
1054 src_coords = np.array(inGDF[['X_MAP', 'Y_MAP']])
1055 xyShift = np.array(inGDF[['X_SHIFT_M', 'Y_SHIFT_M']])
1056 est_coords = src_coords + xyShift
1058 for co, n in zip([src_coords, est_coords], ['src_coords', 'est_coords']):
1059 assert co.ndim == 2 and co.shape[1] == 2, "'%s' must have shape [Nx2]. Got shape %s." % (n, co.shape)
1061 if not 0 < self.rs_max_outlier_percentage < 100:
1062 raise ValueError
1063 min_inlier_percentage = 100 - self.rs_max_outlier_percentage
1065 # class PolyTF_1(PolynomialTransform): # pragma: no cover
1066 # def estimate(*data):
1067 # return PolynomialTransform.estimate(*data, order=1)
1069 # robustly estimate affine transform model with RANSAC
1070 # eliminates not more than the given maximum outlier percentage of the tie points
1072 model_robust, inliers = None, None
1073 count_inliers = None
1074 th = 5 # start RANSAC threshold
1075 th_checked = {} # dict of thresholds that already have been tried + calculated inlier percentage
1076 th_substract = 2
1077 count_iter = 0
1078 time_start = time()
1079 ideal_count = min_inlier_percentage * src_coords.shape[0] / 100
1081 # optimize RANSAC threshold so that it marks not much more or less than the given outlier percentage
1082 while True:
1083 if th_checked:
1084 th_too_strict = count_inliers < ideal_count # True if too few inliers remaining
1086 # calculate new theshold using old increment
1087 # (but ensure th_new>0 by adjusting increment if needed)
1088 th_new = 0
1089 while th_new <= 0:
1090 th_new = th + th_substract if th_too_strict else th - th_substract
1091 if th_new <= 0:
1092 th_substract /= 2
1094 # check if calculated new threshold has been used before
1095 th_already_checked = th_new in th_checked.keys()
1097 # if yes, decrease increment and recalculate new threshold
1098 th_substract = \
1099 th_substract if not th_already_checked else \
1100 th_substract / 2
1101 th = th_new if not th_already_checked else \
1102 (th + th_substract if th_too_strict else
1103 th - th_substract)
1105 ###############
1106 # RANSAC call #
1107 ###############
1109 # model_robust, inliers = ransac((src, dst),
1110 # PolynomialTransform,
1111 # min_samples=3)
1112 if src_coords.size and \
1113 est_coords.size and \
1114 src_coords.shape[0] > 6:
1115 # import here to avoid static TLS ImportError
1116 from skimage.measure import ransac
1117 from skimage.transform import AffineTransform
1119 model_robust, inliers = \
1120 ransac((src_coords, est_coords),
1121 AffineTransform,
1122 min_samples=6,
1123 residual_threshold=th,
1124 max_trials=2000,
1125 stop_sample_num=int(
1126 (min_inlier_percentage - self.rs_tolerance) /
1127 100 * src_coords.shape[0]
1128 ),
1129 stop_residuals_sum=int(
1130 (self.rs_max_outlier_percentage - self.rs_tolerance) /
1131 100 * src_coords.shape[0]),
1132 rng=self.rs_random_state
1133 )
1134 else:
1135 warnings.warn('RANSAC filtering could not be applied '
1136 'because there were too few tie points to fit a model.')
1137 inliers = np.array([])
1138 break
1140 count_inliers = np.count_nonzero(inliers)
1142 th_checked[th] = count_inliers / src_coords.shape[0] * 100
1143 # print(th,'\t', th_checked[th], )
1145 if min_inlier_percentage - self.rs_tolerance <\
1146 th_checked[th] <\
1147 min_inlier_percentage + self.rs_tolerance:
1148 # print('in tolerance')
1149 break
1151 if count_iter > self.rs_max_iter or \
1152 time() - time_start > self.rs_timeout:
1153 break # keep last values and break while loop
1155 count_iter += 1
1157 outliers = inliers.__eq__(False) if inliers is not None and inliers.size else np.array([])
1159 if inGDF.empty or outliers is None or \
1160 (isinstance(outliers, list) and not outliers) or \
1161 (isinstance(outliers, np.ndarray) and not outliers.size):
1162 outseries = Series([False] * len(self.GDF))
1164 elif len(inGDF) < len(self.GDF):
1165 inGDF['outliers'] = outliers
1166 fullGDF = GeoDataFrame(self.GDF['POINT_ID'])
1167 fullGDF = fullGDF.merge(inGDF[['POINT_ID', 'outliers']],
1168 on='POINT_ID',
1169 how="outer")
1170 # fullGDF.outliers.copy()[~fullGDF.POINT_ID.isin(GDF.POINT_ID)] = False
1171 fullGDF = fullGDF.fillna(False) # NaNs are due to exclude_previous_outliers
1172 outseries = fullGDF['outliers']
1174 else:
1175 outseries = Series(outliers)
1177 assert len(outseries) == len(self.GDF), \
1178 'RANSAC output validation failed.'
1180 self.ransac_model_robust = model_robust
1182 return outseries
1185class Tie_Point_Grid_Interpolator(object):
1186 """Class to interpolate tie point data into space."""
1188 def __init__(self, tiepointgrid: Tie_Point_Grid, v: bool = False) -> None:
1189 """Get an instance of Tie_Point_Grid_Interpolator.
1191 :param tiepointgrid: instance of Tie_Point_Grid after computing spatial shifts
1192 :param v: enable verbose mode
1193 """
1194 self.tpg = tiepointgrid
1195 self.v = v
1197 def interpolate(self,
1198 metric: str,
1199 method: str = 'RBF',
1200 plot_result: bool = False,
1201 lowres_spacing: int = 5
1202 ) -> np.array:
1203 """Interpolate the point data of the given metric into space.
1205 :param metric: metric name to interpolate, i.e., one of the column names of
1206 Tie_Point_Grid.CoRegPoints_table, e.g., 'ABS_SHIFT'.
1207 :param method: interpolation algorithm
1208 - 'RBF' (Radial Basis Function)
1209 - 'GPR' (Gaussian Process Regression; equivalent to Simple Kriging)
1210 - 'Kriging' (Ordinary Kriging based on pykrige)
1211 :param plot_result: plot the result to assess the interpolation quality
1212 :param lowres_spacing: by default, RBF, GPR, and Kriging run a lower resolution which is then linearly
1213 interpolated to the full output image resolution. lowres_spacing defines the number of
1214 pixels between the low resolution grid points
1215 (higher values are faster but less accurate, default: 5)
1216 :return: interpolation result as numpy array in the X/Y dimension of the target image of the co-registration
1217 """
1218 t0 = time()
1220 rows, cols, data = self._get_pointdata(metric)
1221 nrows_out, ncols_out = self.tpg.shift.shape[:2]
1223 rows_lowres = np.arange(0, nrows_out + lowres_spacing, lowres_spacing)
1224 cols_lowres = np.arange(0, ncols_out + lowres_spacing, lowres_spacing)
1225 args = rows, cols, data, rows_lowres, cols_lowres
1227 if method == 'RBF':
1228 data_lowres = self._interpolate_via_rbf(*args)
1229 elif method == 'GPR':
1230 data_lowres = self._interpolate_via_gpr(*args)
1231 elif method == 'Kriging':
1232 data_lowres = self._interpolate_via_kriging(*args)
1233 else:
1234 raise ValueError(method)
1236 if lowres_spacing > 1:
1237 rows_full = np.arange(nrows_out)
1238 cols_full = np.arange(ncols_out)
1239 data_full = self._interpolate_regulargrid(rows_lowres, cols_lowres, data_lowres, rows_full, cols_full)
1240 else:
1241 data_full = data_lowres
1243 if self.v:
1244 print('interpolation runtime: %.2fs' % (time() - t0))
1245 if plot_result:
1246 self._plot_interpolation_result(data_full, rows, cols, data, metric)
1248 return data_full
1250 def _get_pointdata(self, metric: str):
1251 """Get the point data for the given metric from Tie_Point_Grid.CoRegPoints_table while ignoring outliers."""
1252 tiepoints = self.tpg.CoRegPoints_table[self.tpg.CoRegPoints_table.OUTLIER.__eq__(False)].copy()
1254 rows = np.array(tiepoints.Y_IM)
1255 cols = np.array(tiepoints.X_IM)
1256 data = np.array(tiepoints[metric])
1258 return rows, cols, data
1260 @staticmethod
1261 def _plot_interpolation_result(data_full: np.ndarray,
1262 rows: np.ndarray,
1263 cols: np.ndarray,
1264 data: np.ndarray,
1265 metric: str
1266 ):
1267 """Plot the interpolation result together with the input point data."""
1268 plt.figure(figsize=(7, 7))
1269 im = plt.imshow(data_full)
1270 plt.colorbar(im)
1271 plt.scatter(cols, rows, c=data, edgecolors='black')
1272 plt.title(metric)
1273 plt.show()
1275 @staticmethod
1276 def _interpolate_regulargrid(rows: np.ndarray,
1277 cols: np.ndarray,
1278 data: np.ndarray,
1279 rows_full: np.ndarray,
1280 cols_full: np.ndarray
1281 ):
1282 """Run linear regular grid interpolation."""
1283 RGI = RegularGridInterpolator(points=[cols, rows],
1284 values=data.T, # must be in shape [x, y]
1285 method='linear',
1286 bounds_error=False)
1287 data_full = RGI(np.dstack(np.meshgrid(cols_full, rows_full)))
1288 return data_full
1290 @staticmethod
1291 def _interpolate_via_rbf(rows: np.ndarray,
1292 cols: np.ndarray,
1293 data: np.ndarray,
1294 rows_full: np.ndarray,
1295 cols_full: np.ndarray
1296 ):
1297 """Run Radial Basis Function (RBF) interpolation.
1299 -> https://github.com/agile-geoscience/xlines/blob/master/notebooks/11_Gridding_map_data.ipynb
1300 -> documents the legacy scipy.interpolate.Rbf
1301 """
1302 rbf = RBFInterpolator(
1303 np.column_stack([cols, rows]), data,
1304 kernel="linear",
1305 # kernel="thin_plate_spline",
1306 )
1307 cols_grid, rows_grid = np.meshgrid(cols_full, rows_full)
1308 data_full = \
1309 rbf(np.column_stack([cols_grid.flat, rows_grid.flat]))\
1310 .reshape(rows_grid.shape)
1312 return data_full
1314 @staticmethod
1315 def _interpolate_via_gpr(rows: np.ndarray,
1316 cols: np.ndarray,
1317 data: np.ndarray,
1318 rows_full: np.ndarray,
1319 cols_full: np.ndarray
1320 ):
1321 """Run Gaussian Process Regression (GPR) interpolation.
1323 -> https://stackoverflow.com/questions/24978052/interpolation-over-regular-grid-in-python
1324 """
1325 try:
1326 import sklearn # noqa F401
1327 except ModuleNotFoundError:
1328 raise ModuleNotFoundError(
1329 "GPR interpolation requires the optional package 'scikit-learn' to be installed. You may install it "
1330 "with Conda (conda install -c conda-forge scikit-learn) or Pip (pip install scikit-learn)."
1331 )
1333 from sklearn.gaussian_process.kernels import RBF
1334 from sklearn.gaussian_process import GaussianProcessRegressor
1336 gp = GaussianProcessRegressor(
1337 normalize_y=False,
1338 alpha=0.001, # Larger values imply more input noise and result in smoother grids; default: 1e-10
1339 kernel=RBF(length_scale=100))
1340 gp.fit(np.column_stack([cols, rows]), data.T)
1342 cols_grid, rows_grid = np.meshgrid(cols_full, rows_full)
1343 data_full = \
1344 gp.predict(np.column_stack([cols_grid.flat, rows_grid.flat]))\
1345 .reshape(rows_grid.shape)
1347 return data_full
1349 @staticmethod
1350 def _interpolate_via_kriging(rows: np.ndarray,
1351 cols: np.ndarray,
1352 data: np.ndarray,
1353 rows_full: np.ndarray,
1354 cols_full: np.ndarray
1355 ):
1356 """Run Ordinary Kriging interpolation based on pykrige.
1358 Reference: P.K. Kitanidis, Introduction to Geostatistics: Applications in Hydrogeology,
1359 (Cambridge University Press, 1997) 272 p.
1360 """
1361 try:
1362 import pykrige # noqa F401
1363 except ModuleNotFoundError:
1364 raise ModuleNotFoundError(
1365 "Ordinary Kriging requires the optional package 'pykrige' to be installed. You may install it with "
1366 "Conda (conda install -c conda-forge pykrige) or Pip (pip install pykrige)."
1367 )
1369 from pykrige.ok import OrdinaryKriging
1371 OK = OrdinaryKriging(cols.astype(float), rows.astype(float), data.astype(float),
1372 variogram_model='spherical',
1373 verbose=False)
1375 data_full, sigmasq = \
1376 OK.execute('grid',
1377 cols_full.astype(float),
1378 rows_full.astype(float),
1379 backend='C',
1380 # n_closest_points=12
1381 )
1383 return data_full