microproduct/deformation-sentiral/ImageHandle.py

378 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: UTF-8 -*-
"""
@Project microproduct
@File Imagehandle.py
@Function 实现对待处理SAR数据的读取、格式标准化和处理完后保存文件功能
@Author SHJ
@Date 2021/10/15
@Version 1.0.0
"""
import logging
import os
from osgeo import gdal
import numpy as np
from PIL import Image
import cv2
logger = logging.getLogger("mylog")
class ImageHandler:
"""
影像读取、编辑、保存
"""
def __init__(self):
pass
@staticmethod
def get_dataset(filename):
"""
:param filename: tif路径
:return: 图像句柄
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
return dataset
def get_scope(self, filename):
"""
:param filename: tif路径
:return: 图像范围
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
im_scope = self.cal_img_scope(dataset)
del dataset
return im_scope
@staticmethod
def get_projection(filename):
"""
:param filename: tif路径
:return: 地图投影信息
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
im_proj = dataset.GetProjection()
del dataset
return im_proj
@staticmethod
def get_geotransform(filename):
"""
:param filename: tif路径
:return: 从图像坐标空间(行、列),也称为(像素、线)到地理参考坐标空间(投影或地理坐标)的仿射变换
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
geotransform = dataset.GetGeoTransform()
del dataset
return geotransform
@staticmethod
def get_bands(filename):
"""
:param filename: tif路径
:return: 影像的波段数
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
bands = dataset.RasterCount
del dataset
return bands
@staticmethod
def get_band_array(filename, num=1):
"""
:param filename: tif路径
:param num: 波段序号
:return: 对应波段的矩阵数据
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
bands = dataset.GetRasterBand(num)
array = bands.ReadAsArray(0, 0, bands.XSize, bands.YSize)
del dataset
return array
@staticmethod
def get_data(filename):
"""
:param filename: tif路径
:return: 获取所有波段的数据
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_data = dataset.ReadAsArray(0, 0, im_width, im_height)
del dataset
return im_data
@staticmethod
def get_img_width(filename):
"""
:param filename: tif路径
:return: 影像宽度
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
width = dataset.RasterXSize
del dataset
return width
@staticmethod
def get_img_height(filename):
"""
:param filename: tif路径
:return: 影像高度
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
height = dataset.RasterYSize
del dataset
return height
@staticmethod
def read_img(filename):
"""
影像读取
:param filename:
:return:
"""
gdal.AllRegister()
img_dataset = gdal.Open(filename) # 打开文件
if img_dataset is None:
msg = 'Could not open ' + filename
logger.error(msg)
return None, None, None
im_proj = img_dataset.GetProjection() # 地图投影信息
if im_proj is None:
return None, None, None
im_geotrans = img_dataset.GetGeoTransform() # 仿射矩阵
im_width = img_dataset.RasterXSize # 栅格矩阵的行数
im_height = img_dataset.RasterYSize # 栅格矩阵的行数
im_arr = img_dataset.ReadAsArray(0, 0, im_width, im_height)
del img_dataset
return im_proj, im_geotrans, im_arr
def cal_img_scope(self, dataset):
"""
计算影像的地理坐标范围
根据GDAL的六参数模型将影像图上坐标行列号转为投影坐标或地理坐标根据具体数据的坐标系统转换
:param dataset :GDAL地理数据
:return: list[point_upleft, point_upright, point_downleft, point_downright]
"""
if dataset is None:
return None
img_geotrans = dataset.GetGeoTransform()
if img_geotrans is None:
return None
width = dataset.RasterXSize # 栅格矩阵的列数
height = dataset.RasterYSize # 栅格矩阵的行数
point_upleft = self.trans_rowcol2geo(img_geotrans, 0, 0)
point_upright = self.trans_rowcol2geo(img_geotrans, width, 0)
point_downleft = self.trans_rowcol2geo(img_geotrans, 0, height)
point_downright = self.trans_rowcol2geo(img_geotrans, width, height)
return [point_upleft, point_upright, point_downleft, point_downright]
@staticmethod
def trans_rowcol2geo(img_geotrans,img_col, img_row):
"""
据GDAL的六参数模型仿射矩阵将影像图上坐标行列号转为投影坐标或地理坐标根据具体数据的坐标系统转换
:param img_geotrans: 仿射矩阵
:param img_col:图像纵坐标
:param img_row:图像横坐标
:return: [geo_x,geo_y]
"""
geo_x = img_geotrans[0] + img_geotrans[1] * img_col + img_geotrans[2]* img_row
geo_y = img_geotrans[3] + img_geotrans[4] * img_col + img_geotrans[5]* img_row
return [geo_x, geo_y]
@staticmethod
def write_img(filename, im_proj, im_geotrans, im_data):
"""
影像保存
:param filename:
:param im_proj:
:param im_geotrans:
:param im_data:
:return:
"""
gdal_dtypes = {
'int8': gdal.GDT_Byte,
'unit16': gdal.GDT_UInt16,
'int16': gdal.GDT_Int16,
'unit32': gdal.GDT_UInt32,
'int32': gdal.GDT_Int32,
'float32': gdal.GDT_Float32,
'float64': gdal.GDT_Float64,
}
if not gdal_dtypes.get(im_data.dtype.name, None) is None:
datatype = gdal_dtypes[im_data.dtype.name]
else:
datatype = gdal.GDT_Float32
# 判读数组维数
if len(im_data.shape) == 3:
im_bands,im_height, im_width, = im_data.shape
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
if os.path.exists(os.path.split(filename)[0]) is False:
os.makedirs(os.path.split(filename)[0])
driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据
else:
for i in range(im_bands):
# dataset.GetRasterBand(i + 1).WriteArray(im_data[:, :, im_bands - 1 - i])
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
# 写GeoTiff文件
@staticmethod
def write_img_rpc(filename, im_proj, im_geotrans, im_data, rpc_dict):
"""
图像中写入rpc信息
"""
# 判断栅格数据的数据类型
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_Int16
else:
datatype = gdal.GDT_Float32
# 判读数组维数
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
# 写入RPC参数
for k in rpc_dict.keys():
dataset.SetMetadataItem(k, rpc_dict[k], 'RPC')
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据
else:
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
def transtif2mask(self,out_tif_path, in_tif_path, threshold):
"""
:param out_tif_path:输出路径
:param in_tif_path:输入的路径
:param threshold:阈值
"""
im_proj, im_geotrans, im_arr, im_scope = self.read_img(in_tif_path)
im_arr_mask = (im_arr < threshold).astype(int)
self.write_img(out_tif_path, im_proj, im_geotrans, im_arr_mask)
def write_quick_view(self, tif_path, color_img=False, quick_view_path=None):
"""
生成快视图,默认快视图和影像同路径且同名
:param tif_path:影像路径
:param color_img:是否生成随机伪彩色图
:param quick_view_path:快视图路径
"""
if quick_view_path is None:
quick_view_path = os.path.splitext(tif_path)[0]+'.jpg'
n = self.get_bands(tif_path)
if n == 1: # 单波段
t_data = self.get_data(tif_path)
else: # 多波段,转为强度数据
t_data = self.get_data(tif_path)
t_data = t_data.astype(float)
t_data = np.sqrt(t_data[0] ** 2 + t_data[1] ** 2)
t_r = self.get_img_height(tif_path)
t_c = self.get_img_width(tif_path)
if t_r > 1024 or t_c > 1024:
if t_r > t_c:
q_r = 1024
q_c = int(t_c/t_r * 1024)
else:
q_c = 1024
q_r = int(t_r/t_c * 1024)
else:
q_r = t_r
q_c = t_c
if color_img is True:
# 生成伪彩色图
img = np.zeros((t_r, t_c, 3), dtype=np.uint8) # (高,宽,维度)
u = np.unique(t_data)
for i in u:
if i != 0:
w = np.where(t_data == i)
img[w[0], w[1], 0] = np.random.randint(0, 255) # 随机生成一个0到255之间的整数 可以通过挑参数设定不同的颜色范围
img[w[0], w[1], 1] = np.random.randint(0, 255)
img[w[0], w[1], 2] = np.random.randint(0, 255)
img = cv2.resize(img, (q_c, q_r)) # (宽,高)
cv2.imwrite(quick_view_path, img)
# cv2.imshow("result4", img)
# cv2.waitKey(0)
else:
# 灰度图
t_data = (t_data - np.min(t_data)) / (np.max(t_data) - np.min(t_data)) * 255
out_img = Image.fromarray(t_data)
out_img = out_img.resize((q_r, q_c)) # 重采样
out_img = out_img.convert("L") # 转换成灰度图
out_img.save(quick_view_path)
if __name__ == '__main__':
ih = ImageHandler()
path = 'D:\Dual1_1_feature1.tif'
# ih.write_quick_view(path, color_img=False)
print('done')