Coverage for py_tools_ds/io/raster/writer.py: 0%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

65 statements  

1# -*- coding: utf-8 -*- 

2 

3# py_tools_ds - A collection of geospatial data analysis tools that simplify standard 

4# operations when handling geospatial raster and vector data as well as projections. 

5# 

6# Copyright (C) 2016-2021 

7# - Daniel Scheffler (GFZ Potsdam, daniel.scheffler@gfz-potsdam.de) 

8# - Helmholtz Centre Potsdam - GFZ German Research Centre for Geosciences Potsdam, 

9# Germany (https://www.gfz-potsdam.de/) 

10# 

11# This software was developed within the context of the GeoMultiSens project funded 

12# by the German Federal Ministry of Education and Research 

13# (project grant code: 01 IS 14 010 A-C). 

14# 

15# Licensed under the Apache License, Version 2.0 (the "License"); 

16# you may not use this file except in compliance with the License. 

17# You may obtain a copy of the License at 

18# 

19# http://www.apache.org/licenses/LICENSE-2.0 

20# 

21# Unless required by applicable law or agreed to in writing, software 

22# distributed under the License is distributed on an "AS IS" BASIS, 

23# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

24# See the License for the specific language governing permissions and 

25# limitations under the License. 

26 

27import os 

28import multiprocessing 

29 

30from pyproj import CRS 

31from osgeo import gdal 

32 

33from ...dtypes.conversion import convertGdalNumpyDataType 

34from ...geo.map_info import geotransform2mapinfo 

35from ...numeric.array import get_array_tilebounds 

36 

37__author__ = "Daniel Scheffler" 

38 

39 

40def write_numpy_to_image(array, path_out, outFmt='GTIFF', gt=None, prj=None): 

41 rows, cols, bands = list(array.shape) + [1] if len(array.shape) == 2 else array.shape 

42 gdal_dtype = gdal.GetDataTypeByName(convertGdalNumpyDataType(array.dtype)) 

43 outDs = gdal.GetDriverByName(outFmt).Create(path_out, cols, rows, bands, gdal_dtype) 

44 for b in range(bands): 

45 band = outDs.GetRasterBand(b + 1) 

46 arr2write = array if len(array.shape) == 2 else array[:, :, b] 

47 band.WriteArray(arr2write) 

48 del band 

49 if gt: 

50 outDs.SetGeoTransform(gt) 

51 if prj: 

52 if int(gdal.__version__[0]) < 3: 

53 # noinspection PyTypeChecker 

54 prj = CRS(prj).to_wkt(version="WKT1_GDAL") 

55 

56 outDs.SetProjection(prj) 

57 del outDs 

58 

59 

60def write_envi(arr, outpath, gt=None, prj=None): 

61 from spectral.io import envi 

62 

63 if gt or prj: 

64 assert gt and prj, 'gt and prj must be provided together or left out.' 

65 

66 meta = {'map info': geotransform2mapinfo(gt, prj), 'coordinate system string': prj} if gt else None 

67 shape = (arr.shape[0], arr.shape[1], 1) if len(arr.shape) == 3 else arr.shape 

68 out = envi.create_image(outpath, metadata=meta, shape=shape, dtype=arr.dtype, interleave='bsq', ext='.bsq', 

69 force=True) # 8bit for multiple masks in one file 

70 out_mm = out.open_memmap(writable=True) 

71 out_mm[:, :, 0] = arr 

72 

73 

74shared_array_on_disk__memmap = None 

75 

76 

77def init_SharedArray_on_disk(out_path, dims, gt=None, prj=None): 

78 from spectral.io import envi 

79 

80 global shared_array_on_disk__memmap 

81 global shared_array_on_disk__path 

82 path = out_path if not os.path.splitext(out_path)[1] == '.bsq' else \ 

83 os.path.splitext(out_path)[0] + '.hdr' 

84 Meta = {} 

85 if gt and prj: 

86 Meta['map info'] = geotransform2mapinfo(gt, prj) 

87 Meta['coordinate system string'] = prj 

88 shared_array_on_disk__obj = envi.create_image(path, metadata=Meta, shape=dims, dtype='uint16', 

89 interleave='bsq', ext='.bsq', force=True) 

90 shared_array_on_disk__memmap = shared_array_on_disk__obj.open_memmap(writable=True) 

91 

92 

93def fill_arr_on_disk(argDict): 

94 pos = argDict.get('pos') 

95 in_path = argDict.get('in_path') 

96 band = argDict.get('band') 

97 

98 (rS, rE), (cS, cE) = pos 

99 ds = gdal.Open(in_path) 

100 band = ds.GetRasterBand(band) 

101 data = band.ReadAsArray(cS, rS, cE - cS + 1, rE - rS + 1) 

102 shared_array_on_disk__memmap[rS:rE + 1, cS:cE + 1, 0] = data 

103 del ds, band 

104 

105 

106def convert_gdal_to_bsq__mp(in_path, out_path, band=1): 

107 """ 

108 

109 Usage: 

110 ref_ds,tgt_ds = gdal.Open(self.path_imref),gdal.Open(self.path_im2shift) 

111 ref_pathTmp, tgt_pathTmp = None,None 

112 if ref_ds.GetDriver().ShortName!='ENVI': 

113 ref_pathTmp = IO.get_tempfile(ext='.bsq') 

114 IO.convert_gdal_to_bsq__mp(self.path_imref,ref_pathTmp) 

115 self.path_imref = ref_pathTmp 

116 if tgt_ds.GetDriver().ShortName!='ENVI': 

117 tgt_pathTmp = IO.get_tempfile(ext='.bsq') 

118 IO.convert_gdal_to_bsq__mp(self.path_im2shift,tgt_pathTmp) 

119 self.path_im2shift = tgt_pathTmp 

120 ref_ds=tgt_ds=None 

121 

122 :param in_path: 

123 :param out_path: 

124 :param band: 

125 :return: 

126 """ 

127 

128 ds = gdal.Open(in_path) 

129 dims = (ds.RasterYSize, ds.RasterXSize) 

130 gt, prj = ds.GetGeoTransform(), ds.GetProjection() 

131 del ds 

132 init_SharedArray_on_disk(out_path, dims, gt, prj) 

133 positions = get_array_tilebounds(array_shape=dims, tile_shape=[512, 512]) 

134 

135 argDicts = [{'pos': pos, 'in_path': in_path, 'band': band} for pos in positions] 

136 

137 with multiprocessing.Pool() as pool: 

138 pool.map(fill_arr_on_disk, argDicts) 

139 pool.close() 

140 pool.join()