Coverage for arosics/DeShifter.py: 86%
165 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-04-03 14:59 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2024-04-03 14:59 +0000
1# -*- coding: utf-8 -*-
3# AROSICS - Automated and Robust Open-Source Image Co-Registration Software
4#
5# Copyright (C) 2017-2024
6# - Daniel Scheffler (GFZ Potsdam, daniel.scheffler@gfz-potsdam.de)
7# - Helmholtz Centre Potsdam - GFZ German Research Centre for Geosciences Potsdam,
8# Germany (https://www.gfz-potsdam.de/)
9#
10# This software was developed within the context of the GeoMultiSens project funded
11# by the German Federal Ministry of Education and Research
12# (project grant code: 01 IS 14 010 A-C).
13#
14# Licensed under the Apache License, Version 2.0 (the "License");
15# you may not use this file except in compliance with the License.
16# You may obtain a copy of the License at
17#
18# https://www.apache.org/licenses/LICENSE-2.0
19#
20# Unless required by applicable law or agreed to in writing, software
21# distributed under the License is distributed on an "AS IS" BASIS,
22# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23# See the License for the specific language governing permissions and
24# limitations under the License.
26import collections
27import time
28import warnings
29import numpy as np
30from typing import Union
32# internal modules
33from geoarray import GeoArray
34from py_tools_ds.geo.map_info import mapinfo2geotransform, geotransform2mapinfo
35from py_tools_ds.geo.coord_grid import is_coord_grid_equal
36from py_tools_ds.geo.projection import prj_equal
37from py_tools_ds.geo.raster.reproject import warp_ndarray
38from py_tools_ds.numeric.vector import find_nearest
40__author__ = 'Daniel Scheffler'
42_dict_rspAlg_rsp_Int = {'nearest': 0, 'bilinear': 1, 'cubic': 2, 'cubic_spline': 3, 'lanczos': 4, 'average': 5,
43 'mode': 6, 'max': 7, 'min': 8, 'med': 9, 'q1': 10, 'q2': 11,
44 0: 'nearest', 1: 'bilinear', 2: 'cubic', 3: 'cubic_spline', 4: 'lanczos', 5: 'average',
45 6: 'mode', 7: 'max', 8: 'min', 9: 'med', 10: 'q1', 11: 'q2'}
48class DESHIFTER(object):
49 """
50 Class to deshift an image array or one of its products by applying previously the computed coregistration info.
52 See help(DESHIFTER) for documentation.
53 """
55 def __init__(self,
56 im2shift: Union[GeoArray, str],
57 coreg_results: dict,
58 **kwargs
59 ) -> None:
60 """Get an instance of DESHIFTER.
62 :param im2shift:
63 path of an image to be de-shifted or alternatively a GeoArray object
65 :param dict coreg_results:
66 the results of the co-registration as given by COREG.coreg_info or COREG_LOCAL.coreg_info
68 :keyword int path_out:
69 /output/directory/filename for coregistered results
71 :keyword str fmt_out:
72 raster file format for output file. ignored if path_out is None. can be any GDAL
73 compatible raster file format (e.g. 'ENVI', 'GTIFF'; default: ENVI)
75 :keyword list out_crea_options:
76 GDAL creation options for the output image, e.g., ["QUALITY=20", "REVERSIBLE=YES", "WRITE_METADATA=YES"]
78 :keyword int band2process:
79 The index of the band to be processed within the given array (starts with 1),
80 default = None (all bands are processed)
82 :keyword float nodata:
83 no data value of the image to be de-shifted
85 :keyword float out_gsd:
86 output pixel size in units of the reference coordinate system (default = pixel size of the input array),
87 given values are overridden by match_gsd=True
89 :keyword bool align_grids:
90 True: align the input coordinate grid to the reference (does not affect the output pixel size as long as
91 input and output pixel sizes are compatible (5:30 or 10:30 but not 4:30)), default = False
93 :keyword bool match_gsd:
94 True: match the input pixel size to the reference pixel size, default = False
96 :keyword list target_xyGrid:
97 a list with an x-grid and a y-grid like [[15,45], [15,45]].
98 This overrides 'out_gsd', 'align_grids' and 'match_gsd'.
100 :keyword int min_points_local_corr:
101 number of valid tie points, below which a global shift correction is performed instead of a local
102 correction (global X/Y shift is then computed as the mean shift of the remaining points)
103 (default: 5 tie points)
105 :keyword str resamp_alg:
106 the resampling algorithm to be used if neccessary
107 (valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average, mode, max, min, med, q1, q3)
109 :keyword bool cliptoextent:
110 True: clip the input image to its actual bounds while deleting possible no data areas outside the actual
111 bounds, default = False
113 :keyword list clipextent:
114 xmin, ymin, xmax, ymax - if given the calculation of the actual bounds is skipped.
115 The given coordinates are automatically snapped to the output grid.
117 :keyword int CPUs:
118 number of CPUs to use (default: None, which means 'all CPUs available')
120 :keyword bool progress:
121 show progress bars (default: True)
123 :keyword bool v:
124 verbose mode (default: False)
126 :keyword bool q:
127 quiet mode (default: False)
128 """
129 # private attributes
130 self._grids_alignable = None
132 # store args / kwargs
133 self.init_args = dict([x for x in locals().items() if x[0] != "self" and not x[0].startswith('__')])
134 self.init_kwargs = self.init_args['kwargs']
136 # unpack args
137 self.im2shift = im2shift if isinstance(im2shift, GeoArray) else GeoArray(im2shift)
138 self.GCPList = coreg_results['GCPList'] if 'GCPList' in coreg_results else None
139 self.ref_gt = coreg_results['reference geotransform']
140 self.ref_grid = coreg_results['reference grid']
141 self.ref_prj = coreg_results['reference projection']
143 # unpack kwargs
144 self.path_out = kwargs.get('path_out', None)
145 self.fmt_out = kwargs.get('fmt_out', 'ENVI')
146 self.out_creaOpt = kwargs.get('out_crea_options', [])
147 self.band2process = kwargs.get('band2process', None) # starts with 1 # FIXME why?
148 self.band2process = \
149 self.band2process - 1 if self.band2process is not None else None # internally handled as band index
150 self.nodata = kwargs.get('nodata', self.im2shift.nodata)
151 self.align_grids = kwargs.get('align_grids', False)
152 self.min_points_local_corr = kwargs.get('min_points_local_corr', 5)
153 self.rspAlg = kwargs.get('resamp_alg', 'cubic') # TODO accept also integers
154 self.cliptoextent = kwargs.get('cliptoextent', False)
155 self.clipextent = kwargs.get('clipextent', None)
156 self.CPUs = kwargs.get('CPUs', None)
157 self.v = kwargs.get('v', False)
158 self.q = kwargs.get('q', False) if not self.v else False # overridden by v
159 self.progress = kwargs.get('progress', True) if not self.q else False # overridden by q
161 self.im2shift.nodata = kwargs.get('nodata', self.im2shift.nodata)
162 self.im2shift.q = self.q
163 self.shift_prj = self.im2shift.projection
164 self.shift_gt = list(self.im2shift.geotransform)
166 # in case of local shift correction and local coreg results contain fewer points than min_points_local_corr:
167 # force global correction based on mean X/Y shifts
168 if 'GCPList' in coreg_results and len(coreg_results['GCPList']) < self.min_points_local_corr:
169 warnings.warn('Only %s valid tie point(s) could be identified. A local shift correction is therefore not '
170 'reasonable and could cause artifacts in the output image. The target image is '
171 'corrected globally with the mean X/Y shift of %.3f/%.3f pixels.'
172 % (len(self.GCPList), coreg_results['mean_shifts_px']['x'],
173 coreg_results['mean_shifts_px']['y']))
174 self.GCPList = None
175 coreg_results['updated map info'] = coreg_results['updated map info means']
177 # in case of global shift correction -> the updated map info from coreg_results already has the final map info
178 # BUT: this will be updated in correct_shifts() if clipextent is given or warping is needed
179 if not self.GCPList:
180 mapI = coreg_results['updated map info']
181 self.updated_map_info = mapI or geotransform2mapinfo(self.shift_gt, self.shift_prj)
182 self.updated_gt = mapinfo2geotransform(self.updated_map_info) or self.shift_gt
183 self.original_map_info = coreg_results['original map info']
184 self.updated_projection = self.ref_prj
186 self.out_grid = self._get_out_grid() # needs self.ref_grid, self.im2shift
187 self.out_gsd = [abs(self.out_grid[0][1] - self.out_grid[0][0]),
188 abs(self.out_grid[1][1] - self.out_grid[1][0])] # xgsd, ygsd
190 # assertions
191 assert self.rspAlg in _dict_rspAlg_rsp_Int.keys(), \
192 "'%s' is not a supported resampling algorithm." % self.rspAlg
193 if self.band2process is not None:
194 assert self.im2shift.bands - 1 >= self.band2process >= 0, \
195 "The %s '%s' has %s %s. So 'band2process' must be %s%s. Got %s." \
196 % (self.im2shift.__class__.__name__, self.im2shift.basename, self.im2shift.bands,
197 'bands' if self.im2shift.bands > 1 else 'band', 'between 1 and ' if self.im2shift.bands > 1 else '',
198 self.im2shift.bands, self.band2process + 1)
200 # set defaults for general class attributes
201 self.is_shifted = False # this is not included in COREG.coreg_info
202 self.is_resampled = False # this is not included in COREG.coreg_info
203 self.tracked_errors = []
204 self.arr_shifted = None # set by self.correct_shifts
205 self.GeoArray_shifted = None # set by self.correct_shifts
207 def _get_out_grid(self):
208 # parse given params
209 out_gsd = self.init_kwargs.get('out_gsd', None)
210 match_gsd = self.init_kwargs.get('match_gsd', False)
211 out_grid = self.init_kwargs.get('target_xyGrid', None)
213 # assertions
214 assert out_grid is None or (isinstance(out_grid, (list, tuple)) and len(out_grid) == 2)
215 assert out_gsd is None or (isinstance(out_gsd, (int, tuple, list)) and len(out_gsd) == 2)
217 ref_xgsd, ref_ygsd = (self.ref_grid[0][1] - self.ref_grid[0][0], abs(self.ref_grid[1][1] - self.ref_grid[1][0]))
219 def get_grid(gt, xgsd, ygsd): return [[gt[0], gt[0] + xgsd], [gt[3], gt[3] - ygsd]]
221 # get out_grid
222 if out_grid:
223 # output grid is given
224 pass
226 elif out_gsd:
227 out_xgsd, out_ygsd = [out_gsd, out_gsd] if isinstance(out_gsd, int) else out_gsd
229 if match_gsd and (out_xgsd, out_ygsd) != (ref_xgsd, ref_ygsd):
230 warnings.warn("\nThe parameter 'match_gsd is ignored because another output ground sampling distance "
231 "was explicitly given.")
232 if self.align_grids and \
233 self._are_grids_alignable(self.im2shift.xgsd, self.im2shift.ygsd, out_xgsd, out_ygsd):
234 # use grid of reference image with the given output gsd
235 out_grid = get_grid(self.ref_gt, out_xgsd, out_ygsd)
236 else: # no grid alignment
237 # use grid of input image with the given output gsd
238 out_grid = get_grid(self.im2shift.geotransform, out_xgsd, out_ygsd)
240 elif match_gsd:
241 if self.align_grids:
242 # use reference grid
243 out_grid = self.ref_grid
244 else:
245 # use grid of input image and reference gsd
246 out_grid = get_grid(self.im2shift.geotransform, ref_xgsd, ref_ygsd)
248 else:
249 if self.align_grids and \
250 self._are_grids_alignable(self.im2shift.xgsd, self.im2shift.ygsd, ref_xgsd, ref_ygsd):
251 # use origin of reference image and gsd of input image
252 out_grid = get_grid(self.ref_gt, self.im2shift.xgsd, self.im2shift.ygsd)
253 else:
254 if not self.GCPList:
255 # in case of global co-registration:
256 # -> use the target image grid but update the origin (shift-correction without resampling)
257 out_grid = get_grid(self.updated_gt, self.im2shift.xgsd, self.im2shift.ygsd)
258 else:
259 # in case of local co-registration:
260 # -> use input image grid
261 out_grid = get_grid(self.im2shift.geotransform, self.im2shift.xgsd, self.im2shift.ygsd)
263 return out_grid
265 @property
266 def warping_needed(self):
267 """Return True if image warping is needed in consideration of the input parameters of DESHIFTER."""
268 assert self.out_grid, 'Output grid must be calculated before.'
269 equal_prj = prj_equal(self.ref_prj, self.shift_prj)
270 return \
271 False if (equal_prj and not self.GCPList and is_coord_grid_equal(self.updated_gt, *self.out_grid)) else True
273 def _are_grids_alignable(self, in_xgsd, in_ygsd, out_xgsd, out_ygsd):
274 """Check if the input image pixel grid is alignable to the output grid.
276 :param in_xgsd:
277 :param in_ygsd:
278 :param out_xgsd:
279 :param out_ygsd:
280 :return:
281 """
282 if self._grids_alignable is None:
283 def is_alignable(gsd1, gsd2):
284 """Check if pixel sizes are divisible."""
285 return max(gsd1, gsd2) % min(gsd1, gsd2) == 0
287 self._grids_alignable = \
288 False if (not is_alignable(in_xgsd, out_xgsd) or not is_alignable(in_ygsd, out_ygsd)) else True
290 if self._grids_alignable is False and not self.q:
291 warnings.warn("\nThe coordinate grid of %s cannot be aligned to the desired grid because their pixel "
292 "sizes are not exact multiples of each other (input [X/Y]: %s/%s; desired [X/Y]: %s/%s). "
293 "Therefore the original grid is chosen for the resampled output image. If you don´t like "
294 "that you can use the 'out_gsd' or 'match_gsd' parameters to set an appropriate output "
295 "pixel size or to allow changing the pixel size.\n"
296 % (self.im2shift.basename, in_xgsd, in_ygsd, out_xgsd, out_ygsd))
298 return self._grids_alignable
300 def _get_out_extent(self):
301 if self.clipextent is None:
302 # no clip extent has been given
303 if self.cliptoextent:
304 # use actual image corners as clip extent
305 self.clipextent = self.im2shift.footprint_poly.envelope.bounds
306 else:
307 # use outer bounds of the image as clip extent
308 xmin, xmax, ymin, ymax = self.im2shift.box.boundsMap
309 self.clipextent = xmin, ymin, xmax, ymax
311 # snap clipextent to output grid
312 # (in case of odd input coords the output coords are moved INSIDE the input array)
313 xmin, ymin, xmax, ymax = self.clipextent
314 x_tol, y_tol = float(np.ptp(self.out_grid[0]) / 2000), float(np.ptp(self.out_grid[1]) / 2000) # 2.000th pix
315 xmin = find_nearest(self.out_grid[0], xmin, roundAlg='on', extrapolate=True, tolerance=x_tol)
316 ymin = find_nearest(self.out_grid[1], ymin, roundAlg='on', extrapolate=True, tolerance=y_tol)
317 xmax = find_nearest(self.out_grid[0], xmax, roundAlg='off', extrapolate=True, tolerance=x_tol)
318 ymax = find_nearest(self.out_grid[1], ymax, roundAlg='off', extrapolate=True, tolerance=y_tol)
319 return xmin, ymin, xmax, ymax
321 def correct_shifts(self) -> collections.OrderedDict:
322 if not self.q:
323 print('Correcting geometric shifts...')
325 t_start = time.time()
327 if not self.warping_needed:
328 """NO RESAMPLING NEEDED"""
330 self.is_shifted = True
331 self.is_resampled = False
332 xmin, ymin, xmax, ymax = self._get_out_extent()
334 if not self.q:
335 print("NOTE: The detected shift is corrected by updating the map info of the target image only, i.e., "
336 "without any resampling. Set the 'align_grids' parameter to True if you need the target and the "
337 "reference coordinate grids to be aligned.")
339 if self.cliptoextent:
340 # TODO validate results
341 # TODO -> output extent does not seem to be the requested one! (only relevant if align_grids=False)
342 # get shifted array
343 shifted_geoArr = GeoArray(self.im2shift[:], tuple(self.updated_gt), self.shift_prj)
345 # clip with target extent
346 # NOTE: get_mapPos() does not perform any resampling as long as source and target projection are equal
347 self.arr_shifted, self.updated_gt, self.updated_projection = \
348 shifted_geoArr.get_mapPos((xmin, ymin, xmax, ymax),
349 self.shift_prj,
350 fillVal=self.nodata,
351 band2get=self.band2process)
353 self.updated_map_info = geotransform2mapinfo(self.updated_gt, self.updated_projection)
355 else:
356 # array keeps the same; updated gt and prj are taken from coreg_info
357 self.arr_shifted = self.im2shift[:, :, self.band2process] \
358 if self.band2process is not None else self.im2shift[:]
360 out_geoArr = GeoArray(self.arr_shifted, self.updated_gt, self.updated_projection, q=self.q)
361 out_geoArr.nodata = self.nodata # equals self.im2shift.nodata after __init__()
362 out_geoArr.metadata = self.im2shift.metadata[[self.band2process]] \
363 if self.band2process is not None else self.im2shift.metadata
365 self.GeoArray_shifted = out_geoArr
367 else: # FIXME equal_prj==False ist noch NICHT implementiert
368 """RESAMPLING NEEDED"""
369 # FIXME avoid reading the whole band if clip_extent is passed
371 in_arr = self.im2shift[:, :, self.band2process] \
372 if self.band2process is not None and self.im2shift.ndim == 3 else self.im2shift[:]
374 if not self.GCPList:
375 # apply XY-shifts to input image gt 'shift_gt' in order to correct the shifts before warping
376 self.shift_gt[0], self.shift_gt[3] = self.updated_gt[0], self.updated_gt[3]
378 # get resampled array
379 out_arr, out_gt, out_prj = \
380 warp_ndarray(in_arr, self.shift_gt, self.shift_prj, self.ref_prj,
381 rspAlg=_dict_rspAlg_rsp_Int[self.rspAlg],
382 in_nodata=self.nodata,
383 out_nodata=self.nodata,
384 out_gsd=self.out_gsd,
385 out_bounds=self._get_out_extent(), # always returns an extent snapped to the target grid
386 gcpList=self.GCPList,
387 # polynomialOrder=str(3),
388 # options='-refine_gcps 500 1.9',
389 # warpOptions=['-refine_gcps 500 1.9'],
390 # options='-wm 10000',# -order 3',
391 # options=['-order 3'],
392 # options=['GDAL_CACHEMAX 800 '],
393 # warpMemoryLimit=125829120, # 120MB
394 CPUs=self.CPUs,
395 progress=self.progress,
396 q=self.q)
398 out_geoArr = GeoArray(out_arr, out_gt, out_prj, q=self.q)
399 out_geoArr.nodata = self.nodata # equals self.im2shift.nodata after __init__()
400 out_geoArr.metadata = self.im2shift.metadata[[self.band2process]] \
401 if self.band2process is not None else self.im2shift.metadata
403 self.arr_shifted = out_arr
404 self.updated_gt = out_gt
405 self.updated_projection = out_prj
406 self.updated_map_info = geotransform2mapinfo(out_gt, out_prj)
407 self.GeoArray_shifted = out_geoArr
408 self.is_shifted = True
409 self.is_resampled = True
411 if self.path_out:
412 out_geoArr.save(self.path_out, fmt=self.fmt_out, creationOptions=self.out_creaOpt)
414 # validation
415 if not is_coord_grid_equal(self.updated_gt, *self.out_grid, tolerance=1.e8):
416 raise RuntimeError('DESHIFTER output dataset has not the desired target pixel grid. Target grid '
417 'was %s. Output geotransform is %s.' % (str(self.out_grid), str(self.updated_gt)))
418 # TODO to be continued (extent, map info, ...)
420 if self.v:
421 print('Time for shift correction: %.2fs' % (time.time() - t_start))
422 return self.deshift_results
424 @property
425 def deshift_results(self):
426 deshift_results = collections.OrderedDict()
427 deshift_results.update({
428 'band': self.band2process,
429 'is shifted': self.is_shifted,
430 'is resampled': self.is_resampled,
431 'updated map info': self.updated_map_info,
432 'updated geotransform': self.updated_gt,
433 'updated projection': self.updated_projection,
434 'arr_shifted': self.arr_shifted,
435 'GeoArray_shifted': self.GeoArray_shifted
436 })
437 return deshift_results
440def deshift_image_using_coreg_info(im2shift: Union[GeoArray, str],
441 coreg_results: dict,
442 path_out: str = None,
443 fmt_out: str = 'ENVI',
444 q: bool = False):
445 """Correct a geometrically distorted image using previously calculated coregistration info.
447 This function can be used for example to correct spatial shifts of mask files using the same transformation
448 parameters that have been used to correct their source images.
450 :param im2shift: path of an image to be de-shifted or alternatively a GeoArray object
451 :param coreg_results: the results of the co-registration as given by COREG.coreg_info or
452 COREG_LOCAL.coreg_info respectively
453 :param path_out: /output/directory/filename for coregistered results. If None, no output is written - only
454 the shift corrected results are returned.
455 :param fmt_out: raster file format for output file. ignored if path_out is None. can be any GDAL
456 compatible raster file format (e.g. 'ENVI', 'GTIFF'; default: ENVI)
457 :param q: quiet mode (default: False)
458 :return:
459 """
460 deshift_results = DESHIFTER(im2shift, coreg_results).correct_shifts()
462 if path_out:
463 deshift_results['GeoArray_shifted'].save(path_out, fmt_out=fmt_out, q=q)
465 return deshift_results