""" @Project :microproduct @File :ImageHandle.py @Function :实现对待处理SAR数据的读取、格式标准化和处理完后保存文件功能 @Author :LMM @Date :2021/10/19 14:39 @Version :1.0.0 """ import os from PIL import Image from osgeo import gdal from osgeo import osr import numpy as np from PIL import Image import cv2 import logging import math logger = logging.getLogger("mylog") class ImageHandler: """ 影像读取、编辑、保存 """ def __init__(self): pass @staticmethod def get_dataset(filename): """ :param filename: tif路径 :return: 图像句柄 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None return dataset def get_scope(self, filename): """ :param filename: tif路径 :return: 图像范围 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None im_scope = self.cal_img_scope(dataset) del dataset return im_scope @staticmethod def get_projection(filename): """ :param filename: tif路径 :return: 地图投影信息 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None im_proj = dataset.GetProjection() del dataset return im_proj @staticmethod def get_geotransform(filename): """ :param filename: tif路径 :return: 从图像坐标空间(行、列),也称为(像素、线)到地理参考坐标空间(投影或地理坐标)的仿射变换 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None geotransform = dataset.GetGeoTransform() del dataset return geotransform def get_invgeotransform(filename): """ :param filename: tif路径 :return: 从地理参考坐标空间(投影或地理坐标)的到图像坐标空间(行、列 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None geotransform = dataset.GetGeoTransform() geotransform=gdal.InvGeoTransform(geotransform) del dataset return geotransform @staticmethod def get_bands(filename): """ :param filename: tif路径 :return: 影像的波段数 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None bands = dataset.RasterCount del dataset return bands @staticmethod def geo2lonlat(dataset, x, y): """ 将投影坐标转为经纬度坐标(具体的投影坐标系由给定数据确定) :param dataset: GDAL地理数据 :param x: 投影坐标x :param y: 投影坐标y :return: 投影坐标(x, y)对应的经纬度坐标(lon, lat) """ prosrs = osr.SpatialReference() prosrs.ImportFromWkt(dataset.GetProjection()) geosrs = prosrs.CloneGeogCS() ct = osr.CoordinateTransformation(prosrs, geosrs) coords = ct.TransformPoint(x, y) return coords[:2] @staticmethod def get_band_array(filename, num=1): """ :param filename: tif路径 :param num: 波段序号 :return: 对应波段的矩阵数据 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None bands = dataset.GetRasterBand(num) array = bands.ReadAsArray(0, 0, bands.XSize, bands.YSize) # if 'int' in str(array.dtype): # array[np.where(array == -9999)] = np.inf # else: # array[np.where(array < -9000.0)] = np.nan del dataset return array @staticmethod def get_data(filename): """ :param filename: tif路径 :return: 获取所有波段的数据 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None im_width = dataset.RasterXSize im_height = dataset.RasterYSize im_data = dataset.ReadAsArray(0, 0, im_width, im_height) del dataset return im_data @staticmethod def get_all_band_array(filename): """ (大气延迟算法) 将ERA-5影像所有波段存为一个数组, 波段数在第三维度 get_data()->(37,8,8) :param filename: 影像路径 get_all_band_array ->(8,8,37) :return: 影像数组 """ dataset = gdal.Open(filename) x_size = dataset.RasterXSize y_size = dataset.RasterYSize nums = dataset.RasterCount array = np.zeros((y_size, x_size, nums), dtype=float) if nums == 1: bands_0 = dataset.GetRasterBand(1) array = bands_0.ReadAsArray(0, 0, x_size, y_size) else: for i in range(0, nums): bands = dataset.GetRasterBand(i+1) arr = bands.ReadAsArray(0, 0, x_size, y_size) array[:, :, i] = arr return array @staticmethod def get_img_width(filename): """ :param filename: tif路径 :return: 影像宽度 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None width = dataset.RasterXSize del dataset return width @staticmethod def get_img_height(filename): """ :param filename: tif路径 :return: 影像高度 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None height = dataset.RasterYSize del dataset return height @staticmethod def read_img(filename): """ 影像读取 :param filename: :return: """ gdal.AllRegister() img_dataset = gdal.Open(filename) # 打开文件 if img_dataset is None: msg = 'Could not open ' + filename logger.error(msg) return None, None, None im_proj = img_dataset.GetProjection() # 地图投影信息 if im_proj is None: return None, None, None im_geotrans = img_dataset.GetGeoTransform() # 仿射矩阵 im_width = img_dataset.RasterXSize # 栅格矩阵的行数 im_height = img_dataset.RasterYSize # 栅格矩阵的行数 im_arr = img_dataset.ReadAsArray(0, 0, im_width, im_height) del img_dataset return im_proj, im_geotrans, im_arr def cal_img_scope(self, dataset): """ 计算影像的地理坐标范围 根据GDAL的六参数模型将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换) :param dataset :GDAL地理数据 :return: list[point_upleft, point_upright, point_downleft, point_downright] """ if dataset is None: return None img_geotrans = dataset.GetGeoTransform() if img_geotrans is None: return None width = dataset.RasterXSize # 栅格矩阵的列数 height = dataset.RasterYSize # 栅格矩阵的行数 point_upleft = self.trans_rowcol2geo(img_geotrans, 0, 0) point_upright = self.trans_rowcol2geo(img_geotrans, width, 0) point_downleft = self.trans_rowcol2geo(img_geotrans, 0, height) point_downright = self.trans_rowcol2geo(img_geotrans, width, height) return [point_upleft, point_upright, point_downleft, point_downright] @staticmethod def get_scope_ori_sim(filename): """ 计算影像的地理坐标范围 根据GDAL的六参数模型将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换) :param dataset :GDAL地理数据 :return: list[point_upleft, point_upright, point_downleft, point_downright] """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None width = dataset.RasterXSize # 栅格矩阵的列数 height = dataset.RasterYSize # 栅格矩阵的行数 band1 = dataset.GetRasterBand(1) array1 = band1.ReadAsArray(0, 0, band1.XSize, band1.YSize) band2 = dataset.GetRasterBand(2) array2 = band2.ReadAsArray(0, 0, band2.XSize, band2.YSize) if array1[0, 0] < array1[0, width-1]: point_upleft = [array1[0, 0], array2[0, 0]] point_upright = [array1[0, width-1], array2[0, width-1]] else: point_upright = [array1[0, 0], array2[0, 0]] point_upleft = [array1[0, width-1], array2[0, width-1]] if array1[height-1, 0] < array1[height-1, width-1]: point_downleft = [array1[height - 1, 0], array2[height - 1, 0]] point_downright = [array1[height - 1, width - 1], array2[height - 1, width - 1]] else: point_downright = [array1[height - 1, 0], array2[height - 1, 0]] point_downleft = [array1[height - 1, width - 1], array2[height - 1, width - 1]] if(array2[0, 0] < array2[height - 1, 0]): #上下调换顺序 tmp1 = point_upleft point_upleft = point_downleft point_downleft = tmp1 tmp2 = point_upright point_upright = point_downright point_downright = tmp2 return [point_upleft, point_upright, point_downleft, point_downright] @staticmethod def trans_rowcol2geo(img_geotrans,img_col, img_row): """ 据GDAL的六参数模型仿射矩阵将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换) :param img_geotrans: 仿射矩阵 :param img_col:图像纵坐标 :param img_row:图像横坐标 :return: [geo_x,geo_y] """ geo_x = img_geotrans[0] + img_geotrans[1] * img_col + img_geotrans[2] * img_row geo_y = img_geotrans[3] + img_geotrans[4] * img_col + img_geotrans[5] * img_row return [geo_x, geo_y] @staticmethod def write_era_into_img(filename, im_proj, im_geotrans, im_data): """ 影像保存 :param filename: :param im_proj: :param im_geotrans: :param im_data: :return: """ gdal_dtypes = { 'int8': gdal.GDT_Byte, 'unit16': gdal.GDT_UInt16, 'int16': gdal.GDT_Int16, 'unit32': gdal.GDT_UInt32, 'int32': gdal.GDT_Int32, 'float32': gdal.GDT_Float32, 'float64': gdal.GDT_Float64, } if not gdal_dtypes.get(im_data.dtype.name, None) is None: datatype = gdal_dtypes[im_data.dtype.name] else: datatype = gdal.GDT_Float32 # 判读数组维数 if len(im_data.shape) == 3: im_height, im_width, im_bands = im_data.shape # shape[0] 行数 else: im_bands, (im_height, im_width) = 1, im_data.shape # 创建文件 if os.path.exists(os.path.split(filename)[0]) is False: os.makedirs(os.path.split(filename)[0]) driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间 dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 dataset.SetProjection(im_proj) # 写入投影 if im_bands == 1: dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据 else: for i in range(im_bands): dataset.GetRasterBand(i + 1).WriteArray(im_data[:, :, i]) # dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) del dataset # 写GeoTiff文件 @staticmethod def lat_lon_to_pixel(raster_dataset_path, location): """From zacharybears.com/using-python-to-translate-latlon-locations-to-pixels-on-a-geotiff/.""" gdal.AllRegister() raster_dataset = gdal.Open(raster_dataset_path) if raster_dataset is None: return None ds = raster_dataset gt = ds.GetGeoTransform() srs = osr.SpatialReference() srs.ImportFromWkt(ds.GetProjection()) srs_lat_lon = srs.CloneGeogCS() ct = osr.CoordinateTransformation(srs_lat_lon, srs) new_location = [None, None] # Change the point locations into the GeoTransform space (new_location[1], new_location[0], holder) = ct.TransformPoint(location[1], location[0]) # Translate the x and y coordinates into pixel values Xp = new_location[0] Yp = new_location[1] dGeoTrans = gt dTemp = dGeoTrans[1] * dGeoTrans[5] - dGeoTrans[2] * dGeoTrans[4] Xpixel = (dGeoTrans[5] * (Xp - dGeoTrans[0]) - dGeoTrans[2] * (Yp - dGeoTrans[3])) / dTemp Yline = (dGeoTrans[1] * (Yp - dGeoTrans[3]) - dGeoTrans[4] * (Xp - dGeoTrans[0])) / dTemp del raster_dataset return (Xpixel, Yline) @staticmethod def write_img(filename, im_proj, im_geotrans, im_data, no_data='0'): """ 影像保存 :param filename: 保存的路径 :param im_proj: :param im_geotrans: :param im_data: :param no_data: 把无效值设置为 nodata :return: """ gdal_dtypes = { 'int8': gdal.GDT_Byte, 'unit16': gdal.GDT_UInt16, 'int16': gdal.GDT_Int16, 'unit32': gdal.GDT_UInt32, 'int32': gdal.GDT_Int32, 'float32': gdal.GDT_Float32, 'float64': gdal.GDT_Float64, } if not gdal_dtypes.get(im_data.dtype.name, None) is None: datatype = gdal_dtypes[im_data.dtype.name] else: datatype = gdal.GDT_Float32 flag = False # 判读数组维数 if len(im_data.shape) == 3: im_bands, im_height, im_width = im_data.shape flag = True else: im_bands, (im_height, im_width) = 1, im_data.shape # 创建文件 if os.path.exists(os.path.split(filename)[0]) is False: os.makedirs(os.path.split(filename)[0]) driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间 dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 dataset.SetProjection(im_proj) # 写入投影 if im_bands == 1: # outRaster.GetRasterBand(1).WriteArray(array) # 写入数组数据 if flag: outband = dataset.GetRasterBand(1) outband.WriteArray(im_data[0]) if no_data != 'null': outband.SetNoDataValue(np.double(no_data)) outband.FlushCache() else: outband = dataset.GetRasterBand(1) outband.WriteArray(im_data) if no_data != 'null': outband.SetNoDataValue(np.double(no_data)) outband.FlushCache() else: for i in range(im_bands): outband = dataset.GetRasterBand(1 + i) outband.WriteArray(im_data[i]) if no_data != 'null': outband.SetNoDataValue(np.double(no_data)) outband.FlushCache() # outRaster.GetRasterBand(i + 1).WriteArray(array[i]) del dataset # 写GeoTiff文件 @staticmethod def write_img_envi(filename, im_proj, im_geotrans, im_data, no_data='null'): """ 影像保存 :param filename: 保存的路径 :param im_proj: :param im_geotrans: :param im_data: :param no_data: 把无效值设置为 nodata :return: """ gdal_dtypes = { 'int8': gdal.GDT_Byte, 'unit16': gdal.GDT_UInt16, 'int16': gdal.GDT_Int16, 'unit32': gdal.GDT_UInt32, 'int32': gdal.GDT_Int32, 'float32': gdal.GDT_Float32, 'float64': gdal.GDT_Float64, } if not gdal_dtypes.get(im_data.dtype.name, None) is None: datatype = gdal_dtypes[im_data.dtype.name] else: datatype = gdal.GDT_Float32 # 判读数组维数 if len(im_data.shape) == 3: im_bands, im_height, im_width = im_data.shape else: im_bands, (im_height, im_width) = 1, im_data.shape # 创建文件 if os.path.exists(os.path.split(filename)[0]) is False: os.makedirs(os.path.split(filename)[0]) driver = gdal.GetDriverByName("ENVI") # 数据类型必须有,因为要计算需要多大内存空间 dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 dataset.SetProjection(im_proj) # 写入投影 if im_bands == 1: # outRaster.GetRasterBand(1).WriteArray(array) # 写入数组数据 outband = dataset.GetRasterBand(1) outband.WriteArray(im_data) if no_data != 'null': outband.SetNoDataValue(no_data) outband.FlushCache() else: for i in range(im_bands): outband = dataset.GetRasterBand(1 + i) outband.WriteArray(im_data[i]) outband.FlushCache() # outRaster.GetRasterBand(i + 1).WriteArray(array[i]) del dataset @staticmethod def write_img_rpc(filename, im_proj, im_geotrans, im_data, rpc_dict): """ 图像中写入rpc信息 """ # 判断栅格数据的数据类型 if 'int8' in im_data.dtype.name: datatype = gdal.GDT_Byte elif 'int16' in im_data.dtype.name: datatype = gdal.GDT_Int16 else: datatype = gdal.GDT_Float32 # 判读数组维数 if len(im_data.shape) == 3: im_bands, im_height, im_width = im_data.shape else: im_bands, (im_height, im_width) = 1, im_data.shape # 创建文件 driver = gdal.GetDriverByName("GTiff") dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 dataset.SetProjection(im_proj) # 写入投影 # 写入RPC参数 for k in rpc_dict.keys(): dataset.SetMetadataItem(k, rpc_dict[k], 'RPC') if im_bands == 1: dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据 else: for i in range(im_bands): dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) del dataset def transtif2mask(self,out_tif_path, in_tif_path, threshold): """ :param out_tif_path:输出路径 :param in_tif_path:输入的路径 :param threshold:阈值 """ im_proj, im_geotrans, im_arr, im_scope = self.read_img(in_tif_path) im_arr_mask = (im_arr < threshold).astype(int) self.write_img(out_tif_path, im_proj, im_geotrans, im_arr_mask) def write_quick_view(self, tif_path, color_img=False, quick_view_path=None): """ 生成快视图,默认快视图和影像同路径且同名 :param tif_path:影像路径 :param color_img:是否生成随机伪彩色图 :param quick_view_path:快视图路径 """ if quick_view_path is None: quick_view_path = os.path.splitext(tif_path)[0]+'.jpg' n = self.get_bands(tif_path) if n == 1: # 单波段 t_data = self.get_data(tif_path) else: # 多波段,转为强度数据 t_data = self.get_data(tif_path) t_data = t_data.astype(float) t_data = np.sqrt(t_data[0] ** 2 + t_data[1] ** 2) t_data[np.isnan(t_data)] = 0 t_data[np.where(t_data == -9999)] = 0 t_r = self.get_img_height(tif_path) t_c = self.get_img_width(tif_path) if t_r > 10000 or t_c > 10000: q_r = int(t_r / 10) q_c = int(t_c / 10) elif 1024 < t_r < 10000 or 1024 < t_c < 10000: if t_r > t_c: q_r = 1024 q_c = int(t_c/t_r * 1024) else: q_c = 1024 q_r = int(t_r/t_c * 1024) else: q_r = t_r q_c = t_c if color_img is True: # 生成伪彩色图 img = np.zeros((t_r, t_c, 3), dtype=np.uint8) # (高,宽,维度) u = np.unique(t_data) for i in u: if i != 0: w = np.where(t_data == i) img[w[0], w[1], 0] = np.random.randint(0, 255) # 随机生成一个0到255之间的整数 可以通过挑参数设定不同的颜色范围 img[w[0], w[1], 1] = np.random.randint(0, 255) img[w[0], w[1], 2] = np.random.randint(0, 255) img = cv2.resize(img, (q_c, q_r)) # (宽,高) cv2.imwrite(quick_view_path, img) # cv2.imshow("result4", img) # cv2.waitKey(0) else: # 灰度图 min = np.percentile(t_data, 2) # np.nanmin(t_data) max = np.percentile(t_data, 98) # np.nanmax(t_data) t_data[np.isnan(t_data)] = max # if (max - min) < 256: t_data = (t_data - min) / (max - min) * 255 out_img = Image.fromarray(t_data) out_img = out_img.resize((q_c, q_r)) # 重采样 out_img = out_img.convert("L") # 转换成灰度图 out_img.save(quick_view_path) def limit_field(self, out_path, in_path, min_value, max_value): """ :param out_path:输出路径 :param in_path:主mask路径,输出影像采用主mask的地理信息 :param min_value :param max_value """ proj = self.get_projection(in_path) geotrans = self.get_geotransform(in_path) array = self.get_band_array(in_path, 1) array[array < min_value] = min_value array[array > max_value] = max_value self.write_img(out_path, proj, geotrans, array) return True def band_merge(self, lon, lat, ori_sim): lon_arr = self.get_data(lon) lat_arr = self.get_data(lat) temp = np.zeros((2, lon_arr.shape[0], lon_arr.shape[1]), dtype=float) temp[0, :, :] = lon_arr[:, :] temp[1, :, :] = lat_arr[:, :] self.write_img(ori_sim, '', [0.0, 1.0, 0.0, 0.0, 0.0, 1.0], temp, '0') def get_scopes(self, ori_sim): ori_sim_data = self.get_data(ori_sim) lon = ori_sim_data[0, :, :] lat = ori_sim_data[1, :, :] min_lon = np.nanmin(np.where((lon != 0) & ~np.isnan(lon), lon, np.inf)) max_lon = np.nanmax(np.where((lon != 0) & ~np.isnan(lon), lon, -np.inf)) min_lat = np.nanmin(np.where((lat != 0) & ~np.isnan(lat), lat, np.inf)) max_lat = np.nanmax(np.where((lat != 0) & ~np.isnan(lat), lat, -np.inf)) scopes = [[min_lon, max_lat], [max_lon, max_lat], [min_lon, min_lat], [max_lon, min_lat]] return scopes def get_center_scopes(self, dataset): if dataset is None: return None img_geotrans = dataset.GetGeoTransform() if img_geotrans is None: return None width = dataset.RasterXSize # 栅格矩阵的列数 height = dataset.RasterYSize # 栅格矩阵的行数 x_split = int(width/5) y_split = int(height/5) img_col_start = x_split * 1 img_col_end = x_split * 3 img_row_start = y_split * 1 img_row_end = y_split *3 point_upleft = self.trans_rowcol2geo(img_geotrans, img_col_start, img_row_start) point_upright = self.trans_rowcol2geo(img_geotrans, img_col_end, img_row_start) point_downleft = self.trans_rowcol2geo(img_geotrans, img_col_start, img_row_end) point_downright = self.trans_rowcol2geo(img_geotrans, img_col_end, img_row_end) return [point_upleft, point_upright, point_downleft, point_downright] def write_view(self, tif_path, color_img=False, quick_view_path=None): """ 生成快视图,默认快视图和影像同路径且同名 :param tif_path:影像路径 :param color_img:是否生成随机伪彩色图 :param quick_view_path:快视图路径 """ if quick_view_path is None: quick_view_path = os.path.splitext(tif_path)[0]+'.jpg' n = self.get_bands(tif_path) if n == 1: # 单波段 t_data = self.get_data(tif_path) else: # 多波段,转为强度数据 t_data = self.get_data(tif_path) t_data = t_data.astype(float) t_data = np.sqrt(t_data[0] ** 2 + t_data[1] ** 2) t_data[np.isnan(t_data)] = 0 t_data[np.where(t_data == -9999)] = 0 t_r = self.get_img_height(tif_path) t_c = self.get_img_width(tif_path) q_r = t_r q_c = t_c if color_img is True: # 生成伪彩色图 img = np.zeros((t_r, t_c, 3), dtype=np.uint8) # (高,宽,维度) u = np.unique(t_data) for i in u: if i != 0: w = np.where(t_data == i) img[w[0], w[1], 0] = np.random.randint(0, 255) # 随机生成一个0到255之间的整数 可以通过挑参数设定不同的颜色范围 img[w[0], w[1], 1] = np.random.randint(0, 255) img[w[0], w[1], 2] = np.random.randint(0, 255) img = cv2.resize(img, (q_c, q_r)) # (宽,高) cv2.imwrite(quick_view_path, img) # cv2.imshow("result4", img) # cv2.waitKey(0) else: # 灰度图 min = np.percentile(t_data, 2) # np.nanmin(t_data) max = np.percentile(t_data, 98) # np.nanmax(t_data) # if (max - min) < 256: t_data = (t_data - min) / (max - min) * 255 out_img = Image.fromarray(t_data) out_img = out_img.resize((q_c, q_r)) # 重采样 out_img = out_img.convert("L") # 转换成灰度图 out_img.save(quick_view_path) return quick_view_path if __name__ == '__main__': fn = r"C:\Users\sxwcc\Downloads\HJ2E_MYC_QPS_001752_E118.0_N37.7_20230204_SLC_AHV_L10000010458-cal-SMC\HJ2E_MYC_QPS_001752_E118.0_N37.7_20230204_SLC_AHV_L10000010458-cal-SMC.tif" ImageHandler().write_quick_view(fn) # path = r'D:\BaiduNetdiskDownload\GZ\lon.rdr' # path2 = r'D:\BaiduNetdiskDownload\GZ\lat.rdr' # path3 = r'D:\BaiduNetdiskDownload\GZ\lon_lat.tif' # s = ImageHandler().band_merge(path, path2, path3) # print(s) # pass