Coverage for geoarray/baseclasses.py: 82%
971 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-14 11:57 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-14 11:57 +0000
1# -*- coding: utf-8 -*-
3# geoarray, A fast Python interface for image geodata - either on disk or in memory.
4#
5# Copyright (C) 2017-2023
6# - Daniel Scheffler (GFZ Potsdam, daniel.scheffler@gfz-potsdam.de)
7# - Helmholtz Centre Potsdam - GFZ German Research Centre for Geosciences Potsdam,
8# Germany (https://www.gfz-potsdam.de/)
9#
10# This software was developed within the context of the GeoMultiSens project funded
11# by the German Federal Ministry of Education and Research
12# (project grant code: 01 IS 14 010 A-C).
13#
14# Licensed under the Apache License, Version 2.0 (the "License");
15# you may not use this file except in compliance with the License.
16# You may obtain a copy of the License at
17#
18# http://www.apache.org/licenses/LICENSE-2.0
19#
20# Unless required by applicable law or agreed to in writing, software
21# distributed under the License is distributed on an "AS IS" BASIS,
22# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23# See the License for the specific language governing permissions and
24# limitations under the License.
27import os
28import warnings
29from pkgutil import find_loader
30from collections import OrderedDict
31from copy import copy, deepcopy
32from numbers import Number
33from typing import Union, Optional, Sequence, List, Tuple, Iterable, TYPE_CHECKING # noqa F401
35import numpy as np
36from osgeo import gdal, gdal_array # noqa
37from shapely.geometry import Polygon
38from shapely.wkt import loads as shply_loads
39# dill -> imported when dumping GeoArray
41from py_tools_ds.convenience.object_oriented import alias_property
42from py_tools_ds.geo.coord_calc import get_corner_coordinates
43from py_tools_ds.geo.coord_grid import snap_bounds_to_pixGrid
44from py_tools_ds.geo.coord_trafo import mapXY2imXY, imXY2mapXY, transform_any_prj, reproject_shapelyGeometry
45from py_tools_ds.geo.projection import prj_equal, WKT2EPSG, EPSG2WKT, isLocal, CRS
46from py_tools_ds.geo.raster.conversion import raster2polygon
47from py_tools_ds.geo.vector.topology \
48 import get_footprint_polygon, polyVertices_outside_poly, fill_holes_within_poly
49from py_tools_ds.geo.vector.geometry import boxObj
50from py_tools_ds.io.raster.gdal import get_GDAL_ds_inmem
51from py_tools_ds.numeric.numbers import is_number
52from py_tools_ds.numeric.array import get_array_tilebounds
54# internal imports
55from .subsetting import get_array_at_mapPos
56from .metadata import GDAL_Metadata
58if TYPE_CHECKING:
59 from matplotlib.colors import Colormap
60 from matplotlib import axis, figure
61 from matplotlib.image import AxesImage
62 from holoviews import HoloMap
63 from .masks import NoDataMask, BadDataMask
65__author__ = 'Daniel Scheffler'
68class _GeneratorLen(object):
69 """Generator class with a defined __len__ attribute."""
71 def __init__(self, gen, length):
72 self.gen = gen
73 self.length = length
75 def __len__(self):
76 return self.length
78 def __iter__(self):
79 return self.gen
82class GeoArrayTiles(_GeneratorLen):
83 """A class representing tiles of a GeoArray."""
86class GeoArray(object):
87 """A class providing a fast Python interface for geodata - either on disk or in memory.
89 GeoArray can be instanced with a file path or with a numpy array and the corresponding geoinformation. Instances
90 can always be indexed and sliced like normal numpy arrays, no matter if it has been instanced from file or from an
91 in-memory array. GeoArray provides a wide range of geo-related attributes belonging to the dataset as well as
92 some functions for quickly visualizing the data as a map, a simple image or an interactive image.
93 """
95 def __init__(self,
96 path_or_array: Union[str, np.ndarray, 'GeoArray'],
97 geotransform: tuple = None,
98 projection: str = None,
99 bandnames: list = None,
100 nodata: float = None,
101 basename: str = '',
102 progress: bool = True,
103 q: bool = False
104 ) -> None:
105 """Get an instance of GeoArray.
107 :param path_or_array: a numpy.ndarray (rows, columns, bands) or a valid file path
108 :param geotransform: GDAL geotransform of the given array or file on disk
109 :param projection: projection of the given array or file on disk as WKT string
110 (only needed if GeoArray is instanced with an array)
111 :param bandnames: names of the bands within the input array, e.g. ['mask_1bit', 'mask_clouds'],
112 (default: ['B1', 'B2', 'B3', ...])
113 :param nodata: nodata value
114 :param basename: a short base name of the dataset (e.g., used in some status messages)
115 :param progress: show progress bars (default: True)
116 :param q: quiet mode (default: False)
117 """
118 if not (isinstance(path_or_array, (str, np.ndarray, GeoArray)) or
119 issubclass(getattr(path_or_array, '__class__'), GeoArray)):
120 raise ValueError("%s parameter 'arg' takes only string, np.ndarray or GeoArray(and subclass) instances. "
121 "Got %s." % (self.__class__.__name__, type(path_or_array)))
123 if path_or_array is None:
124 raise ValueError("The %s parameter 'path_or_array' must not be None!" % self.__class__.__name__)
126 if isinstance(path_or_array, str):
127 assert ' ' not in path_or_array, "The given path contains whitespaces. This is not supported by GDAL."
129 if not os.path.exists(path_or_array) and \
130 not path_or_array.startswith('/vsi') and \
131 not path_or_array.startswith('HDF') and \
132 not path_or_array.startswith('NETCDF'):
133 raise FileNotFoundError(path_or_array)
135 if isinstance(path_or_array, GeoArray) or issubclass(getattr(path_or_array, '__class__'), GeoArray):
136 self.__dict__ = path_or_array.__dict__.copy()
137 self._initParams = dict([x for x in locals().items() if x[0] != "self"])
138 self.geotransform = geotransform or self.geotransform
139 self.projection = projection or self.projection
140 self.bandnames = bandnames or list(self.bandnames.keys())
141 self.basename = basename or self.basename
142 self._nodata = nodata if nodata is not None else self._nodata
143 self.progress = False if progress is False else self.progress
144 self.q = q or self.q
146 else:
147 self._initParams = dict([x for x in locals().items() if x[0] != "self"])
148 self._arr = path_or_array if isinstance(path_or_array, np.ndarray) else None
149 self.filePath = path_or_array if isinstance(path_or_array, str) and path_or_array else None
150 basename_default = os.path.splitext(os.path.basename(self.filePath))[0] if not self.is_inmem else 'IN_MEM'
151 self.basename = basename or basename_default
152 self.progress = progress
153 self.q = q
154 self._arr_cache = None # dict containing key 'pos' and 'arr_cached'
155 self._geotransform = None
156 self._projection = None
157 self._shape = None
158 self._dtype = None
159 self._nodata = nodata
160 self._mask_nodata = None
161 self._mask_baddata = None
162 self._footprint_poly = None
163 self._gdalDataset_meta_already_set = False
164 self._metadata = None
165 self._bandnames = None
167 if bandnames:
168 self.bandnames = bandnames # use property in order to validate given value
169 if geotransform:
170 self.geotransform = geotransform # use property in order to validate given value
171 if projection:
172 self.projection = projection # use property in order to validate given value
174 if self.filePath:
175 self.set_gdalDataset_meta()
177 if 'nodata' in self._initParams and self._initParams['nodata'] is not None:
178 self._validate_nodataVal()
180 def _validate_nodataVal(self) -> None:
181 """Check if a given nodata value is within the valid value range of the data type."""
182 _nodata = self._initParams['nodata']
184 if np.issubdtype(self.dtype, np.integer):
185 dt_min, dt_max = np.iinfo(self.dtype).min, np.iinfo(self.dtype).max
186 elif np.issubdtype(self.dtype, np.floating):
187 dt_min, dt_max = np.finfo(self.dtype).min, np.finfo(self.dtype).max
188 else:
189 return
191 if not dt_min <= _nodata <= dt_max:
192 if np.issubdtype(self.dtype, np.floating) and np.isnan(_nodata):
193 pass
194 else:
195 raise ValueError("The given no-data value (%s) is out range for data type %s."
196 % (self._initParams['nodata'], str(np.dtype(self.dtype))))
198 @property
199 def arr(self) -> Optional[np.ndarray]:
200 return self._arr
202 @arr.setter
203 def arr(self, ndarray: np.ndarray):
204 assert isinstance(ndarray, np.ndarray), "'arr' can only be set to a numpy array! Got %s." % type(ndarray)
205 # assert ndarray.shape == self.shape, "'arr' can only be set to a numpy array with shape %s. Received %s. " \
206 # "If you need to change the dimensions, create a new instance of %s." \
207 # %(self.shape, ndarray.shape, self.__class__.__name__)
208 # THIS would avoid warping like this: geoArr.arr, geoArr.gt, geoArr.prj = warp(...)
210 if ndarray.shape != self.shape:
211 self.flush_cache() # the cached array is not useful anymore
213 self._arr = ndarray
214 self._dtype = ndarray.dtype
215 self._shape = ndarray.shape
217 @property
218 def bandnames(self) -> dict:
219 if self._bandnames and len(self._bandnames) == self.bands:
220 return self._bandnames
221 else:
222 del self.bandnames # runs deleter which sets it to default values
223 return self._bandnames
225 @bandnames.setter
226 def bandnames(self, list_bandnames: list):
227 if list_bandnames:
228 if not isinstance(list_bandnames, list):
229 raise TypeError("A list must be given when setting the 'bandnames' attribute. "
230 "Received %s." % type(list_bandnames))
231 if len(list_bandnames) != self.bands:
232 raise ValueError('Number of given bandnames does not match number of bands in array.')
233 if len(list(set([type(b) for b in list_bandnames]))) != 1:
234 raise ValueError('Multiple data types of the band names are not supported.')
235 if not isinstance(list_bandnames[0], (str, int, float)):
236 raise ValueError(f'Band names must be a set of strings, integers, or floats. '
237 f'Got {type(list_bandnames[0])}')
239 bN_dict = OrderedDict((band, i) for i, band in enumerate(list_bandnames))
241 if len(bN_dict) != self.bands:
242 raise ValueError('Bands must have unique names. Received band list: %s' % list_bandnames)
244 self._bandnames = bN_dict
246 try:
247 self.metadata.band_meta['band_names'] = list_bandnames
248 except AttributeError:
249 # in case self._metadata is None
250 pass
251 else:
252 del self.bandnames
254 @bandnames.deleter
255 def bandnames(self):
256 self._bandnames = OrderedDict(('B%s' % band, i) for i, band in enumerate(range(1, self.bands + 1)))
257 if self._metadata is not None:
258 self.metadata.band_meta['band_names'] = list(self._bandnames.keys())
260 @property
261 def is_inmem(self) -> bool:
262 """Check if associated image array is completely loaded into memory."""
263 return isinstance(self.arr, np.ndarray)
265 @property
266 def shape(self) -> tuple:
267 """Get the array shape of the associated image array."""
268 if self.is_inmem:
269 return self.arr.shape
270 else:
271 if self._shape:
272 return self._shape
273 else:
274 self.set_gdalDataset_meta()
275 return self._shape
277 @property
278 def ndim(self) -> int:
279 """Get the number dimensions of the associated image array."""
280 return len(self.shape)
282 @property
283 def rows(self) -> int:
284 """Get the number of rows of the associated image array."""
285 return self.shape[0]
287 @property
288 def columns(self) -> int:
289 """Get the number of columns of the associated image array."""
290 return self.shape[1]
292 cols = alias_property('columns')
294 @property
295 def bands(self) -> int:
296 """Get the number of bands of the associated image array."""
297 return self.shape[2] if len(self.shape) > 2 else 1
299 @property
300 def dtype(self) -> np.dtype:
301 """Get the numpy data type of the associated image array."""
302 if self.is_inmem:
303 return self.arr.dtype
304 else:
305 if self._dtype:
306 return self._dtype
307 else:
308 self.set_gdalDataset_meta()
309 return self._dtype
311 @property
312 def geotransform(self) -> Union[tuple, list]:
313 """Get the GDAL GeoTransform of the associated image, e.g., (283500.0, 5.0, 0.0, 4464500.0, 0.0, -5.0)."""
314 if self._geotransform:
315 return self._geotransform
316 elif not self.is_inmem:
317 self.set_gdalDataset_meta()
318 return self._geotransform
319 else:
320 return [0, 1, 0, 0, 0, -1]
322 @geotransform.setter
323 def geotransform(self, gt: Union[tuple, list]):
324 assert isinstance(gt, (list, tuple)) and len(gt) == 6, \
325 'geotransform must be a list with 6 numbers. Got %s.' % str(gt)
327 for i in gt:
328 assert is_number(i), "geotransform must contain only numbers. Got '%s' (type: %s)." % (i, type(i))
330 self._geotransform = gt
332 gt = alias_property('geotransform')
334 @property
335 def xgsd(self) -> float:
336 """Get the X resolution in units of the given or detected projection."""
337 return self.geotransform[1]
339 @property
340 def ygsd(self) -> float:
341 """Get the Y resolution in units of the given or detected projection."""
342 return abs(self.geotransform[5])
344 @property
345 def xygrid_specs(self) -> Sequence:
346 """Get the specifications for the X/Y coordinate grid.
348 This returns for example [[15,30], [0,30]] for a coordinate
349 with its origin at X/Y[15,0] and a GSD of X/Y[15,30].
350 """
351 def get_grid(gt, xgsd, ygsd): return [[gt[0], gt[0] + xgsd], [gt[3], gt[3] - ygsd]]
352 return get_grid(self.geotransform, self.xgsd, self.ygsd)
354 @property
355 def projection(self) -> str:
356 """Get the projection of the associated image.
358 Setting the projection is only allowed if GeoArray has been instanced from memory or the associated file on
359 disk has no projection.
360 """
361 if self._projection:
362 return self._projection
363 elif not self.is_inmem:
364 self.set_gdalDataset_meta()
365 return self._projection # or "LOCAL_CS[\"MAP\"]"
366 else:
367 return '' # '"LOCAL_CS[\"MAP\"]"
369 @projection.setter
370 def projection(self, prj: str):
371 self._projection = prj
373 prj = alias_property('projection')
375 @property
376 def epsg(self) -> int:
377 """Get the EPSG code of the projection of the GeoArray."""
378 return WKT2EPSG(self.projection)
380 @epsg.setter
381 def epsg(self, epsg_code: int):
382 self.projection = EPSG2WKT(epsg_code)
384 @property
385 def box(self) -> boxObj:
386 mapPoly = get_footprint_polygon(get_corner_coordinates(gt=self.geotransform, cols=self.columns, rows=self.rows))
387 return boxObj(gt=self.geotransform, prj=self.projection, mapPoly=mapPoly)
389 @property
390 def is_map_geo(self) -> bool:
391 """Return 'True' if the image has a valid geoinformation with map instead of image coordinates."""
392 return all([self.gt, list(self.gt) != [0, 1, 0, 0, 0, -1], self.prj])
394 @property
395 def is_rotated(self) -> bool:
396 """Return 'True' if the image has a rotation in the map info (i.e., is pseudo-projected)."""
397 return self.gt[2] != 0 or self.gt[4] != 0
399 @property
400 def nodata(self) -> Optional[Union[bool, int, float]]:
401 """Get the nodata value of the GeoArray instance.
403 If GeoArray has been instanced with a file path the metadata of the file on disk is checked for an existing
404 nodata value. Otherwise, (if no value is exlicitly given during object instanciation) an automatic detection
405 based on 3x3 windows at each image corner is run that analyzes the mean and standard deviation of these windows.
406 """
407 if self._nodata is not None:
408 return self._nodata
409 else:
410 # try to get nodata value from file
411 if not self.is_inmem:
412 self.set_gdalDataset_meta()
413 if self._nodata is None:
414 self.find_noDataVal()
415 if self._nodata == 'ambiguous':
416 warnings.warn('Nodata value could not be clearly identified. It has been set to None.')
417 self._nodata = None
418 else:
419 if self._nodata is not None and not self.q:
420 print("Automatically detected nodata value for %s '%s': %s"
421 % (self.__class__.__name__, self.basename, self._nodata))
422 return self._nodata
424 @nodata.setter
425 def nodata(self, value: Optional[Union[bool, int, float]]):
426 if isinstance(value, np.bool_):
427 value = bool(value)
428 elif isinstance(value, np.integer):
429 value = int(value)
430 elif isinstance(value, np.floating):
431 value = float(value)
433 self._nodata = value
435 if self._metadata and value is not None:
436 self.metadata.global_meta.update({'data_ignore_value': str(value)})
438 @property
439 def mask_nodata(self) -> 'NoDataMask':
440 """Get the nodata mask of the associated image array. It is generated based on all image bands."""
441 if self._mask_nodata is not None:
442 return self._mask_nodata
443 else:
444 self.calc_mask_nodata() # sets self._mask_nodata
445 return self.mask_nodata
447 @mask_nodata.setter
448 def mask_nodata(self, mask: Union[np.ndarray, 'GeoArray', 'NoDataMask']):
449 """Set the bad data mask.
451 :param mask: Can be a file path, a numpy array or an instance o GeoArray.
452 """
453 if mask is not None:
454 from .masks import NoDataMask
455 geoArr_mask = NoDataMask(mask, progress=self.progress, q=self.q)
456 geoArr_mask.gt = geoArr_mask.gt if geoArr_mask.gt not in [None, [0, 1, 0, 0, 0, -1]] else self.gt
457 geoArr_mask.prj = geoArr_mask.prj if geoArr_mask.prj else self.prj
458 imName = "the %s '%s'" % (self.__class__.__name__, self.basename)
460 assert geoArr_mask.bands == 1, \
461 'Expected one single band as nodata mask for %s. Got %s bands.' % (self.basename, geoArr_mask.bands)
462 assert geoArr_mask.shape[:2] == self.shape[:2], 'The provided nodata mask must have the same number of ' \
463 'rows and columns as the %s itself.' % imName
464 assert geoArr_mask.gt == self.gt, \
465 'The geotransform of the given nodata mask for %s must match the geotransform of the %s itself. ' \
466 'Got %s.' % (imName, self.__class__.__name__, geoArr_mask.gt)
467 assert not geoArr_mask.prj or prj_equal(geoArr_mask.prj, self.prj), \
468 'The projection of the given nodata mask for the %s must match the projection of the %s itself.' \
469 % (imName, self.__class__.__name__)
471 self._mask_nodata = geoArr_mask
472 else:
473 del self.mask_nodata
475 @mask_nodata.deleter
476 def mask_nodata(self):
477 self._mask_nodata = None
479 @property
480 def mask_baddata(self) -> 'BadDataMask':
481 """Return the bad data mask.
483 Note: The mask must be explicitly set to a file path or a numpy array before.
484 """
485 return self._mask_baddata
487 @mask_baddata.setter
488 def mask_baddata(self, mask: Union[np.ndarray, 'GeoArray', 'BadDataMask']):
489 """Set bad data mask.
491 :param mask: Can be a file path, a numpy array or an instance o GeoArray.
492 """
493 if mask is not None:
494 from .masks import BadDataMask
495 geoArr_mask = BadDataMask(mask, progress=self.progress, q=self.q)
496 geoArr_mask.gt = geoArr_mask.gt if geoArr_mask.gt not in [None, [0, 1, 0, 0, 0, -1]] else self.gt
497 geoArr_mask.prj = geoArr_mask.prj if geoArr_mask.prj else self.prj
498 imName = "the %s '%s'" % (self.__class__.__name__, self.basename)
500 assert geoArr_mask.bands == 1, \
501 'Expected one single band as bad data mask for %s. Got %s bands.' % (self.basename, geoArr_mask.bands)
502 assert geoArr_mask.shape[:2] == self.shape[:2], 'The provided bad data mask must have the same number of ' \
503 'rows and columns as the %s itself.' % imName
504 assert geoArr_mask.gt == self.gt, \
505 'The geotransform of the given bad data mask for %s must match the geotransform of the %s itself. ' \
506 'Got %s.' % (imName, self.__class__.__name__, geoArr_mask.gt)
507 assert prj_equal(geoArr_mask.prj, self.prj), \
508 'The projection of the given bad data mask for the %s must match the projection of the %s itself.' \
509 % (imName, self.__class__.__name__)
511 self._mask_baddata = geoArr_mask
512 else:
513 del self.mask_baddata
515 @mask_baddata.deleter
516 def mask_baddata(self):
517 self._mask_baddata = None
519 @property
520 def footprint_poly(self) -> Polygon:
521 """Get the footprint polygon of the associated image array (shapely.geometry.Polygon)."""
522 # FIXME should return polygon in image coordinates if no projection is available
523 if self._footprint_poly is None:
524 assert self.mask_nodata is not None, 'A nodata mask is needed for calculating the footprint polygon. '
525 if False not in self.mask_nodata[:]:
526 # do not run raster2polygon if whole image is filled with data
527 self._footprint_poly = self.box.mapPoly
528 elif True not in self.mask_nodata[:]:
529 raise RuntimeError("Unable to compute a footprint polygon for %s '%s' "
530 "because the dataset only contains nodata values."
531 % (self.__class__.__name__, self.basename))
532 else:
533 try:
534 multipolygon = raster2polygon(self.mask_nodata.astype(np.uint8), self.gt, self.prj, exact=False,
535 min_npx=10, progress=self.progress, q=self.q, timeout=15)
536 self._footprint_poly = fill_holes_within_poly(multipolygon)
537 except (RuntimeError, TimeoutError):
538 if not self.q:
539 warnings.warn("\nCalculation of footprint polygon failed for %s '%s'. Using outer bounds. One "
540 "reason could be that the nodata value appears within the actual image (not only "
541 "as fill value). To avoid this use another nodata value. Current nodata value is "
542 "%s." % (self.__class__.__name__, self.basename, self.nodata))
543 self._footprint_poly = self.box.mapPoly
545 # validation
546 assert not polyVertices_outside_poly(self._footprint_poly, self.box.mapPoly, tolerance=1e-5), \
547 "Computing footprint polygon for %s '%s' failed. The resulting polygon is partly or completely " \
548 "outside of the image bounds." % (self.__class__.__name__, self.basename)
549 # assert self._footprint_poly
550 # for XY in self.corner_coord:
551 # assert self.GeoArray.box.mapPoly.contains(Point(XY)) or self.GeoArray.box.mapPoly.touches(Point(XY)), \
552 # "The corner position '%s' is outside of the %s." % (XY, self.imName)
554 return self._footprint_poly
556 @footprint_poly.setter
557 def footprint_poly(self, poly: Union[Polygon, str]):
558 if isinstance(poly, Polygon):
559 self._footprint_poly = poly
560 elif isinstance(poly, str):
561 self._footprint_poly = shply_loads(poly)
562 else:
563 raise ValueError("'footprint_poly' can only be set from a shapely polygon or a WKT string.")
565 @property
566 def metadata(self) -> GDAL_Metadata:
567 """Return a DataFrame containing all available metadata (read from file if available).
569 Use 'metadata[band_index].to_dict()' to get a metadata dictionary for a specific band.
570 Use 'metadata.loc[row_name].to_dict()' to get all metadata values of the same key for all bands as dictionary.
571 Use 'metadata.loc[row_name, band_index] = value' to set a new value.
573 :return: instance of GDAL_Metadata
574 """
575 if self._metadata is not None:
576 return self._metadata
577 else:
578 default = GDAL_Metadata(nbands=self.bands, nodata_allbands=self._nodata)
580 self._metadata = default
581 if not self.is_inmem:
582 self.set_gdalDataset_meta()
583 return self._metadata
584 else:
585 return self._metadata
587 @metadata.setter
588 def metadata(self, meta: GDAL_Metadata):
589 if not isinstance(meta, GDAL_Metadata) or meta.bands != self.bands:
590 raise ValueError("%s.metadata can only be set with an instance of geoarray.metadata.GDAL_Metadata of "
591 "which the band number corresponds to the band number of %s."
592 % (self.__class__.__name__, self.__class__.__name__))
593 self._metadata = meta
595 meta = alias_property('metadata') # type: GDAL_Metadata
597 def __getitem__(self, given: Union[int, float, slice, np.integer, np.floating, str, tuple, list]) -> np.ndarray:
598 if isinstance(given, (int, float, slice, np.integer, np.floating)) and self.ndim == 3:
599 # handle 'given' as index for 3rd (bands) dimension
600 if self.is_inmem:
601 return self.arr[:, :, given]
602 else:
603 return self.from_path(self.filePath, [given])
605 elif isinstance(given, str):
606 # behave like a dictionary and return the corresponding band
607 if self.bandnames:
608 if given not in self.bandnames:
609 raise ValueError("'%s' is not a known band. Known bands are: %s"
610 % (given, ', '.join(list(self.bandnames.keys()))))
611 if self.is_inmem:
612 return self.arr if self.ndim == 2 else self.arr[:, :, self.bandnames[given]]
613 else:
614 return self.from_path(self.filePath, [self.bandnames[given]])
615 else:
616 raise ValueError('String indices are only supported if %s has been instanced with bandnames given.'
617 % self.__class__.__name__)
619 elif isinstance(given, (tuple, list)):
620 # handle requests like geoArr[[1,2],[3,4] -> not implemented in from_path if array is not in mem
621 types = [type(i) for i in given]
623 if list in types or tuple in types:
625 # avoid that the whole cube is read if only data from a single band is requested
626 if not self.is_inmem \
627 and len(given) == 3 \
628 and isinstance(given[2], (int, float, np.integer, np.floating)):
629 band_subset = GeoArray(self.filePath)[:, :, given[2]]
630 return band_subset[given[:2]]
632 self.to_mem()
634 if len(given) == 3:
636 # handle strings in the 3rd dim of 'given' -> convert them to a band index
637 if isinstance(given[2], str):
638 if self.bandnames:
639 if given[2] not in self.bandnames:
640 raise ValueError("'%s' is not a known band. Known bands are: %s"
641 % (given[2], ', '.join(list(self.bandnames.keys()))))
643 band_idx = self.bandnames[given[2]]
644 # NOTE: the string in the 3rd is ignored if ndim==2 and band_idx==0
645 if self.is_inmem:
646 return self.arr if (self.ndim == 2 and band_idx == 0) else self.arr[:, :, band_idx]
647 else:
648 getitem_params = \
649 given[:2] if (self.ndim == 2 and band_idx == 0) else given[:2] + (band_idx,)
650 return self.from_path(self.filePath, getitem_params)
651 else:
652 raise ValueError(
653 'String indices are only supported if %s has been instanced with bandnames given.'
654 % self.__class__.__name__)
656 # in case a third dim is requested from 2D-array -> ignore 3rd dim if 3rd dim is 0
657 elif self.ndim == 2 and given[2] == 0:
658 if self.is_inmem:
659 return self.arr[given[:2]]
660 else:
661 return self.from_path(self.filePath, given[:2])
663 # if nothing has been returned until here -> behave like a numpy array
664 if self.is_inmem:
665 return self.arr[given]
666 else:
667 getitem_params = [given] if isinstance(given, slice) else given
668 return self.from_path(self.filePath, getitem_params)
670 def __setitem__(self, idx: Union[int, list, slice], array2set: Union[np.ndarray, Number]):
671 """Overwrite the pixel values of GeoArray.arr with the given array.
673 :param idx: the index position to overwrite
674 :param array2set: array to be set. Must be compatible to the given index position.
675 """
676 if self.is_inmem:
677 self.arr[idx] = array2set
678 else:
679 raise NotImplementedError('Item assignment for %s instances that are not in memory is not yet supported.'
680 % self.__class__.__name__)
682 def __getattr__(self, attr: str):
683 # check if the requested attribute can not be present because GeoArray has been instanced with an array
684 attrsNot2Link2np = ['__deepcopy__'] # attributes we don't want to inherit from numpy.ndarray
686 if attr not in self.__dir__() and not self.is_inmem and attr in ['shape', 'dtype', 'geotransform',
687 'projection']:
688 self.set_gdalDataset_meta()
690 if attr in self.__dir__(): # __dir__() includes also methods and properties
691 return self.__getattribute__(attr) # __getattribute__ avoids infinite loop
692 elif attr not in attrsNot2Link2np and hasattr(np.array([]), attr):
693 return self[:].__getattribute__(attr)
694 else:
695 raise AttributeError("%s object has no attribute '%s'." % (self.__class__.__name__, attr))
697 def __getstate__(self) -> dict:
698 """Define how the attributes of the GeoArray instance are pickled (e.g., by multiprocessing.Pool)."""
699 # clean array cache in order to avoid cache pickling
700 self.flush_cache()
702 return self.__dict__
704 def __setstate__(self, state: dict):
705 """Define how the attributes of the GeoArray instance are unpickled (e.g., by multiprocessing.Pool).
707 NOTE: This method has been implemented because otherwise pickled and unpickled instances show recursion errors
708 within __getattr__ when requesting any attribute.
709 """
710 self.__dict__ = state
712 def calc_mask_nodata(self, fromBand: int = None, overwrite: bool = False, flag: str = 'all') -> np.ndarray:
713 """Calculate a no data mask with values False (=nodata) and True (=data).
715 :param fromBand: index of the band to be used (if None, all bands are used)
716 :param overwrite: whether to overwrite existing nodata mask that has already been calculated
717 :param flag: algorithm how to flag pixels (default: 'all')
718 'all': flag those pixels as nodata that contain the nodata value in ALL bands
719 'any': flag those pixels as nodata that contain the nodata value in ANY band
720 :return:
721 """
722 if self._mask_nodata is None or overwrite:
723 if flag not in ['all', 'any']:
724 raise ValueError(flag)
726 assert self.ndim in [2, 3], "Only 2D or 3D arrays are supported. Got a %sD array." % self.ndim
727 arr = self[:, :, fromBand] if self.ndim == 3 and fromBand is not None else self[:]
729 if self.nodata is None:
730 mask = np.ones((self.rows, self.cols), bool)
732 elif np.isnan(self.nodata):
733 nanmask = np.isnan(arr)
734 nanbands = np.all(np.all(nanmask, axis=0), axis=0)
736 if np.all(nanbands):
737 mask = np.full(arr.shape[:2], False)
738 elif arr.ndim == 2:
739 mask = ~np.isnan(arr)
740 else:
741 arr_1st_databand = arr[:, :, np.argwhere(~nanbands)[0][0]]
742 arr_remain = arr[:, :, ~nanbands][:, :, 1:]
744 mask = ~np.isnan(arr_1st_databand) # True where 1st data band has data
746 if flag == 'all':
747 # ALL bands need to contain np.nan to flag the mask as nodata
748 # overwrite the mask at nodata positions (False) with True in case there is data in ANY band
749 mask[~mask] = np.any(~np.isnan(arr_remain[~mask]), axis=1)
750 else:
751 # ANY band needs to contain np.nan to flag the mask as nodata
752 # overwrite the mask at data positions (True) with False in case there is np.nan in ANY band
753 mask[mask] = ~np.any(np.isnan(arr_remain[mask]), axis=1)
755 else:
756 bandmeans = np.mean(np.mean(arr, axis=0), axis=0)
757 nodatabands = bandmeans == self.nodata
759 if np.nanmean(bandmeans) == self.nodata:
760 mask = np.full(arr.shape[:2], False)
761 elif arr.ndim == 2:
762 mask = arr != self.nodata
763 else:
764 arr_1st_databand = arr[:, :, np.argwhere(~nodatabands)[0][0]]
765 arr_remain = arr[:, :, ~nodatabands][:, :, 1:]
767 mask = np.array(arr_1st_databand != self.nodata) # True where 1st data band has data
769 if flag == 'all':
770 # ALL bands need to contain nodata to flag the mask as such
771 # overwrite the mask at nodata positions (False) with True in case there is data in ANY band
772 mask[~mask] = np.any(arr_remain[~mask] != self.nodata, axis=1)
773 else:
774 # ANY band needs to contain nodata to flag the mask as such
775 # overwrite the mask at data positions (True) with False in case there is nodata in ANY band
776 mask[mask] = ~np.any(arr_remain[mask] == self.nodata, axis=1)
778 self.mask_nodata = mask
780 return mask
782 def find_noDataVal(self, bandIdx: int = 0, sz: int = 3) -> Union[int, float]:
783 """Try to derive no data value from homogenious corner pixels within 3x3 windows (by default).
785 :param bandIdx:
786 :param sz: window size in which corner pixels are analysed
787 """
788 wins = [self[0:sz, 0:sz, bandIdx], self[0:sz, -sz:, bandIdx],
789 self[-sz:, -sz:, bandIdx], self[-sz:, 0:sz, bandIdx]] # UL, UR, LR, LL
791 means, stds = [np.mean(win) for win in wins], [np.std(win) for win in wins]
792 possVals = [mean for mean, std in zip(means, stds) if std == 0 or np.isnan(std)]
793 # possVals==[]: all corners are filled with data; np.std(possVals)==0: noDataVal clearly identified
795 if possVals:
796 if np.std(possVals) != 0:
797 if np.isnan(np.std(possVals)):
798 # at least one of the possible values is np.nan
799 nodata = np.nan
800 else:
801 # different possible nodata values have been found in the image corner
802 nodata = 'ambiguous'
803 else:
804 if len(possVals) <= 2:
805 # each window in each corner
806 warnings.warn("\nAutomatic nodata value detection returned the value %s for GeoArray '%s' but this "
807 "seems to be unreliable (occurs in only %s). To avoid automatic detection, just pass "
808 "the correct nodata value."
809 % (possVals[0], self.basename, ('2 image corners' if len(possVals) == 2 else
810 '1 image corner')))
811 nodata = possVals[0]
812 else:
813 nodata = None
815 self.nodata = nodata
816 return nodata
818 def set_gdalDataset_meta(self) -> None:
819 """Retrieve GDAL metadata from file.
821 This is only executed once to avoid overwriting of user defined attributes,
822 that are defined after object instanciation.
823 """
824 if not self._gdalDataset_meta_already_set:
825 assert self.filePath
826 ds = gdal.Open(self.filePath)
827 if not ds:
828 raise Exception('Error reading file: ' + gdal.GetLastErrorMsg())
830 # set private class variables (in order to avoid recursion error)
831 self._shape = tuple([ds.RasterYSize, ds.RasterXSize] + ([ds.RasterCount] if ds.RasterCount > 1 else []))
832 self._dtype = gdal_array.GDALTypeCodeToNumericTypeCode(ds.GetRasterBand(1).DataType)
833 self._geotransform = list(ds.GetGeoTransform())
835 # for some reason GDAL reads arbitrary geotransforms as (0, 1, 0, 0, 0, 1) instead of (0, 1, 0, 0, 0, -1)
836 self._geotransform[5] = -abs(self._geotransform[5]) # => force ygsd to be negative
838 # consequently use WKT1 strings here as GDAL always exports transformation results as WKT1
839 wkt = ds.GetProjection()
840 self._projection = CRS(wkt).to_wkt(version="WKT1_GDAL") if not isLocal(wkt) else ''
842 if 'nodata' not in self._initParams or self._initParams['nodata'] is None:
843 band = ds.GetRasterBand(1)
844 # FIXME this does not support different nodata values within the same file
845 self.nodata = band.GetNoDataValue()
846 self.nodata = band.GetNoDataValue()
848 # set metadata attribute
849 if self.is_inmem or not self.filePath:
850 # metadata cannot be read from disk -> set it to the default
851 self._metadata = GDAL_Metadata(nbands=self.bands, nodata_allbands=self._nodata)
853 else:
854 self._metadata = GDAL_Metadata(filePath=self.filePath)
856 # copy over the band names
857 if 'band_names' in self.metadata.band_meta and self.metadata.band_meta['band_names']:
858 self.bandnames = self.metadata.band_meta['band_names']
860 # noinspection PyUnusedLocal
861 ds = None
863 self._gdalDataset_meta_already_set = True
865 def from_path(self, path: str, getitem_params: list = None) -> np.ndarray:
866 """Read a GDAL compatible raster image from disk, with respect to the given image position.
868 NOTE: If the requested array position is already in cache, it is returned from there.
870 :param path: the file path of the image to read
871 :param getitem_params: a list of slices in the form [row_slice, col_slice, band_slice]
872 :return out_arr: the output array
873 """
874 ds = gdal.Open(path)
875 if not ds:
876 raise Exception('Error reading file: ' + gdal.GetLastErrorMsg())
878 R, C, B = ds.RasterYSize, ds.RasterXSize, ds.RasterCount
879 del ds
881 # convert getitem_params to subset area to be read #
882 rS, rE, cS, cE, bS, bE, bL = [None] * 7
884 # populate rS, rE, cS, cE, bS, bE, bL
885 if getitem_params:
886 # populate rS, rE, cS, cE
887 if len(getitem_params) >= 2:
888 givenR, givenC = getitem_params[:2]
889 if isinstance(givenR, slice):
890 rS = givenR.start
891 rE = givenR.stop - 1 if givenR.stop is not None else None
892 elif isinstance(givenR, (int, np.integer)):
893 rS = givenR
894 rE = givenR
895 if isinstance(givenC, slice):
896 cS = givenC.start
897 cE = givenC.stop - 1 if givenC.stop is not None else None
898 elif isinstance(givenC, (int, np.integer)):
899 cS = givenC
900 cE = givenC
902 # populate bS, bE, bL
903 if len(getitem_params) in [1, 3]:
904 givenB = getitem_params[2] if len(getitem_params) == 3 else getitem_params[0]
905 if isinstance(givenB, slice):
906 bS = givenB.start
907 bE = givenB.stop - 1 if givenB.stop is not None else None
908 elif isinstance(givenB, (int, np.integer)):
909 bS = givenB
910 bE = givenB
911 elif isinstance(givenB, (tuple, list)):
912 typesInGivenB = [type(i) for i in givenB]
913 assert len(list(set(typesInGivenB))) == 1, \
914 'Mixed data types within the list of bands are not supported.'
915 if isinstance(givenB[0], (int, np.integer)):
916 bL = list(givenB)
917 elif isinstance(givenB[0], str):
918 bL = [self.bandnames[i] for i in givenB]
919 elif type(givenB) in [str]:
920 bL = [self.bandnames[givenB]]
922 # set defaults for not given values
923 rS = rS if rS is not None else 0
924 rE = rE if rE is not None else R - 1
925 cS = cS if cS is not None else 0
926 cE = cE if cE is not None else C - 1
927 bS = bS if bS is not None else 0
928 bE = bE if bE is not None else B - 1
929 bL = list(range(bS, bE + 1)) if not bL else bL
931 # convert negative to positive ones
932 rS = rS if rS >= 0 else self.rows + rS
933 rE = rE if rE >= 0 else self.rows + rE
934 cS = cS if cS >= 0 else self.columns + cS
935 cE = cE if cE >= 0 else self.columns + cE
936 bS = bS if bS >= 0 else self.bands + bS
937 bE = bE if bE >= 0 else self.bands + bE
938 bL = [b if b >= 0 else (self.bands + b) for b in bL]
940 # validate subset area bounds to be read
941 def msg(v, idx, sz):
942 # FIXME numpy raises that error ONLY for the 2nd axis
943 return '%s is out of bounds for axis %s with size %s' % (v, idx, sz)
945 for val, axIdx, axSize in zip([rS, rE, cS, cE, bS, bE], [0, 0, 1, 1, 2, 2], [R, R, C, C, B, B]):
946 if not 0 <= val <= axSize - 1:
947 raise ValueError(msg(val, axIdx, axSize))
949 # summarize requested array position in arr_pos
950 # NOTE: # bandlist must be string because truth value of an array with more than one element is ambiguous
951 arr_pos = dict(rS=rS, rE=rE, cS=cS, cE=cE, bS=bS, bE=bE, bL=bL)
953 def _ensure_np_shape_consistency_3D_2D(arr: np.ndarray) -> np.ndarray:
954 """Ensure numpy output shape consistency according to the given indexing parameters.
956 This may require 3D to 2D conversion in case out_arr can be represented by a 2D array AND index has been
957 provided as integer (avoids shapes like (1,2,2). It also may require 2D to 3D conversion in case only one
958 band has been requested and the 3rd dimension has been provided as a slice.
960 NOTE: -> numpy also returns a 2D array in that case
961 NOTE: if array is indexed with a slice -> keep it a 3D array
962 """
963 # a single value -> return as float/int
964 if arr.ndim == 2 and arr.size == 1:
965 arr = arr[0, 0]
967 # 2D -> 3D
968 if arr.ndim == 2 and isinstance(getitem_params, (tuple, list)) and len(getitem_params) == 3 and \
969 isinstance(getitem_params[2], slice):
970 arr = arr[:, :, np.newaxis]
972 # 3D -> 2D
973 if 1 in arr.shape and len(getitem_params) != 1:
974 outshape = []
975 for i, sh in enumerate(arr.shape):
976 if sh == 1 and isinstance(getitem_params[i], (int, np.integer, float, np.floating)):
977 pass
978 else:
979 outshape.append(sh)
981 arr = arr.reshape(*outshape)
983 return arr
985 # check if the requested array position is already in cache -> if yes, return it from there
986 if self._arr_cache is not None and self._arr_cache['pos'] == arr_pos:
987 out_arr = self._arr_cache['arr_cached']
988 out_arr = _ensure_np_shape_consistency_3D_2D(out_arr)
990 else:
991 # TODO insert a multiprocessing.Lock here in order to prevent IO bottlenecks?
992 # read subset area from disk
993 if bL == list(range(0, B)):
994 tempArr = gdal_array.LoadFile(path, cS, rS, cE - cS + 1, rE - rS + 1)
995 out_arr = np.swapaxes(np.swapaxes(tempArr, 0, 2), 0, 1) if B > 1 else tempArr
996 if out_arr is None:
997 raise Exception('Error reading file: ' + gdal.GetLastErrorMsg())
998 else:
999 ds = gdal.Open(path)
1000 if len(bL) == 1:
1001 band = ds.GetRasterBand(bL[0] + 1)
1002 out_arr = band.ReadAsArray(cS, rS, cE - cS + 1, rE - rS + 1)
1003 if out_arr is None:
1004 raise Exception('Error reading file: ' + gdal.GetLastErrorMsg())
1005 del band
1006 else:
1007 out_arr = np.empty((rE - rS + 1, cE - cS + 1, len(bL)))
1008 for i, bIdx in enumerate(bL):
1009 band = ds.GetRasterBand(bIdx + 1)
1010 out_arr[:, :, i] = band.ReadAsArray(cS, rS, cE - cS + 1, rE - rS + 1)
1011 if out_arr is None:
1012 raise Exception('Error reading file: ' + gdal.GetLastErrorMsg())
1013 del band
1015 del ds
1017 out_arr = _ensure_np_shape_consistency_3D_2D(out_arr)
1019 # only set self.arr if the whole cube has been read (in order to avoid sudden shape changes)
1020 if out_arr.shape == self.shape:
1021 self.arr = out_arr
1023 # write _arr_cache
1024 self._arr_cache = dict(pos=arr_pos, arr_cached=out_arr)
1026 return out_arr # TODO implement check of returned datatype (e.g. NoDataMask should always return bool
1027 # TODO -> would be np.int8 if an int8 file is read from disk
1029 def save(self,
1030 out_path: str,
1031 fmt: str = 'ENVI',
1032 creationOptions: list = None
1033 ) -> None:
1034 """Write the raster data to disk.
1036 :param out_path: output path
1037 :param fmt: the output format / GDAL driver code to be used for output creation, e.g. 'ENVI'
1038 Refer to https://gdal.org/drivers/raster/index.html to get a full list of supported
1039 formats.
1040 :param creationOptions: GDAL creation options,
1041 e.g., ["QUALITY=80", "REVERSIBLE=YES", "WRITE_METADATA=YES"]
1042 """
1043 if not self.q:
1044 print('Writing GeoArray of size %s to %s.' % (self.shape, out_path))
1045 assert self.ndim in [2, 3], 'Only 2D- or 3D arrays are supported.'
1047 driver = gdal.GetDriverByName(fmt)
1048 if driver is None:
1049 raise Exception("'%s' is not a supported GDAL driver. Refer to https://gdal.org/drivers/raster/index.html "
1050 "for full list of GDAL driver codes." % fmt)
1052 if not os.path.isdir(os.path.dirname(out_path)):
1053 os.makedirs(os.path.dirname(out_path))
1055 envi_metadict = self.metadata.to_ENVI_metadict()
1057 ###########################
1058 # get source GDAL dataset #
1059 ###########################
1061 ds_src: gdal.Dataset
1062 ds_out: gdal.Dataset
1064 if self.is_inmem:
1065 ds_src = get_GDAL_ds_inmem(self.arr, # expects rows,columns,bands
1066 self.geotransform, self.projection,
1067 self._nodata) # avoid to compute the nodata value here, so use private attrib.
1069 else:
1070 ds_src = gdal.Open(self.filePath)
1071 # metadomains = {dom: src_ds.GetMetadata(dom) for dom in src_ds.GetMetadataDomainList()}
1073 if not ds_src:
1074 raise Exception('Error reading file: ' + gdal.GetLastErrorMsg())
1076 #########################################
1077 # write output dataset and set metadata #
1078 #########################################
1080 try:
1081 # ENVI #
1082 ########
1083 if fmt == 'ENVI':
1084 # NOTE: The dataset has to be written BEFORE metadata are added. Otherwise, metadata are not written.
1086 # write ds_src to disk and re-open it to add the metadata
1087 gdal.Translate(out_path, ds_src, format=fmt, creationOptions=creationOptions)
1088 del ds_src
1089 # os.environ['GDAL_PAM_ENABLED'] = 'NO'
1090 ds_out = gdal.Open(out_path, gdal.GA_Update)
1092 for bidx in range(self.bands):
1093 band = ds_out.GetRasterBand(bidx + 1)
1095 if 'band_names' in envi_metadict:
1096 bandname = str(self.metadata.band_meta['band_names'][bidx]).strip()
1097 band.SetDescription(bandname)
1098 assert band.GetDescription() == bandname
1100 del band
1102 # avoid that band names are written to global meta
1103 if 'band_names' in envi_metadict:
1104 del envi_metadict['band_names']
1106 # the expected key name is 'data_ignore_value', see below
1107 if 'nodata' in envi_metadict:
1108 del envi_metadict['nodata']
1110 # set data_ignore_value in case self.metadata.band_meta contains a unique nodata value
1111 if 'nodata' in self.metadata.band_meta:
1112 if len(set(self.metadata.band_meta['nodata'])) == 1:
1113 envi_metadict['data_ignore_value'] = str(self.metadata.band_meta['nodata'][0])
1114 else:
1115 warnings.warn("Band-specific nodata values are not supported by the ENVI header format.")
1117 ds_out.SetMetadata(envi_metadict, 'ENVI')
1119 if 'description' in envi_metadict:
1120 ds_out.SetDescription(envi_metadict['description'])
1122 ds_out.FlushCache()
1124 # NOTE: In case of ENVI format and GDAL_PAM_ENABLED=NO, the metadata is not written
1125 gdal.Unlink(out_path + '.aux.xml')
1127 else:
1128 # disable to write separate metadata XML files
1129 os.environ['GDAL_PAM_ENABLED'] = 'NO'
1131 ds_out = ds_src
1133 # set metadata
1134 if self.metadata.all_meta:
1136 # set global domain metadata
1137 if self.metadata.global_meta:
1138 ds_out.SetMetadata(dict((k, repr(v)) for k, v in self.metadata.global_meta.items()))
1140 if 'description' in envi_metadict:
1141 ds_out.SetDescription(envi_metadict['description'])
1143 # set band domain metadata
1144 bandmeta_dict = self.metadata.to_DataFrame().astype(str).to_dict()
1146 for bidx in range(self.bands):
1147 band = ds_out.GetRasterBand(bidx + 1)
1148 bandmeta = bandmeta_dict[bidx].copy()
1150 # filter global metadata out
1151 bandmeta = {k: v for k, v in bandmeta.items() if k not in self.metadata.global_meta}
1152 # meta2write = dict((k, repr(v)) for k, v in self.metadata.band_meta.items() if v is not np.nan)
1154 if 'band_names' in bandmeta:
1155 bandname = str(self.metadata.band_meta['band_names'][bidx]).strip()
1156 band.SetDescription(bandname)
1157 del bandmeta['band_names']
1159 if 'nodata' in bandmeta:
1160 band.SetNoDataValue(self.metadata.band_meta['nodata'][bidx])
1161 del bandmeta['nodata']
1163 if bandmeta:
1164 band.SetMetadata(bandmeta)
1166 band.FlushCache()
1167 del band
1169 ds_out.FlushCache()
1171 # write ds_out to disk,
1172 # -> writes the in-memory array or transforms the linked dataset into the target format
1173 gdal.Translate(out_path, ds_out, format=fmt, creationOptions=creationOptions)
1174 del ds_src
1176 finally:
1177 if 'GDAL_PAM_ENABLED' in os.environ:
1178 del os.environ['GDAL_PAM_ENABLED']
1180 if not os.path.exists(out_path):
1181 raise Exception(gdal.GetLastErrorMsg())
1183 del ds_out
1185 def dump(self, out_path: str) -> None:
1186 """Serialize the whole object instance to disk using dill."""
1187 import dill
1188 with open(out_path, 'wb') as outF:
1189 dill.dump(self, outF)
1191 def _get_plottable_image(self,
1192 xlim: Union[tuple, list] = None,
1193 ylim: Union[tuple, list] = None,
1194 band: int = None,
1195 boundsMap: tuple = None,
1196 boundsMapPrj: str = None,
1197 res_factor: Union[int, float] = None,
1198 nodataVal: Union[int, float] = None,
1199 out_prj: Union[str, int] = None,
1200 ignore_rotation: bool = False
1201 ) -> Tuple[np.ndarray, tuple, str]:
1202 # handle limits
1203 if boundsMap:
1204 boundsMapPrj = boundsMapPrj or self.prj
1205 image2plot, gt, prj = self.get_mapPos(boundsMap, boundsMapPrj, band2get=band,
1206 fillVal=nodataVal if nodataVal is not None else self.nodata)
1207 else:
1208 cS, cE = xlim if isinstance(xlim, (tuple, list)) else (0, self.columns)
1209 rS, rE = ylim if isinstance(ylim, (tuple, list)) else (0, self.rows)
1211 image2plot = self[rS:rE, cS:cE, band] if band is not None else self[rS:rE, cS:cE]
1212 gt, prj = self.geotransform, self.projection
1214 transOpt = ['SRC_METHOD=NO_GEOTRANSFORM'] if tuple(gt) == (0, 1, 0, 0, 0, -1) else None
1215 xdim, ydim = None, None
1216 in_nodata = nodataVal if nodataVal is not None else self.nodata
1217 out_nodata = in_nodata if in_nodata is not None else -9999
1218 if not np.can_cast(out_nodata, image2plot.dtype):
1219 image2plot = image2plot.astype(np.int32)
1220 if image2plot.dtype == bool:
1221 image2plot = image2plot.astype(int)
1223 # rotated images always have to be resampled for plotting
1224 if not ignore_rotation and self.is_rotated:
1225 out_prj = out_prj or self.projection
1227 if res_factor != 1. and image2plot.shape[0] * image2plot.shape[1] > 1e6: # shape > 1000*1000
1228 # sample image down / normalize
1229 xdim, ydim = \
1230 (self.columns * res_factor, self.rows * res_factor) if res_factor else \
1231 tuple(np.array([self.columns, self.rows]) / (max([self.columns, self.rows]) / 1000))
1232 xdim, ydim = int(xdim), int(ydim)
1234 if xdim or ydim or out_prj:
1235 from py_tools_ds.geo.raster.reproject import warp_ndarray
1236 image2plot, gt, prj = warp_ndarray(image2plot, self.geotransform, self.projection,
1237 out_XYdims=(xdim, ydim),
1238 in_nodata=in_nodata,
1239 out_nodata=out_nodata,
1240 transformerOptions=transOpt,
1241 out_prj=out_prj,
1242 q=True)
1243 if transOpt and 'NO_GEOTRANSFORM' in ','.join(transOpt):
1244 image2plot = np.flipud(image2plot)
1245 gt = list(gt)
1246 gt[3] = 0
1248 if xdim or ydim:
1249 print('Note: array has been downsampled to %s x %s for faster visualization.' % (xdim, ydim))
1251 return image2plot, gt, prj
1253 @staticmethod
1254 def _get_cmap_vmin_vmax(cmap: Union[str, 'Colormap'],
1255 vmin: float,
1256 vmax: float,
1257 pmin: float,
1258 pmax: float,
1259 image2plot: np.ndarray,
1260 nodataVal: Union[int, float]
1261 ):
1262 from matplotlib import pyplot as plt
1264 # set color palette
1265 palette = \
1266 plt.get_cmap(cmap) if cmap and isinstance(cmap, str) else \
1267 cmap if cmap else \
1268 plt.get_cmap('gray')
1269 palette = copy(palette) # allows to modify the colormap as in-place modifications are not allowed anymore
1271 if nodataVal is not None and \
1272 np.std(image2plot.astype(float)) != 0: # do not show nodata; float is needed to avoid overflow error
1274 image2plot = np.ma.masked_equal(image2plot, nodataVal)
1275 vmin_auto = np.nanpercentile(image2plot.compressed(), pmin)
1276 vmax_auto = np.nanpercentile(image2plot.compressed(), pmax)
1277 palette.set_bad('aqua', 0)
1279 else:
1280 vmin_auto = np.nanpercentile(image2plot, pmin)
1281 vmax_auto = np.nanpercentile(image2plot, pmax)
1283 vmin = vmin if vmin is not None else vmin_auto
1284 vmax = vmax if vmax is not None else vmax_auto
1286 palette.set_over('1')
1287 palette.set_under('0')
1289 return palette, vmin, vmax
1291 def show(self,
1292 xlim: Union[tuple, list] = None,
1293 ylim: Union[tuple, list] = None,
1294 band: int = None,
1295 boundsMap: tuple = None,
1296 boundsMapPrj: str = None,
1297 figsize: tuple = None,
1298 interpolation: Optional[str] = 'none',
1299 vmin: float = None,
1300 vmax: float = None,
1301 pmin: float = 2,
1302 pmax: float = 98,
1303 cmap: Union[str, 'Colormap'] = None,
1304 nodataVal: float = None,
1305 res_factor: float = None,
1306 interactive: bool = False,
1307 ax: 'axis' = None,
1308 ignore_rotation: bool = False
1309 ) -> Union['AxesImage', 'HoloMap']:
1310 """Plot the desired array position into a figure.
1312 :param xlim: [start_column, end_column]
1313 :param ylim: [start_row, end_row]
1314 :param band: the band index of the band to be plotted (if None and interactive==True all bands are
1315 shown, otherwise the first band is chosen)
1316 :param boundsMap: xmin, ymin, xmax, ymax
1317 :param boundsMapPrj:
1318 :param figsize:
1319 :param interpolation:
1320 :param vmin: darkest pixel value to be included in stretching
1321 :param vmax: brightest pixel value to be included in stretching
1322 :param pmin: percentage to be used for excluding the darkest pixels from stretching (default: 2)
1323 :param pmax: percentage to be used for excluding the brightest pixels from stretching (default: 98)
1324 :param cmap:
1325 :param nodataVal:
1326 :param res_factor: resolution factor for downsampling of the image to be plotted in order to save
1327 plotting time and memory (default=None -> downsampling is performed to 1000x1000)
1328 :param interactive: activates interactive plotting based on 'holoviews' library.
1329 NOTE: this deactivates the magic '% matplotlib inline' in Jupyter Notebook
1330 :param ax: only usable in non-interactive mode
1331 :param ignore_rotation: whether to ignore an image rotation angle included in the GDAL GeoTransform tuple for
1332 plotting (default: False)
1333 :return:
1334 """
1335 from matplotlib import pyplot as plt
1337 band = (band if band is not None else 0) if not interactive else band
1339 # get image to plot
1340 nodataVal = nodataVal if nodataVal is not None else self.nodata if self.nodata is not None else -9999
1341 image2plot, gt, prj = \
1342 self._get_plottable_image(xlim, ylim, band,
1343 boundsMap=boundsMap,
1344 boundsMapPrj=boundsMapPrj,
1345 res_factor=res_factor,
1346 nodataVal=nodataVal,
1347 ignore_rotation=ignore_rotation)
1349 palette, vmin, vmax = self._get_cmap_vmin_vmax(cmap, vmin, vmax, pmin, pmax, image2plot, nodataVal)
1350 if nodataVal is not None and np.std(image2plot.astype(float)) != 0:
1351 image2plot = np.ma.masked_equal(image2plot, nodataVal)
1353 # check availability of holoviews
1354 if interactive and not find_loader('holoviews'):
1355 warnings.warn("Interactive mode requires holoviews. Install it by running, e.g., "
1356 "'conda install -c conda-forge holoviews'. Using non-interactive mode.")
1357 interactive = False
1359 if interactive and image2plot.ndim == 3:
1360 import holoviews as hv
1361 from skimage.exposure import rescale_intensity
1362 hv.extension('matplotlib')
1364 cS, cE = xlim if isinstance(xlim, (tuple, list)) else (0, self.columns - 1)
1365 rS, rE = ylim if isinstance(ylim, (tuple, list)) else (0, self.rows - 1)
1367 # noinspection PyTypeChecker
1368 image2plot: np.ndarray = rescale_intensity(image2plot, in_range=(vmin, vmax))
1370 def get_hv_image(b):
1371 # FIXME ylabels have the wrong order
1372 hv_image = hv.Image(image2plot[:, :, b] if b is not None else image2plot,
1373 bounds=(cS, rS, cE, rE))
1374 return hv_image.options(cmap='gray',
1375 fig_inches=4 if figsize is None else figsize,
1376 show_grid=True)
1378 # hvIm = hv.Image(image2plot)(style={'cmap': 'gray'}, figure_inches=figsize)
1379 hmap = hv.HoloMap([(band, get_hv_image(band))
1380 for band in range(image2plot.shape[2])],
1381 kdims=['band'])
1383 return hmap
1385 else:
1386 if interactive:
1387 warnings.warn('Currently there is no interactive mode for single-band arrays. '
1388 'Switching to standard matplotlib figure..') # TODO implement zoomable fig
1390 # show image
1391 if not ax:
1392 plt.figure(figsize=figsize)
1393 ax = plt.gca()
1395 rows, cols = image2plot.shape[:2]
1396 im = ax.imshow(image2plot,
1397 palette,
1398 interpolation=interpolation,
1399 extent=(0, cols, rows, 0),
1400 vmin=vmin,
1401 vmax=vmax
1402 ) # compressed excludes nodata values
1403 plt.show()
1405 return im
1407 def show_map(self,
1408 xlim: Union[tuple, list] = None,
1409 ylim: Union[tuple, list] = None,
1410 band: int = 0,
1411 boundsMap: tuple = None,
1412 boundsMapPrj: str = None,
1413 out_epsg: int = None,
1414 figsize: tuple = None,
1415 interpolation: Optional[str] = 'none',
1416 vmin: float = None,
1417 vmax: float = None,
1418 pmin: float = 2,
1419 pmax: float = 98,
1420 cmap: Union[str, 'Colormap'] = None,
1421 draw_gridlines: bool = True,
1422 nodataVal: float = None,
1423 res_factor: float = None,
1424 return_map: bool = False
1425 ) -> Optional[tuple]:
1426 """Show a cartopy map of the associated image data (requires geocoding and projection information).
1428 :param xlim:
1429 :param ylim:
1430 :param band: band index (starting with 0)
1431 :param boundsMap: xmin, ymin, xmax, ymax
1432 :param boundsMapPrj:
1433 :param out_epsg: EPSG code of the output projection
1434 :param figsize:
1435 :param interpolation:
1436 :param vmin: darkest pixel value to be included in stretching
1437 :param vmax: brightest pixel value to be included in stretching
1438 :param pmin: percentage to be used for excluding the darkest pixels from stretching (default: 2)
1439 :param pmax: percentage to be used for excluding the brightest pixels from stretching (default: 98)
1440 :param cmap:
1441 :param draw_gridlines: whether to draw gridlines into the map (default: True)
1442 :param nodataVal:
1443 :param res_factor: <float> resolution factor for downsampling of the image to be plotted in order to save
1444 plotting time and memory (default=None -> downsampling is performed to 1000x1000)
1445 :param return_map:
1446 :return:
1447 """
1448 from matplotlib import pyplot as plt
1449 from cartopy.crs import epsg as ccrs_from_epsg, PlateCarree
1451 assert self.geotransform and tuple(self.geotransform) != (0, 1, 0, 0, 0, -1), \
1452 'A valid geotransform is needed for a map visualization. Got %s.' % list(self.geotransform)
1453 assert self.projection, "A projection is needed for a map visualization. Got '%s'." % self.projection
1455 # get image to plot
1456 # (reproject to LonLat as workaround in case self.epsg is None because cartopy relies on an existing EPSG code)
1457 nodataVal = nodataVal if nodataVal is not None else self.nodata
1458 gA2plot = GeoArray(*self._get_plottable_image(xlim, ylim, band,
1459 boundsMap=boundsMap,
1460 boundsMapPrj=boundsMapPrj,
1461 res_factor=res_factor,
1462 nodataVal=nodataVal,
1463 # FIXME EPSG:4326 fails for extraterrestrial data
1464 out_prj=self.epsg or 4326
1465 ),
1466 nodata=nodataVal)
1467 image2plot = gA2plot[:]
1469 # create map
1470 def get_cartopy_crs_from_epsg(epsg_code):
1471 if epsg_code:
1472 try:
1473 return ccrs_from_epsg(epsg_code)
1474 except ValueError:
1475 if epsg_code == 4326:
1476 return PlateCarree()
1477 else:
1478 raise NotImplementedError('The show_map() method currently does not support the given '
1479 'projection.')
1480 else:
1481 raise ValueError(f'Expected a valid EPSG code. Got {epsg_code}.')
1483 crs_in = get_cartopy_crs_from_epsg(gA2plot.epsg)
1484 crs_out = get_cartopy_crs_from_epsg(out_epsg if out_epsg is not None else gA2plot.epsg)
1486 fig = plt.figure(figsize=figsize)
1487 ax = plt.axes(projection=crs_out)
1489 ax.set_extent(gA2plot.box.boundsMap, crs=crs_in)
1491 palette, vmin, vmax = gA2plot._get_cmap_vmin_vmax(cmap, vmin, vmax, pmin, pmax, image2plot, nodataVal)
1492 if nodataVal is not None and np.std(image2plot) != 0: # do not show nodata
1493 image2plot = np.ma.masked_equal(image2plot, nodataVal)
1494 ax.imshow(image2plot, cmap=palette, interpolation=interpolation, vmin=vmin, vmax=vmax,
1495 origin='upper', transform=crs_in,
1496 extent=list(gA2plot.box.boundsMap)
1497 )
1499 # draw grid lines
1500 if draw_gridlines:
1501 ax.gridlines(draw_labels=True, linewidth=2, color='gray', alpha=0.5, linestyle='--') # cartopy>=0.18.0 only
1503 if return_map:
1504 return fig, ax
1505 else:
1506 plt.show()
1508 def show_footprint(self):
1509 """Show a web map containing the computed footprint of the GeoArray instance in a Jupyter notebook."""
1510 if not find_loader('folium') or not find_loader('geojson'):
1511 raise ImportError(
1512 "This method requires the libraries 'folium' and 'geojson'. They can be installed with "
1513 "the shell command 'pip install folium geojson'.")
1515 import folium
1516 import geojson
1518 lonlatPoly = reproject_shapelyGeometry(self.footprint_poly, self.prj, 4326)
1520 m = folium.Map(location=tuple(np.array(lonlatPoly.centroid.coords.xy).flatten())[::-1])
1521 gjs = geojson.Feature(geometry=lonlatPoly, properties={})
1522 folium.GeoJson(gjs).add_to(m)
1523 return m
1525 def show_histogram(self,
1526 band: int = 1,
1527 bins: int = 200,
1528 normed: bool = False,
1529 exclude_nodata: bool = True,
1530 vmin: float = None,
1531 vmax: float = None,
1532 figsize: tuple = None
1533 ) -> None:
1534 """Show a histogram of a given band.
1536 :param band: the band to be used to plot the histogram
1537 :param bins: number of bins to plot (default: 200)
1538 :param normed: whether to normalize the y-axis or not (default: False)
1539 :param exclude_nodata: whether tp exclude nodata value from the histogram
1540 :param vmin: minimum value for the x-axis of the histogram
1541 :param vmax: maximum value for the x-axis of the histogram
1542 :param figsize: figure size (tuple)
1543 """
1544 from matplotlib import pyplot as plt
1546 if self.nodata is not None and exclude_nodata:
1547 data = np.ma.masked_equal(self[band] if not self.bands == 1 else self[:], self.nodata)
1548 data = data.compressed()
1549 else:
1550 data = self[band] if not self.bands == 1 else self[:]
1552 vmin = vmin if vmin is not None else np.nanpercentile(data, 1)
1553 vmax = vmax if vmax is not None else np.nanpercentile(data, 99)
1554 image2plot = data
1556 plt.figure(figsize=figsize)
1557 plt.hist(list(image2plot.flat), density=normed, bins=bins, color='gray', range=[vmin, vmax])
1558 plt.xlabel('Pixel value')
1559 plt.ylabel('Probabilty' if normed else 'Count')
1560 plt.show()
1562 if not self.q:
1563 print('STD:', np.std(data))
1564 print('MEAN:', np.mean(data))
1565 print('2 % percentile:', np.nanpercentile(data, 2))
1566 print('98 % percentile:', np.nanpercentile(data, 98))
1568 def _show_profile(self,
1569 x: Union[int, Iterable],
1570 y: Union[int, Iterable],
1571 xlabel: str,
1572 ylabel: str,
1573 title: str,
1574 xlim: Union[tuple, list],
1575 ylim: Union[tuple, list],
1576 figsize: tuple,
1577 show_nodata: bool,
1578 return_fig: bool
1579 ) -> Optional['figure']:
1580 from matplotlib import pyplot as plt
1582 nd = None
1584 if self._nodata is not None and self._nodata in y:
1585 if show_nodata:
1586 nd = np.ma.masked_not_equal(y, self._nodata)
1587 title += ' (no-data is indicated in red)'
1588 else:
1589 title += ' (no-data is not shown)'
1591 y = np.ma.masked_equal(y, self._nodata)
1593 fig = plt.figure(figsize=figsize)
1594 plt.plot(x, y, 'k')
1595 if show_nodata and nd is not None:
1596 plt.plot(x, nd, 'r')
1598 plt.xlabel(xlabel)
1599 plt.ylabel(ylabel)
1600 plt.xlim(*xlim or (min(x), max(x)))
1601 if ylim:
1602 plt.ylim(*ylim)
1603 plt.grid()
1604 plt.title(title)
1606 if return_fig:
1607 return fig
1608 else:
1609 plt.show()
1611 def show_xprofile(self,
1612 row: int,
1613 band: int,
1614 xlim: Union[tuple, list] = None,
1615 ylim: Union[tuple, list] = None,
1616 title: str = None,
1617 figsize: tuple = (10, 5),
1618 show_nodata: bool = True,
1619 return_fig: bool = False
1620 ) -> Optional['figure']:
1621 """Show an x-profile at the given row/band image position.
1623 :param row: image row number (counts from 0)
1624 :param band: image band number (counts from 0)
1625 :param xlim: x-axis limits to be used in the plot
1626 :param ylim: y-axis limits to be used in the plot
1627 :param title: a custom plot title
1628 :param figsize: figure size (tuple)
1629 :param show_nodata: whether to show no-data values in the plot
1630 :param return_fig: whether to return the figure instead of showing it directly
1631 :return: plt.figure
1632 """
1633 x = range(self.columns)
1634 y = self[row, :, band]
1635 title = title or f'X-Profile at row {range(self.rows)[row]}, band {range(self.bands)[band]}'
1637 return self._show_profile(x, y, 'column', 'value', title, xlim, ylim, figsize, show_nodata, return_fig)
1639 def show_yprofile(self,
1640 column: int,
1641 band: int,
1642 xlim: Union[tuple, list] = None,
1643 ylim: Union[tuple, list] = None,
1644 title: str = None,
1645 figsize: tuple = (10, 5),
1646 show_nodata: bool = True,
1647 return_fig: bool = False
1648 ) -> Optional['figure']:
1649 """Show a y-profile at the given column/band image position.
1651 :param column: image column number (counts from 0)
1652 :param band: image band number (counts from 0)
1653 :param xlim: x-axis limits to be used in the plot
1654 :param ylim: y-axis limits to be used in the plot
1655 :param title: a custom plot title
1656 :param figsize: figure size (tuple)
1657 :param show_nodata: whether to show no-data values in the plot
1658 :param return_fig: whether to return the figure instead of showing it directly
1659 :return: plt.figure
1660 """
1661 x = range(self.rows)
1662 y = self[:, column, band]
1663 title = title or f'Y-Profile at column {range(self.columns)[column]}, band {range(self.bands)[band]}'
1665 return self._show_profile(x, y, 'row', 'value', title, xlim, ylim, figsize, show_nodata, return_fig)
1667 def show_zprofile(self,
1668 row: int,
1669 column: int,
1670 xlim: Union[tuple, list] = None,
1671 ylim: Union[tuple, list] = None,
1672 title: str = None,
1673 figsize: tuple = (10, 5),
1674 show_nodata: bool = True,
1675 return_fig: bool = False
1676 ) -> Optional['figure']:
1677 """Show a z-profile at the given row/column image position.
1679 :param row: image row number (counts from 0)
1680 :param column: image column number (counts from 0)
1681 :param xlim: x-axis limits to be used in the plot
1682 :param ylim: y-axis limits to be used in the plot
1683 :param title: a custom plot title
1684 :param figsize: figure size (tuple)
1685 :param show_nodata: whether to show no-data values in the plot
1686 :param return_fig: whether to return the figure instead of showing it directly
1687 :return: plt.figure
1688 """
1689 if self.ndim <= 2:
1690 raise RuntimeError(f'Plotting a z-profile is not possible for a {self.ndim}D array.')
1692 if 'wavelength' in self.meta.band_meta:
1693 x = self.meta.band_meta['wavelength']
1694 x_label = 'wavelength'
1696 else:
1697 x = range(self.bands)
1698 x_label = 'band'
1700 y = self[row, column, :]
1701 title = title or f'Z-Profile at row {range(self.rows)[row]}, column {range(self.columns)[column]}'
1703 return self._show_profile(x, y, x_label, 'value', title, xlim, ylim, figsize, show_nodata, return_fig)
1705 def clip_to_footprint(self) -> None:
1706 """Clip the GeoArray instance to the outer bounds of the actual footprint."""
1707 self.clip_to_poly(self.footprint_poly)
1709 def clip_to_poly(self, poly: Polygon) -> None:
1710 """Clip the GeoArray instance to the outer bounds of a given shapely polygon.
1712 :param poly: instance of shapely.geometry.Polygon
1713 """
1714 self.arr, self.gt, self.projection = self.get_mapPos(mapBounds=poly.bounds)
1715 self.mask_nodata.arr, self.mask_nodata.gt, self.mask_nodata.projection = \
1716 self.mask_nodata.get_mapPos(mapBounds=poly.bounds, mapBounds_prj=self.prj)
1717 assert self.shape[:2] == self.mask_nodata.shape
1719 if self._mask_baddata is not None:
1720 self.mask_baddata.arr, self.mask_baddata.gt, self.mask_baddata.projection = \
1721 self.mask_baddata.get_mapPos(mapBounds=poly.bounds)
1722 assert self.shape[:2] == self.mask_baddata.shape
1724 # update footprint polygon
1725 if self._footprint_poly:
1726 if not (self.footprint_poly.within(self.box.mapPoly) or self.footprint_poly.equals(self.box.mapPoly)):
1727 self.footprint_poly = self.footprint_poly.intersection(self.box.mapPoly)
1729 def tiles(self, tilesize: tuple = (100, 100)) -> GeoArrayTiles:
1730 """Get tiles of the full dataset in the given tile size.
1732 :param tilesize: target size of the tiles (rows, columns)
1733 NOTE: If rows or columns are None, all rows/columns are returned
1734 :return: GeoArrayTiles with elements like: (((rowStart, rowEnd), (colStart, colEnd)), tiledata)
1735 """
1736 bounds_alltiles = get_array_tilebounds(self.shape, tilesize)
1738 if self.ndim == 3:
1739 out_gen = ((((rS, rE), (cS, cE)), self[rS: rE + 1, cS: cE + 1, :])
1740 for (rS, rE), (cS, cE) in bounds_alltiles)
1741 else:
1742 out_gen = ((((rS, rE), (cS, cE)), self[rS: rE + 1, cS: cE + 1])
1743 for (rS, rE), (cS, cE) in bounds_alltiles)
1745 return GeoArrayTiles(out_gen, length=len(bounds_alltiles))
1747 def get_mapPos(self,
1748 mapBounds: tuple,
1749 mapBounds_prj: Union[str, int] = None,
1750 band2get: int = None,
1751 out_prj: Union[str, int] = None,
1752 out_gsd: tuple = None,
1753 arr_gt: tuple = None,
1754 arr_prj: str = None,
1755 fillVal: Union[int, float] = None,
1756 rspAlg: str = 'near',
1757 progress: bool = None,
1758 v: bool = False
1759 ) -> (np.ndarray, tuple, str):
1760 # TODO implement slice for indexing bands
1761 """Return the array data of GeoArray at a given geographic position.
1763 NOTE: The given mapBounds are snapped to the pixel grid of GeoArray. If the given mapBounds include areas
1764 outside of the extent of GeoArray, these areas are filled with the fill value of GeoArray.
1766 :param mapBounds: xmin, ymin, xmax, ymax
1767 :param mapBounds_prj: WKT projection string or EPSG code corresponding to mapBounds
1768 :param band2get: band index of the band to be returned (full array if not given)
1769 :param out_prj: output projection as WKT string or EPSG code. If not given, the self.projection is used.
1770 :param out_gsd: output spatial resolution in map units of the output projection (XGSD, YGSD)
1771 :param arr_gt: GDAL GeoTransform (taken from self if not given)
1772 :param arr_prj: WKT projection string (taken from self if not given)
1773 :param fillVal: nodata value
1774 :param rspAlg: <str> Resampling method to use. Available methods are:
1775 near, bilinear, cubic, cubicspline, lanczos, average, mode, max, min, med, q1, q2
1776 :param progress: whether to show progress bars or not
1777 :param v: verbose mode (not related to GeoArray.v; must be explicitly set)
1778 :return:
1779 """
1780 mapBounds_prj = mapBounds_prj if mapBounds_prj is not None else self.prj
1781 arr_gt = arr_gt or self.geotransform
1782 arr_prj = arr_prj or self.projection
1783 out_prj = out_prj or arr_prj
1784 out_gsd = out_gsd or (self.xgsd, self.ygsd)
1785 fillVal = fillVal if fillVal is not None else self.nodata
1786 progress = progress if progress is not None else self.progress
1788 if self.is_inmem and (not arr_gt or not arr_prj):
1789 raise ValueError('In case of in-mem arrays the respective geotransform and projection of the array '
1790 'has to be passed.')
1792 if v:
1793 print('%s.get_mapPos() input parameters:')
1794 print('\tmapBounds', mapBounds, '<==>', self.box.boundsMap)
1795 print('\tEPSG', WKT2EPSG(mapBounds_prj), self.epsg)
1796 print('\tarr_gt', arr_gt, self.gt)
1797 print('\tarr_prj', WKT2EPSG(arr_prj), self.epsg)
1798 print('\tfillVal', fillVal, self.nodata, '\n')
1800 sub_arr, sub_gt, sub_prj = get_array_at_mapPos(self, arr_gt, arr_prj,
1801 out_prj=out_prj,
1802 mapBounds=mapBounds,
1803 mapBounds_prj=mapBounds_prj,
1804 fillVal=fillVal,
1805 rspAlg=rspAlg,
1806 out_gsd=out_gsd,
1807 band2get=band2get,
1808 progress=progress)
1809 return sub_arr, sub_gt, sub_prj
1811 def get_subset(self,
1812 xslice: slice = None,
1813 yslice: slice = None,
1814 zslice: Union[slice, list] = None,
1815 return_GeoArray: bool = True,
1816 reset_bandnames: bool = False
1817 ) -> Union['GeoArray', Tuple[np.ndarray, tuple, str]]:
1818 """Return a new GeoArray instance representing a subset of the initial one with respect to given array position.
1820 :param xslice: a slice providing the X-position for the subset in the form slice(xstart, xend, xstep)
1821 :param yslice: a slice providing the Y-position for the subset in the form slice(ystart, yend, ystep)
1822 :param zslice: a slice providing the Z-position for the subset in the form slice(zstart, zend, zstep)
1823 or a list containing the indices of the bands to extract
1824 :param return_GeoArray: whether to return an instance of GeoArray (default) or a tuple(np.ndarray, gt, prj)
1825 :param reset_bandnames: whether band names of subset should be copied from source GeoArray or reset to
1826 'B1', 'B2', 'B3', ...
1827 :return:
1828 """
1829 xslice, yslice, zslice = xslice or slice(None), yslice or slice(None), zslice or slice(None)
1830 xslicing = xslice.start is not None or xslice.stop is not None or xslice.step is not None # type: bool
1831 yslicing = yslice.start is not None or yslice.stop is not None or yslice.step is not None # type: bool
1832 zslicing = isinstance(zslice, list) or \
1833 zslice.start is not None or zslice.stop is not None or zslice.step is not None # type: bool
1835 # get array subset #
1836 ####################
1838 # get sub_arr
1839 if zslicing:
1840 # validation
1841 if self.ndim == 2:
1842 raise ValueError('Invalid zslice. A 2D GeoArray is not slicable in z-direction.')
1844 sub_arr = self[yslice, xslice, zslice] # row, col, band
1845 else:
1846 sub_arr = self[yslice, xslice] # row, col
1848 if sub_arr is None:
1849 raise ValueError('Unable to return an array for the given slice parameters.')
1851 # copy GeoArray instance #
1852 ##########################
1854 # get deepcopy of self (but without slowly copying the full-size self.arr)
1855 # -> cache self.arr, overwrite with subset, quickly create sub_gA and recreate self.arr
1856 # -> do the same with attributes 'mask_nodata' and 'mask_baddata'
1857 from .masks import NoDataMask, BadDataMask
1858 full_arr = self.arr
1859 full_mask_nodata = self._mask_nodata
1860 full_mask_baddata = self._mask_baddata
1862 self.arr = sub_arr
1863 if isinstance(self._mask_nodata, NoDataMask): # avoid computing it here by using private
1864 self._mask_nodata = self._mask_nodata.get_subset(xslice=xslice, yslice=yslice)
1865 if isinstance(self._mask_baddata, BadDataMask): # avoid computing it here by using private
1866 self._mask_baddata = self._mask_baddata.get_subset(xslice=xslice, yslice=yslice)
1868 sub_gA = deepcopy(self) # do not copy any references, otherwise numpy arrays would be copied as views
1870 self._arr = full_arr
1871 if isinstance(self._mask_nodata, NoDataMask):
1872 self._mask_nodata = full_mask_nodata
1873 if isinstance(self._mask_baddata, BadDataMask):
1874 self._mask_baddata = full_mask_baddata
1876 # numpy array references need to be cleared separately (also called by self._mask_nodata.get_subset() above)
1877 sub_gA.deepcopy_array()
1879 # handle metadata #
1880 ###################
1882 # adapt geotransform
1883 sub_ulXY = imXY2mapXY((xslice.start or 0, yslice.start or 0), self.gt)
1884 sub_gt = (sub_ulXY[0], self.gt[1], self.gt[2], sub_ulXY[1], self.gt[4], self.gt[5])
1886 # apply zslice to bandnames and metadata
1887 if zslicing:
1888 bNs_out = list(np.array(list(self._bandnames))[zslice]) if self._bandnames else None
1889 _meta_out = self.metadata.get_subset(bands2extract=zslice)
1890 else:
1891 bNs_out = list(self._bandnames) if self._bandnames else None
1892 _meta_out = self.meta
1894 sub_gA.gt = sub_gt
1895 sub_gA.metadata = _meta_out
1896 sub_gA.bandnames = bNs_out
1897 sub_gA.filePath = self.filePath
1898 if xslicing or yslicing:
1899 sub_gA._footprint_poly = None # reset footprint_poly -> has to be updated
1901 if reset_bandnames:
1902 del sub_gA.bandnames # also updates bandnames within self.meta
1904 return sub_gA if return_GeoArray else (sub_arr, sub_gt, self.prj)
1906 def reproject_to_new_grid(self,
1907 prototype: 'GeoArray' = None,
1908 tgt_prj: Union[str, int] = None,
1909 tgt_xygrid: Sequence = None,
1910 rspAlg: Union[str, int] = 'cubic',
1911 CPUs: int = None
1912 ):
1913 """Reproject all array-like attributes to a given target grid.
1915 :param prototype: an instance of GeoArray to be used as pixel grid reference
1916 :param tgt_prj: GDAL projection as WKT string or EPSG code ('epsg:1234' or <EPSG_int>)
1917 :param tgt_xygrid: target XY grid, e.g. [[xmin,xmax], [ymax, ymin]] for the UL corner
1918 :param rspAlg: GDAL compatible resampling algorithm code
1919 :param CPUs: number of CPUs to use (default: None -> use all available CPUs)
1920 :return:
1921 """
1922 assert (tgt_prj and tgt_xygrid) or prototype, "Provide either 'prototype' or 'tgt_prj' and 'tgt_xygrid'!"
1923 tgt_prj = tgt_prj or prototype.prj
1924 tgt_xygrid = tgt_xygrid or prototype.xygrid_specs
1925 assert tgt_xygrid[1][0] > tgt_xygrid[1][1]
1927 # set target GSD
1928 tgt_xgsd, tgt_ygsd = abs(tgt_xygrid[0][0] - tgt_xygrid[0][1]), abs(tgt_xygrid[1][0] - tgt_xygrid[1][1])
1930 # set target bounds
1931 tgt_bounds = reproject_shapelyGeometry(self.box.mapPoly, self.prj, tgt_prj).bounds
1933 gt = (tgt_xygrid[0][0], tgt_xgsd, 0, max(tgt_xygrid[1]), 0, -tgt_ygsd)
1934 xmin, ymin, xmax, ymax = snap_bounds_to_pixGrid(tgt_bounds, gt, roundAlg='on')
1936 from py_tools_ds.geo.raster.reproject import warp_ndarray
1937 self.arr, self.gt, self.prj = \
1938 warp_ndarray(self[:], self.gt, self.prj, tgt_prj,
1939 out_gsd=(tgt_xgsd, tgt_ygsd),
1940 out_bounds=(xmin, ymin, xmax, ymax),
1941 out_bounds_prj=tgt_prj,
1942 rspAlg=rspAlg,
1943 in_nodata=self.nodata,
1944 CPUs=CPUs,
1945 progress=self.progress,
1946 q=self.q)
1948 if hasattr(self, '_mask_nodata') and self._mask_nodata is not None:
1949 self.mask_nodata.reproject_to_new_grid(prototype=prototype,
1950 tgt_prj=tgt_prj,
1951 tgt_xygrid=tgt_xygrid,
1952 rspAlg='near',
1953 CPUs=CPUs)
1955 if hasattr(self, '_mask_baddata') and self._mask_baddata is not None:
1956 self.mask_baddata.reproject_to_new_grid(prototype=prototype,
1957 tgt_prj=tgt_prj,
1958 tgt_xygrid=tgt_xygrid,
1959 rspAlg='near',
1960 CPUs=CPUs)
1962 # update footprint polygon
1963 if self._footprint_poly:
1964 if not (self.footprint_poly.within(self.box.mapPoly) or self.footprint_poly.equals(self.box.mapPoly)):
1965 self.footprint_poly = self.footprint_poly.intersection(self.box.mapPoly)
1967 def read_pointData(self,
1968 mapXY_points: Union[np.ndarray, tuple],
1969 mapXY_points_prj: Union[str, int] = None,
1970 band: int = None,
1971 offside_val: Union[float, int] = np.nan
1972 ) -> Union[int, float, np.ndarray]:
1973 """Return the array values for the given set of X/Y coordinates.
1975 NOTE: If GeoArray has been instanced with a file path, the function will read the dataset into memory.
1977 :param mapXY_points: X/Y coordinates of the points of interest. If a numpy array is
1978 given, it must have the shape [Nx2]
1979 :param mapXY_points_prj: WKT string or EPSG code of the projection corresponding to the given
1980 coordinates.
1981 :param band: the band index of the band of interest. If None, the values of all bands are
1982 returned.
1983 :param offside_val: fill value in case input coordinates are geographically outside of the GeoArray
1984 instance
1985 :return: - int in case only a singe coordinate is passed
1986 - np.ndarray with shape [Nx1] in case only one band is requested
1987 - np.ndarray with shape [Nx1xbands] in case all bands are requested
1988 """
1989 mapXY = mapXY_points if isinstance(mapXY_points, np.ndarray) else np.array(mapXY_points).reshape(1, 2)
1990 prj = mapXY_points_prj if mapXY_points_prj else self.prj
1992 assert prj, 'A projection is needed for returning image DNs at specific map X/Y coordinates!'
1993 if not prj_equal(prj1=prj, prj2=self.prj):
1994 mapX, mapY = transform_any_prj(prj, self.prj, mapXY[:, 0], mapXY[:, 1])
1995 mapXY = np.hstack([mapX.reshape(-1, 1),
1996 mapY.reshape(-1, 1)])
1998 imXY = mapXY2imXY(mapXY, self.geotransform)
2000 # get a mask of all map positions geographically outside the GeoArray instance
2001 mask_off = (np.any(imXY < 0, axis=1)) |\
2002 (imXY[:, 0] >= self.columns) |\
2003 (imXY[:, 1] >= self.rows)
2005 imYX = np.fliplr(np.array(imXY)).astype(np.int16)
2007 if imYX.size == 2: # only one coordinate pair
2008 Y, X = imYX[0].tolist()
2010 if X < 0 or X >= self.columns or Y < 0 or Y >= self.rows:
2011 pointdata = offside_val
2012 else:
2013 pointdata = self[Y, X, band]
2015 else: # multiple coordinate pairs
2016 if True in mask_off:
2017 shape_exp = (imXY.shape[0], 1, self.bands) if band is None and self.bands > 1 else (imXY.shape[0], 1)
2018 pointdata = np.full(shape_exp, offside_val, dtype=self.dtype)
2019 imYX = imYX[~mask_off, :]
2021 if band is None and self.bands > 1:
2022 # multiple bands requested
2023 pointdata[~mask_off, 0, :] = \
2024 self[tuple(imYX.T.tolist() + [band])] \
2025 .reshape(imYX.shape[0], self.bands)
2026 elif self.bands == 1:
2027 # if there is only one band
2028 pointdata[~mask_off, 0] = \
2029 self[tuple(imYX.T.tolist())]
2030 else:
2031 # one band out of multiple bands requested
2032 pointdata[~mask_off, 0] = \
2033 self[tuple(imYX.T.tolist() + [band])]
2034 else:
2035 pointdata = self[tuple(imYX.T.tolist() + [band])]
2037 return pointdata
2039 def to_mem(self) -> 'GeoArray':
2040 """Read the whole dataset into memory and sets self.arr to the read data."""
2041 self.arr = self[:]
2042 return self
2044 def to_disk(self) -> 'GeoArray':
2045 """Set self.arr back to None for in-memory instances, to release memory.
2047 Note: This requires that the GeoArray was instanced with a file path.
2048 """
2049 if self.filePath and os.path.isfile(self.filePath):
2050 self._arr = None
2051 else:
2052 warnings.warn('GeoArray object cannot be turned into disk mode because this asserts that GeoArray.filePath '
2053 'contains a valid file path. Got %s.' % self.filePath)
2054 return self
2056 def deepcopy_array(self) -> None:
2057 if self.is_inmem:
2058 temp = np.empty_like(self.arr)
2059 temp[:] = self.arr
2060 self.arr = temp # deep copy: converts view to its own array in order to avoid wrong output
2062 def cache_array_subset(self, arr_pos: list) -> None:
2063 """Set the array cache of the GeoArray instance to the given array to speed up calculations afterwards.
2065 :param arr_pos: a list of array indices as passed to __getitem__
2066 """
2067 if not self.is_inmem:
2068 # noinspection PyStatementEffect
2069 self[arr_pos] # runs __getitem__ and sets self._arr_cache
2070 else:
2071 pass # no array cache needed because array is already in memory
2073 def flush_cache(self) -> None:
2074 """Clear the array cache of the GeoArray instance."""
2075 self._arr_cache = None
2078def get_GeoArray_from_GDAL_ds(ds: gdal.Dataset) -> GeoArray:
2079 # TODO implement as class method of GeoArray
2080 arr = gdal_array.DatasetReadAsArray(ds)
2081 if len(arr.shape) == 3:
2082 arr = np.swapaxes(np.swapaxes(arr, 0, 2), 0, 1)
2083 return GeoArray(arr, ds.GetGeoTransform(), ds.GetProjection())
2086class MultiGeoArray(object): # pragma: no cover
2087 def __init__(self, GeoArray_list: List[GeoArray]):
2088 """Get an instance of MultiGeoArray.
2090 :param GeoArray_list: a list of GeoArray instances having a geographic overlap
2091 """
2092 self._arrs = None
2094 self.arrs = GeoArray_list
2096 raise NotImplementedError('This class is not yet working.') # FIXME
2098 @property
2099 def arrs(self) -> List[GeoArray]:
2100 return self._arrs
2102 @arrs.setter
2103 def arrs(self, GeoArray_list: List[GeoArray]):
2104 for geoArr in GeoArray_list:
2105 assert isinstance(geoArr, GeoArray), "'arrs' can only be set to a list of GeoArray instances."
2107 self._arrs = GeoArray_list