Coverage for arosics/DeShifter.py: 86%

165 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-04-03 14:59 +0000

1# -*- coding: utf-8 -*- 

2 

3# AROSICS - Automated and Robust Open-Source Image Co-Registration Software 

4# 

5# Copyright (C) 2017-2024 

6# - Daniel Scheffler (GFZ Potsdam, daniel.scheffler@gfz-potsdam.de) 

7# - Helmholtz Centre Potsdam - GFZ German Research Centre for Geosciences Potsdam, 

8# Germany (https://www.gfz-potsdam.de/) 

9# 

10# This software was developed within the context of the GeoMultiSens project funded 

11# by the German Federal Ministry of Education and Research 

12# (project grant code: 01 IS 14 010 A-C). 

13# 

14# Licensed under the Apache License, Version 2.0 (the "License"); 

15# you may not use this file except in compliance with the License. 

16# You may obtain a copy of the License at 

17# 

18# https://www.apache.org/licenses/LICENSE-2.0 

19# 

20# Unless required by applicable law or agreed to in writing, software 

21# distributed under the License is distributed on an "AS IS" BASIS, 

22# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

23# See the License for the specific language governing permissions and 

24# limitations under the License. 

25 

26import collections 

27import time 

28import warnings 

29import numpy as np 

30from typing import Union 

31 

32# internal modules 

33from geoarray import GeoArray 

34from py_tools_ds.geo.map_info import mapinfo2geotransform, geotransform2mapinfo 

35from py_tools_ds.geo.coord_grid import is_coord_grid_equal 

36from py_tools_ds.geo.projection import prj_equal 

37from py_tools_ds.geo.raster.reproject import warp_ndarray 

38from py_tools_ds.numeric.vector import find_nearest 

39 

40__author__ = 'Daniel Scheffler' 

41 

42_dict_rspAlg_rsp_Int = {'nearest': 0, 'bilinear': 1, 'cubic': 2, 'cubic_spline': 3, 'lanczos': 4, 'average': 5, 

43 'mode': 6, 'max': 7, 'min': 8, 'med': 9, 'q1': 10, 'q2': 11, 

44 0: 'nearest', 1: 'bilinear', 2: 'cubic', 3: 'cubic_spline', 4: 'lanczos', 5: 'average', 

45 6: 'mode', 7: 'max', 8: 'min', 9: 'med', 10: 'q1', 11: 'q2'} 

46 

47 

48class DESHIFTER(object): 

49 """ 

50 Class to deshift an image array or one of its products by applying previously the computed coregistration info. 

51 

52 See help(DESHIFTER) for documentation. 

53 """ 

54 

55 def __init__(self, 

56 im2shift: Union[GeoArray, str], 

57 coreg_results: dict, 

58 **kwargs 

59 ) -> None: 

60 """Get an instance of DESHIFTER. 

61 

62 :param im2shift: 

63 path of an image to be de-shifted or alternatively a GeoArray object 

64 

65 :param dict coreg_results: 

66 the results of the co-registration as given by COREG.coreg_info or COREG_LOCAL.coreg_info 

67 

68 :keyword int path_out: 

69 /output/directory/filename for coregistered results 

70 

71 :keyword str fmt_out: 

72 raster file format for output file. ignored if path_out is None. can be any GDAL 

73 compatible raster file format (e.g. 'ENVI', 'GTIFF'; default: ENVI) 

74 

75 :keyword list out_crea_options: 

76 GDAL creation options for the output image, e.g., ["QUALITY=20", "REVERSIBLE=YES", "WRITE_METADATA=YES"] 

77 

78 :keyword int band2process: 

79 The index of the band to be processed within the given array (starts with 1), 

80 default = None (all bands are processed) 

81 

82 :keyword float nodata: 

83 no data value of the image to be de-shifted 

84 

85 :keyword float out_gsd: 

86 output pixel size in units of the reference coordinate system (default = pixel size of the input array), 

87 given values are overridden by match_gsd=True 

88 

89 :keyword bool align_grids: 

90 True: align the input coordinate grid to the reference (does not affect the output pixel size as long as 

91 input and output pixel sizes are compatible (5:30 or 10:30 but not 4:30)), default = False 

92 

93 :keyword bool match_gsd: 

94 True: match the input pixel size to the reference pixel size, default = False 

95 

96 :keyword list target_xyGrid: 

97 a list with an x-grid and a y-grid like [[15,45], [15,45]]. 

98 This overrides 'out_gsd', 'align_grids' and 'match_gsd'. 

99 

100 :keyword int min_points_local_corr: 

101 number of valid tie points, below which a global shift correction is performed instead of a local 

102 correction (global X/Y shift is then computed as the mean shift of the remaining points) 

103 (default: 5 tie points) 

104 

105 :keyword str resamp_alg: 

106 the resampling algorithm to be used if neccessary 

107 (valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average, mode, max, min, med, q1, q3) 

108 

109 :keyword bool cliptoextent: 

110 True: clip the input image to its actual bounds while deleting possible no data areas outside the actual 

111 bounds, default = False 

112 

113 :keyword list clipextent: 

114 xmin, ymin, xmax, ymax - if given the calculation of the actual bounds is skipped. 

115 The given coordinates are automatically snapped to the output grid. 

116 

117 :keyword int CPUs: 

118 number of CPUs to use (default: None, which means 'all CPUs available') 

119 

120 :keyword bool progress: 

121 show progress bars (default: True) 

122 

123 :keyword bool v: 

124 verbose mode (default: False) 

125 

126 :keyword bool q: 

127 quiet mode (default: False) 

128 """ 

129 # private attributes 

130 self._grids_alignable = None 

131 

132 # store args / kwargs 

133 self.init_args = dict([x for x in locals().items() if x[0] != "self" and not x[0].startswith('__')]) 

134 self.init_kwargs = self.init_args['kwargs'] 

135 

136 # unpack args 

137 self.im2shift = im2shift if isinstance(im2shift, GeoArray) else GeoArray(im2shift) 

138 self.GCPList = coreg_results['GCPList'] if 'GCPList' in coreg_results else None 

139 self.ref_gt = coreg_results['reference geotransform'] 

140 self.ref_grid = coreg_results['reference grid'] 

141 self.ref_prj = coreg_results['reference projection'] 

142 

143 # unpack kwargs 

144 self.path_out = kwargs.get('path_out', None) 

145 self.fmt_out = kwargs.get('fmt_out', 'ENVI') 

146 self.out_creaOpt = kwargs.get('out_crea_options', []) 

147 self.band2process = kwargs.get('band2process', None) # starts with 1 # FIXME why? 

148 self.band2process = \ 

149 self.band2process - 1 if self.band2process is not None else None # internally handled as band index 

150 self.nodata = kwargs.get('nodata', self.im2shift.nodata) 

151 self.align_grids = kwargs.get('align_grids', False) 

152 self.min_points_local_corr = kwargs.get('min_points_local_corr', 5) 

153 self.rspAlg = kwargs.get('resamp_alg', 'cubic') # TODO accept also integers 

154 self.cliptoextent = kwargs.get('cliptoextent', False) 

155 self.clipextent = kwargs.get('clipextent', None) 

156 self.CPUs = kwargs.get('CPUs', None) 

157 self.v = kwargs.get('v', False) 

158 self.q = kwargs.get('q', False) if not self.v else False # overridden by v 

159 self.progress = kwargs.get('progress', True) if not self.q else False # overridden by q 

160 

161 self.im2shift.nodata = kwargs.get('nodata', self.im2shift.nodata) 

162 self.im2shift.q = self.q 

163 self.shift_prj = self.im2shift.projection 

164 self.shift_gt = list(self.im2shift.geotransform) 

165 

166 # in case of local shift correction and local coreg results contain fewer points than min_points_local_corr: 

167 # force global correction based on mean X/Y shifts 

168 if 'GCPList' in coreg_results and len(coreg_results['GCPList']) < self.min_points_local_corr: 

169 warnings.warn('Only %s valid tie point(s) could be identified. A local shift correction is therefore not ' 

170 'reasonable and could cause artifacts in the output image. The target image is ' 

171 'corrected globally with the mean X/Y shift of %.3f/%.3f pixels.' 

172 % (len(self.GCPList), coreg_results['mean_shifts_px']['x'], 

173 coreg_results['mean_shifts_px']['y'])) 

174 self.GCPList = None 

175 coreg_results['updated map info'] = coreg_results['updated map info means'] 

176 

177 # in case of global shift correction -> the updated map info from coreg_results already has the final map info 

178 # BUT: this will be updated in correct_shifts() if clipextent is given or warping is needed 

179 if not self.GCPList: 

180 mapI = coreg_results['updated map info'] 

181 self.updated_map_info = mapI or geotransform2mapinfo(self.shift_gt, self.shift_prj) 

182 self.updated_gt = mapinfo2geotransform(self.updated_map_info) or self.shift_gt 

183 self.original_map_info = coreg_results['original map info'] 

184 self.updated_projection = self.ref_prj 

185 

186 self.out_grid = self._get_out_grid() # needs self.ref_grid, self.im2shift 

187 self.out_gsd = [abs(self.out_grid[0][1] - self.out_grid[0][0]), 

188 abs(self.out_grid[1][1] - self.out_grid[1][0])] # xgsd, ygsd 

189 

190 # assertions 

191 assert self.rspAlg in _dict_rspAlg_rsp_Int.keys(), \ 

192 "'%s' is not a supported resampling algorithm." % self.rspAlg 

193 if self.band2process is not None: 

194 assert self.im2shift.bands - 1 >= self.band2process >= 0, \ 

195 "The %s '%s' has %s %s. So 'band2process' must be %s%s. Got %s." \ 

196 % (self.im2shift.__class__.__name__, self.im2shift.basename, self.im2shift.bands, 

197 'bands' if self.im2shift.bands > 1 else 'band', 'between 1 and ' if self.im2shift.bands > 1 else '', 

198 self.im2shift.bands, self.band2process + 1) 

199 

200 # set defaults for general class attributes 

201 self.is_shifted = False # this is not included in COREG.coreg_info 

202 self.is_resampled = False # this is not included in COREG.coreg_info 

203 self.tracked_errors = [] 

204 self.arr_shifted = None # set by self.correct_shifts 

205 self.GeoArray_shifted = None # set by self.correct_shifts 

206 

207 def _get_out_grid(self): 

208 # parse given params 

209 out_gsd = self.init_kwargs.get('out_gsd', None) 

210 match_gsd = self.init_kwargs.get('match_gsd', False) 

211 out_grid = self.init_kwargs.get('target_xyGrid', None) 

212 

213 # assertions 

214 assert out_grid is None or (isinstance(out_grid, (list, tuple)) and len(out_grid) == 2) 

215 assert out_gsd is None or (isinstance(out_gsd, (int, tuple, list)) and len(out_gsd) == 2) 

216 

217 ref_xgsd, ref_ygsd = (self.ref_grid[0][1] - self.ref_grid[0][0], abs(self.ref_grid[1][1] - self.ref_grid[1][0])) 

218 

219 def get_grid(gt, xgsd, ygsd): return [[gt[0], gt[0] + xgsd], [gt[3], gt[3] - ygsd]] 

220 

221 # get out_grid 

222 if out_grid: 

223 # output grid is given 

224 pass 

225 

226 elif out_gsd: 

227 out_xgsd, out_ygsd = [out_gsd, out_gsd] if isinstance(out_gsd, int) else out_gsd 

228 

229 if match_gsd and (out_xgsd, out_ygsd) != (ref_xgsd, ref_ygsd): 

230 warnings.warn("\nThe parameter 'match_gsd is ignored because another output ground sampling distance " 

231 "was explicitly given.") 

232 if self.align_grids and \ 

233 self._are_grids_alignable(self.im2shift.xgsd, self.im2shift.ygsd, out_xgsd, out_ygsd): 

234 # use grid of reference image with the given output gsd 

235 out_grid = get_grid(self.ref_gt, out_xgsd, out_ygsd) 

236 else: # no grid alignment 

237 # use grid of input image with the given output gsd 

238 out_grid = get_grid(self.im2shift.geotransform, out_xgsd, out_ygsd) 

239 

240 elif match_gsd: 

241 if self.align_grids: 

242 # use reference grid 

243 out_grid = self.ref_grid 

244 else: 

245 # use grid of input image and reference gsd 

246 out_grid = get_grid(self.im2shift.geotransform, ref_xgsd, ref_ygsd) 

247 

248 else: 

249 if self.align_grids and \ 

250 self._are_grids_alignable(self.im2shift.xgsd, self.im2shift.ygsd, ref_xgsd, ref_ygsd): 

251 # use origin of reference image and gsd of input image 

252 out_grid = get_grid(self.ref_gt, self.im2shift.xgsd, self.im2shift.ygsd) 

253 else: 

254 if not self.GCPList: 

255 # in case of global co-registration: 

256 # -> use the target image grid but update the origin (shift-correction without resampling) 

257 out_grid = get_grid(self.updated_gt, self.im2shift.xgsd, self.im2shift.ygsd) 

258 else: 

259 # in case of local co-registration: 

260 # -> use input image grid 

261 out_grid = get_grid(self.im2shift.geotransform, self.im2shift.xgsd, self.im2shift.ygsd) 

262 

263 return out_grid 

264 

265 @property 

266 def warping_needed(self): 

267 """Return True if image warping is needed in consideration of the input parameters of DESHIFTER.""" 

268 assert self.out_grid, 'Output grid must be calculated before.' 

269 equal_prj = prj_equal(self.ref_prj, self.shift_prj) 

270 return \ 

271 False if (equal_prj and not self.GCPList and is_coord_grid_equal(self.updated_gt, *self.out_grid)) else True 

272 

273 def _are_grids_alignable(self, in_xgsd, in_ygsd, out_xgsd, out_ygsd): 

274 """Check if the input image pixel grid is alignable to the output grid. 

275 

276 :param in_xgsd: 

277 :param in_ygsd: 

278 :param out_xgsd: 

279 :param out_ygsd: 

280 :return: 

281 """ 

282 if self._grids_alignable is None: 

283 def is_alignable(gsd1, gsd2): 

284 """Check if pixel sizes are divisible.""" 

285 return max(gsd1, gsd2) % min(gsd1, gsd2) == 0 

286 

287 self._grids_alignable = \ 

288 False if (not is_alignable(in_xgsd, out_xgsd) or not is_alignable(in_ygsd, out_ygsd)) else True 

289 

290 if self._grids_alignable is False and not self.q: 

291 warnings.warn("\nThe coordinate grid of %s cannot be aligned to the desired grid because their pixel " 

292 "sizes are not exact multiples of each other (input [X/Y]: %s/%s; desired [X/Y]: %s/%s). " 

293 "Therefore the original grid is chosen for the resampled output image. If you don´t like " 

294 "that you can use the 'out_gsd' or 'match_gsd' parameters to set an appropriate output " 

295 "pixel size or to allow changing the pixel size.\n" 

296 % (self.im2shift.basename, in_xgsd, in_ygsd, out_xgsd, out_ygsd)) 

297 

298 return self._grids_alignable 

299 

300 def _get_out_extent(self): 

301 if self.clipextent is None: 

302 # no clip extent has been given 

303 if self.cliptoextent: 

304 # use actual image corners as clip extent 

305 self.clipextent = self.im2shift.footprint_poly.envelope.bounds 

306 else: 

307 # use outer bounds of the image as clip extent 

308 xmin, xmax, ymin, ymax = self.im2shift.box.boundsMap 

309 self.clipextent = xmin, ymin, xmax, ymax 

310 

311 # snap clipextent to output grid 

312 # (in case of odd input coords the output coords are moved INSIDE the input array) 

313 xmin, ymin, xmax, ymax = self.clipextent 

314 x_tol, y_tol = float(np.ptp(self.out_grid[0]) / 2000), float(np.ptp(self.out_grid[1]) / 2000) # 2.000th pix 

315 xmin = find_nearest(self.out_grid[0], xmin, roundAlg='on', extrapolate=True, tolerance=x_tol) 

316 ymin = find_nearest(self.out_grid[1], ymin, roundAlg='on', extrapolate=True, tolerance=y_tol) 

317 xmax = find_nearest(self.out_grid[0], xmax, roundAlg='off', extrapolate=True, tolerance=x_tol) 

318 ymax = find_nearest(self.out_grid[1], ymax, roundAlg='off', extrapolate=True, tolerance=y_tol) 

319 return xmin, ymin, xmax, ymax 

320 

321 def correct_shifts(self) -> collections.OrderedDict: 

322 if not self.q: 

323 print('Correcting geometric shifts...') 

324 

325 t_start = time.time() 

326 

327 if not self.warping_needed: 

328 """NO RESAMPLING NEEDED""" 

329 

330 self.is_shifted = True 

331 self.is_resampled = False 

332 xmin, ymin, xmax, ymax = self._get_out_extent() 

333 

334 if not self.q: 

335 print("NOTE: The detected shift is corrected by updating the map info of the target image only, i.e., " 

336 "without any resampling. Set the 'align_grids' parameter to True if you need the target and the " 

337 "reference coordinate grids to be aligned.") 

338 

339 if self.cliptoextent: 

340 # TODO validate results 

341 # TODO -> output extent does not seem to be the requested one! (only relevant if align_grids=False) 

342 # get shifted array 

343 shifted_geoArr = GeoArray(self.im2shift[:], tuple(self.updated_gt), self.shift_prj) 

344 

345 # clip with target extent 

346 # NOTE: get_mapPos() does not perform any resampling as long as source and target projection are equal 

347 self.arr_shifted, self.updated_gt, self.updated_projection = \ 

348 shifted_geoArr.get_mapPos((xmin, ymin, xmax, ymax), 

349 self.shift_prj, 

350 fillVal=self.nodata, 

351 band2get=self.band2process) 

352 

353 self.updated_map_info = geotransform2mapinfo(self.updated_gt, self.updated_projection) 

354 

355 else: 

356 # array keeps the same; updated gt and prj are taken from coreg_info 

357 self.arr_shifted = self.im2shift[:, :, self.band2process] \ 

358 if self.band2process is not None else self.im2shift[:] 

359 

360 out_geoArr = GeoArray(self.arr_shifted, self.updated_gt, self.updated_projection, q=self.q) 

361 out_geoArr.nodata = self.nodata # equals self.im2shift.nodata after __init__() 

362 out_geoArr.metadata = self.im2shift.metadata[[self.band2process]] \ 

363 if self.band2process is not None else self.im2shift.metadata 

364 

365 self.GeoArray_shifted = out_geoArr 

366 

367 else: # FIXME equal_prj==False ist noch NICHT implementiert 

368 """RESAMPLING NEEDED""" 

369 # FIXME avoid reading the whole band if clip_extent is passed 

370 

371 in_arr = self.im2shift[:, :, self.band2process] \ 

372 if self.band2process is not None and self.im2shift.ndim == 3 else self.im2shift[:] 

373 

374 if not self.GCPList: 

375 # apply XY-shifts to input image gt 'shift_gt' in order to correct the shifts before warping 

376 self.shift_gt[0], self.shift_gt[3] = self.updated_gt[0], self.updated_gt[3] 

377 

378 # get resampled array 

379 out_arr, out_gt, out_prj = \ 

380 warp_ndarray(in_arr, self.shift_gt, self.shift_prj, self.ref_prj, 

381 rspAlg=_dict_rspAlg_rsp_Int[self.rspAlg], 

382 in_nodata=self.nodata, 

383 out_nodata=self.nodata, 

384 out_gsd=self.out_gsd, 

385 out_bounds=self._get_out_extent(), # always returns an extent snapped to the target grid 

386 gcpList=self.GCPList, 

387 # polynomialOrder=str(3), 

388 # options='-refine_gcps 500 1.9', 

389 # warpOptions=['-refine_gcps 500 1.9'], 

390 # options='-wm 10000',# -order 3', 

391 # options=['-order 3'], 

392 # options=['GDAL_CACHEMAX 800 '], 

393 # warpMemoryLimit=125829120, # 120MB 

394 CPUs=self.CPUs, 

395 progress=self.progress, 

396 q=self.q) 

397 

398 out_geoArr = GeoArray(out_arr, out_gt, out_prj, q=self.q) 

399 out_geoArr.nodata = self.nodata # equals self.im2shift.nodata after __init__() 

400 out_geoArr.metadata = self.im2shift.metadata[[self.band2process]] \ 

401 if self.band2process is not None else self.im2shift.metadata 

402 

403 self.arr_shifted = out_arr 

404 self.updated_gt = out_gt 

405 self.updated_projection = out_prj 

406 self.updated_map_info = geotransform2mapinfo(out_gt, out_prj) 

407 self.GeoArray_shifted = out_geoArr 

408 self.is_shifted = True 

409 self.is_resampled = True 

410 

411 if self.path_out: 

412 out_geoArr.save(self.path_out, fmt=self.fmt_out, creationOptions=self.out_creaOpt) 

413 

414 # validation 

415 if not is_coord_grid_equal(self.updated_gt, *self.out_grid, tolerance=1.e8): 

416 raise RuntimeError('DESHIFTER output dataset has not the desired target pixel grid. Target grid ' 

417 'was %s. Output geotransform is %s.' % (str(self.out_grid), str(self.updated_gt))) 

418 # TODO to be continued (extent, map info, ...) 

419 

420 if self.v: 

421 print('Time for shift correction: %.2fs' % (time.time() - t_start)) 

422 return self.deshift_results 

423 

424 @property 

425 def deshift_results(self): 

426 deshift_results = collections.OrderedDict() 

427 deshift_results.update({ 

428 'band': self.band2process, 

429 'is shifted': self.is_shifted, 

430 'is resampled': self.is_resampled, 

431 'updated map info': self.updated_map_info, 

432 'updated geotransform': self.updated_gt, 

433 'updated projection': self.updated_projection, 

434 'arr_shifted': self.arr_shifted, 

435 'GeoArray_shifted': self.GeoArray_shifted 

436 }) 

437 return deshift_results 

438 

439 

440def deshift_image_using_coreg_info(im2shift: Union[GeoArray, str], 

441 coreg_results: dict, 

442 path_out: str = None, 

443 fmt_out: str = 'ENVI', 

444 q: bool = False): 

445 """Correct a geometrically distorted image using previously calculated coregistration info. 

446 

447 This function can be used for example to correct spatial shifts of mask files using the same transformation 

448 parameters that have been used to correct their source images. 

449 

450 :param im2shift: path of an image to be de-shifted or alternatively a GeoArray object 

451 :param coreg_results: the results of the co-registration as given by COREG.coreg_info or 

452 COREG_LOCAL.coreg_info respectively 

453 :param path_out: /output/directory/filename for coregistered results. If None, no output is written - only 

454 the shift corrected results are returned. 

455 :param fmt_out: raster file format for output file. ignored if path_out is None. can be any GDAL 

456 compatible raster file format (e.g. 'ENVI', 'GTIFF'; default: ENVI) 

457 :param q: quiet mode (default: False) 

458 :return: 

459 """ 

460 deshift_results = DESHIFTER(im2shift, coreg_results).correct_shifts() 

461 

462 if path_out: 

463 deshift_results['GeoArray_shifted'].save(path_out, fmt_out=fmt_out, q=q) 

464 

465 return deshift_results