# -*- coding: UTF-8 -*- """ @Project :microproduct @File :Imagehandle.py @Function :实现对待处理SAR数据的读取、格式标准化和处理完后保存文件功能 @Author :SHJ @Date :2021/10/15 @Version :1.0.0 """ import logging import os from osgeo import gdal import numpy as np from PIL import Image import cv2 logger = logging.getLogger("mylog") class ImageHandler: """ 影像读取、编辑、保存 """ def __init__(self): pass @staticmethod def get_dataset(filename): """ :param filename: tif路径 :return: 图像句柄 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None return dataset def get_scope(self, filename): """ :param filename: tif路径 :return: 图像范围 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None im_scope = self.cal_img_scope(dataset) del dataset return im_scope @staticmethod def get_projection(filename): """ :param filename: tif路径 :return: 地图投影信息 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None im_proj = dataset.GetProjection() del dataset return im_proj @staticmethod def get_geotransform(filename): """ :param filename: tif路径 :return: 从图像坐标空间(行、列),也称为(像素、线)到地理参考坐标空间(投影或地理坐标)的仿射变换 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None geotransform = dataset.GetGeoTransform() del dataset return geotransform @staticmethod def get_bands(filename): """ :param filename: tif路径 :return: 影像的波段数 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None bands = dataset.RasterCount del dataset return bands @staticmethod def get_band_array(filename, num=1): """ :param filename: tif路径 :param num: 波段序号 :return: 对应波段的矩阵数据 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None bands = dataset.GetRasterBand(num) array = bands.ReadAsArray(0, 0, bands.XSize, bands.YSize) del dataset return array @staticmethod def get_data(filename): """ :param filename: tif路径 :return: 获取所有波段的数据 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None im_width = dataset.RasterXSize im_height = dataset.RasterYSize im_data = dataset.ReadAsArray(0, 0, im_width, im_height) del dataset return im_data @staticmethod def get_img_width(filename): """ :param filename: tif路径 :return: 影像宽度 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None width = dataset.RasterXSize del dataset return width @staticmethod def get_img_height(filename): """ :param filename: tif路径 :return: 影像高度 """ gdal.AllRegister() dataset = gdal.Open(filename) if dataset is None: return None height = dataset.RasterYSize del dataset return height @staticmethod def read_img(filename): """ 影像读取 :param filename: :return: """ gdal.AllRegister() img_dataset = gdal.Open(filename) # 打开文件 if img_dataset is None: msg = 'Could not open ' + filename logger.error(msg) return None, None, None im_proj = img_dataset.GetProjection() # 地图投影信息 if im_proj is None: return None, None, None im_geotrans = img_dataset.GetGeoTransform() # 仿射矩阵 im_width = img_dataset.RasterXSize # 栅格矩阵的行数 im_height = img_dataset.RasterYSize # 栅格矩阵的行数 im_arr = img_dataset.ReadAsArray(0, 0, im_width, im_height) del img_dataset return im_proj, im_geotrans, im_arr def cal_img_scope(self, dataset): """ 计算影像的地理坐标范围 根据GDAL的六参数模型将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换) :param dataset :GDAL地理数据 :return: list[point_upleft, point_upright, point_downleft, point_downright] """ if dataset is None: return None img_geotrans = dataset.GetGeoTransform() if img_geotrans is None: return None width = dataset.RasterXSize # 栅格矩阵的列数 height = dataset.RasterYSize # 栅格矩阵的行数 point_upleft = self.trans_rowcol2geo(img_geotrans, 0, 0) point_upright = self.trans_rowcol2geo(img_geotrans, width, 0) point_downleft = self.trans_rowcol2geo(img_geotrans, 0, height) point_downright = self.trans_rowcol2geo(img_geotrans, width, height) return [point_upleft, point_upright, point_downleft, point_downright] @staticmethod def trans_rowcol2geo(img_geotrans,img_col, img_row): """ 据GDAL的六参数模型仿射矩阵将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换) :param img_geotrans: 仿射矩阵 :param img_col:图像纵坐标 :param img_row:图像横坐标 :return: [geo_x,geo_y] """ geo_x = img_geotrans[0] + img_geotrans[1] * img_col + img_geotrans[2]* img_row geo_y = img_geotrans[3] + img_geotrans[4] * img_col + img_geotrans[5]* img_row return [geo_x, geo_y] @staticmethod def write_img(filename, im_proj, im_geotrans, im_data): """ 影像保存 :param filename: :param im_proj: :param im_geotrans: :param im_data: :return: """ gdal_dtypes = { 'int8': gdal.GDT_Byte, 'unit16': gdal.GDT_UInt16, 'int16': gdal.GDT_Int16, 'unit32': gdal.GDT_UInt32, 'int32': gdal.GDT_Int32, 'float32': gdal.GDT_Float32, 'float64': gdal.GDT_Float64, } if not gdal_dtypes.get(im_data.dtype.name, None) is None: datatype = gdal_dtypes[im_data.dtype.name] else: datatype = gdal.GDT_Float32 # 判读数组维数 if len(im_data.shape) == 3: im_bands,im_height, im_width, = im_data.shape else: im_bands, (im_height, im_width) = 1, im_data.shape # 创建文件 if os.path.exists(os.path.split(filename)[0]) is False: os.makedirs(os.path.split(filename)[0]) driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间 dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 dataset.SetProjection(im_proj) # 写入投影 if im_bands == 1: dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据 else: for i in range(im_bands): # dataset.GetRasterBand(i + 1).WriteArray(im_data[:, :, im_bands - 1 - i]) dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) del dataset # 写GeoTiff文件 @staticmethod def write_img_rpc(filename, im_proj, im_geotrans, im_data, rpc_dict): """ 图像中写入rpc信息 """ # 判断栅格数据的数据类型 if 'int8' in im_data.dtype.name: datatype = gdal.GDT_Byte elif 'int16' in im_data.dtype.name: datatype = gdal.GDT_Int16 else: datatype = gdal.GDT_Float32 # 判读数组维数 if len(im_data.shape) == 3: im_bands, im_height, im_width = im_data.shape else: im_bands, (im_height, im_width) = 1, im_data.shape # 创建文件 driver = gdal.GetDriverByName("GTiff") dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 dataset.SetProjection(im_proj) # 写入投影 # 写入RPC参数 for k in rpc_dict.keys(): dataset.SetMetadataItem(k, rpc_dict[k], 'RPC') if im_bands == 1: dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据 else: for i in range(im_bands): dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) del dataset def transtif2mask(self,out_tif_path, in_tif_path, threshold): """ :param out_tif_path:输出路径 :param in_tif_path:输入的路径 :param threshold:阈值 """ im_proj, im_geotrans, im_arr, im_scope = self.read_img(in_tif_path) im_arr_mask = (im_arr < threshold).astype(int) self.write_img(out_tif_path, im_proj, im_geotrans, im_arr_mask) def write_quick_view(self, tif_path, color_img=False, quick_view_path=None): """ 生成快视图,默认快视图和影像同路径且同名 :param tif_path:影像路径 :param color_img:是否生成随机伪彩色图 :param quick_view_path:快视图路径 """ if quick_view_path is None: quick_view_path = os.path.splitext(tif_path)[0]+'.jpg' n = self.get_bands(tif_path) if n == 1: # 单波段 t_data = self.get_data(tif_path) else: # 多波段,转为强度数据 t_data = self.get_data(tif_path) t_data = t_data.astype(float) t_data = np.sqrt(t_data[0] ** 2 + t_data[1] ** 2) t_r = self.get_img_height(tif_path) t_c = self.get_img_width(tif_path) if t_r > 1024 or t_c > 1024: if t_r > t_c: q_r = 1024 q_c = int(t_c/t_r * 1024) else: q_c = 1024 q_r = int(t_r/t_c * 1024) else: q_r = t_r q_c = t_c if color_img is True: # 生成伪彩色图 img = np.zeros((t_r, t_c, 3), dtype=np.uint8) # (高,宽,维度) u = np.unique(t_data) for i in u: if i != 0: w = np.where(t_data == i) img[w[0], w[1], 0] = np.random.randint(0, 255) # 随机生成一个0到255之间的整数 可以通过挑参数设定不同的颜色范围 img[w[0], w[1], 1] = np.random.randint(0, 255) img[w[0], w[1], 2] = np.random.randint(0, 255) img = cv2.resize(img, (q_c, q_r)) # (宽,高) cv2.imwrite(quick_view_path, img) # cv2.imshow("result4", img) # cv2.waitKey(0) else: # 灰度图 t_data = (t_data - np.min(t_data)) / (np.max(t_data) - np.min(t_data)) * 255 out_img = Image.fromarray(t_data) out_img = out_img.resize((q_r, q_c)) # 重采样 out_img = out_img.convert("L") # 转换成灰度图 out_img.save(quick_view_path) if __name__ == '__main__': ih = ImageHandler() path = 'D:\Dual1_1_feature1.tif' # ih.write_quick_view(path, color_img=False) print('done')