microproduct/tool/algorithm/image/ImageHandle.py

652 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
@Project microproduct
@File ImageHandle.py
@Function 实现对待处理SAR数据的读取、格式标准化和处理完后保存文件功能
@Author LMM
@Date 2021/10/19 14:39
@Version 1.0.0
"""
import os
from PIL import Image
from osgeo import gdal
from osgeo import osr
import numpy as np
from PIL import Image
import cv2
import logging
import math
logger = logging.getLogger("mylog")
class ImageHandler:
"""
影像读取、编辑、保存
"""
def __init__(self):
pass
@staticmethod
def get_dataset(filename):
"""
:param filename: tif路径
:return: 图像句柄
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
return dataset
def get_scope(self, filename):
"""
:param filename: tif路径
:return: 图像范围
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
im_scope = self.cal_img_scope(dataset)
del dataset
return im_scope
@staticmethod
def get_projection(filename):
"""
:param filename: tif路径
:return: 地图投影信息
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
im_proj = dataset.GetProjection()
del dataset
return im_proj
@staticmethod
def get_geotransform(filename):
"""
:param filename: tif路径
:return: 从图像坐标空间(行、列),也称为(像素、线)到地理参考坐标空间(投影或地理坐标)的仿射变换
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
geotransform = dataset.GetGeoTransform()
del dataset
return geotransform
def get_invgeotransform(filename):
"""
:param filename: tif路径
:return: 从地理参考坐标空间(投影或地理坐标)的到图像坐标空间(行、列
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
geotransform = dataset.GetGeoTransform()
geotransform=gdal.InvGeoTransform(geotransform)
del dataset
return geotransform
@staticmethod
def get_bands(filename):
"""
:param filename: tif路径
:return: 影像的波段数
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
bands = dataset.RasterCount
del dataset
return bands
@staticmethod
def geo2lonlat(dataset, x, y):
"""
将投影坐标转为经纬度坐标(具体的投影坐标系由给定数据确定)
:param dataset: GDAL地理数据
:param x: 投影坐标x
:param y: 投影坐标y
:return: 投影坐标(x, y)对应的经纬度坐标(lon, lat)
"""
prosrs = osr.SpatialReference()
prosrs.ImportFromWkt(dataset.GetProjection())
geosrs = prosrs.CloneGeogCS()
ct = osr.CoordinateTransformation(prosrs, geosrs)
coords = ct.TransformPoint(x, y)
return coords[:2]
@staticmethod
def get_band_array(filename, num=1):
"""
:param filename: tif路径
:param num: 波段序号
:return: 对应波段的矩阵数据
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
bands = dataset.GetRasterBand(num)
array = bands.ReadAsArray(0, 0, bands.XSize, bands.YSize)
# if 'int' in str(array.dtype):
# array[np.where(array == -9999)] = np.inf
# else:
# array[np.where(array < -9000.0)] = np.nan
del dataset
return array
@staticmethod
def get_data(filename):
"""
:param filename: tif路径
:return: 获取所有波段的数据
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_data = dataset.ReadAsArray(0, 0, im_width, im_height)
del dataset
return im_data
@staticmethod
def get_all_band_array(filename):
"""
(大气延迟算法)
将ERA-5影像所有波段存为一个数组, 波段数在第三维度 get_data->3788
:param filename 影像路径 get_all_band_array ->8837
:return: 影像数组
"""
dataset = gdal.Open(filename)
x_size = dataset.RasterXSize
y_size = dataset.RasterYSize
nums = dataset.RasterCount
array = np.zeros((y_size, x_size, nums), dtype=float)
if nums == 1:
bands_0 = dataset.GetRasterBand(1)
array = bands_0.ReadAsArray(0, 0, x_size, y_size)
else:
for i in range(0, nums):
bands = dataset.GetRasterBand(i+1)
arr = bands.ReadAsArray(0, 0, x_size, y_size)
array[:, :, i] = arr
return array
@staticmethod
def get_img_width(filename):
"""
:param filename: tif路径
:return: 影像宽度
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
width = dataset.RasterXSize
del dataset
return width
@staticmethod
def get_img_height(filename):
"""
:param filename: tif路径
:return: 影像高度
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
height = dataset.RasterYSize
del dataset
return height
@staticmethod
def read_img(filename):
"""
影像读取
:param filename:
:return:
"""
gdal.AllRegister()
img_dataset = gdal.Open(filename) # 打开文件
if img_dataset is None:
msg = 'Could not open ' + filename
logger.error(msg)
return None, None, None
im_proj = img_dataset.GetProjection() # 地图投影信息
if im_proj is None:
return None, None, None
im_geotrans = img_dataset.GetGeoTransform() # 仿射矩阵
im_width = img_dataset.RasterXSize # 栅格矩阵的行数
im_height = img_dataset.RasterYSize # 栅格矩阵的行数
im_arr = img_dataset.ReadAsArray(0, 0, im_width, im_height)
del img_dataset
return im_proj, im_geotrans, im_arr
def cal_img_scope(self, dataset):
"""
计算影像的地理坐标范围
根据GDAL的六参数模型将影像图上坐标行列号转为投影坐标或地理坐标根据具体数据的坐标系统转换
:param dataset :GDAL地理数据
:return: list[point_upleft, point_upright, point_downleft, point_downright]
"""
if dataset is None:
return None
img_geotrans = dataset.GetGeoTransform()
if img_geotrans is None:
return None
width = dataset.RasterXSize # 栅格矩阵的列数
height = dataset.RasterYSize # 栅格矩阵的行数
point_upleft = self.trans_rowcol2geo(img_geotrans, 0, 0)
point_upright = self.trans_rowcol2geo(img_geotrans, width, 0)
point_downleft = self.trans_rowcol2geo(img_geotrans, 0, height)
point_downright = self.trans_rowcol2geo(img_geotrans, width, height)
return [point_upleft, point_upright, point_downleft, point_downright]
@staticmethod
def get_scope_ori_sim(filename):
"""
计算影像的地理坐标范围
根据GDAL的六参数模型将影像图上坐标行列号转为投影坐标或地理坐标根据具体数据的坐标系统转换
:param dataset :GDAL地理数据
:return: list[point_upleft, point_upright, point_downleft, point_downright]
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
width = dataset.RasterXSize # 栅格矩阵的列数
height = dataset.RasterYSize # 栅格矩阵的行数
band1 = dataset.GetRasterBand(1)
array1 = band1.ReadAsArray(0, 0, band1.XSize, band1.YSize)
band2 = dataset.GetRasterBand(2)
array2 = band2.ReadAsArray(0, 0, band2.XSize, band2.YSize)
if array1[0, 0] < array1[0, width-1]:
point_upleft = [array1[0, 0], array2[0, 0]]
point_upright = [array1[0, width-1], array2[0, width-1]]
else:
point_upright = [array1[0, 0], array2[0, 0]]
point_upleft = [array1[0, width-1], array2[0, width-1]]
if array1[height-1, 0] < array1[height-1, width-1]:
point_downleft = [array1[height - 1, 0], array2[height - 1, 0]]
point_downright = [array1[height - 1, width - 1], array2[height - 1, width - 1]]
else:
point_downright = [array1[height - 1, 0], array2[height - 1, 0]]
point_downleft = [array1[height - 1, width - 1], array2[height - 1, width - 1]]
if(array2[0, 0] < array2[height - 1, 0]):
#上下调换顺序
tmp1 = point_upleft
point_upleft = point_downleft
point_downleft = tmp1
tmp2 = point_upright
point_upright = point_downright
point_downright = tmp2
return [point_upleft, point_upright, point_downleft, point_downright]
@staticmethod
def trans_rowcol2geo(img_geotrans,img_col, img_row):
"""
据GDAL的六参数模型仿射矩阵将影像图上坐标行列号转为投影坐标或地理坐标根据具体数据的坐标系统转换
:param img_geotrans: 仿射矩阵
:param img_col:图像纵坐标
:param img_row:图像横坐标
:return: [geo_x,geo_y]
"""
geo_x = img_geotrans[0] + img_geotrans[1] * img_col + img_geotrans[2] * img_row
geo_y = img_geotrans[3] + img_geotrans[4] * img_col + img_geotrans[5] * img_row
return [geo_x, geo_y]
@staticmethod
def write_era_into_img(filename, im_proj, im_geotrans, im_data):
"""
影像保存
:param filename:
:param im_proj:
:param im_geotrans:
:param im_data:
:return:
"""
gdal_dtypes = {
'int8': gdal.GDT_Byte,
'unit16': gdal.GDT_UInt16,
'int16': gdal.GDT_Int16,
'unit32': gdal.GDT_UInt32,
'int32': gdal.GDT_Int32,
'float32': gdal.GDT_Float32,
'float64': gdal.GDT_Float64,
}
if not gdal_dtypes.get(im_data.dtype.name, None) is None:
datatype = gdal_dtypes[im_data.dtype.name]
else:
datatype = gdal.GDT_Float32
# 判读数组维数
if len(im_data.shape) == 3:
im_height, im_width, im_bands = im_data.shape # shape[0] 行数
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
if os.path.exists(os.path.split(filename)[0]) is False:
os.makedirs(os.path.split(filename)[0])
driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据
else:
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[:, :, i])
# dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
# 写GeoTiff文件
@staticmethod
def write_img(filename, im_proj, im_geotrans, im_data, no_data='null'):
"""
影像保存
:param filename: 保存的路径
:param im_proj:
:param im_geotrans:
:param im_data:
:param no_data: 把无效值设置为 nodata
:return:
"""
gdal_dtypes = {
'int8': gdal.GDT_Byte,
'unit16': gdal.GDT_UInt16,
'int16': gdal.GDT_Int16,
'unit32': gdal.GDT_UInt32,
'int32': gdal.GDT_Int32,
'float32': gdal.GDT_Float32,
'float64': gdal.GDT_Float64,
}
if not gdal_dtypes.get(im_data.dtype.name, None) is None:
datatype = gdal_dtypes[im_data.dtype.name]
else:
datatype = gdal.GDT_Float32
# 判读数组维数
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
if os.path.exists(os.path.split(filename)[0]) is False:
os.makedirs(os.path.split(filename)[0])
driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
if im_bands == 1:
# outRaster.GetRasterBand(1).WriteArray(array) # 写入数组数据
outband = dataset.GetRasterBand(1)
outband.WriteArray(im_data)
if no_data != 'null':
outband.SetNoDataValue(np.double(no_data))
outband.FlushCache()
else:
for i in range(im_bands):
outband = dataset.GetRasterBand(1 + i)
outband.WriteArray(im_data[i])
if no_data != 'null':
outband.SetNoDataValue(np.double(no_data))
outband.FlushCache()
# outRaster.GetRasterBand(i + 1).WriteArray(array[i])
del dataset
# 写GeoTiff文件
@staticmethod
def write_img_envi(filename, im_proj, im_geotrans, im_data, no_data='null'):
"""
影像保存
:param filename: 保存的路径
:param im_proj:
:param im_geotrans:
:param im_data:
:param no_data: 把无效值设置为 nodata
:return:
"""
gdal_dtypes = {
'int8': gdal.GDT_Byte,
'unit16': gdal.GDT_UInt16,
'int16': gdal.GDT_Int16,
'unit32': gdal.GDT_UInt32,
'int32': gdal.GDT_Int32,
'float32': gdal.GDT_Float32,
'float64': gdal.GDT_Float64,
}
if not gdal_dtypes.get(im_data.dtype.name, None) is None:
datatype = gdal_dtypes[im_data.dtype.name]
else:
datatype = gdal.GDT_Float32
# 判读数组维数
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
if os.path.exists(os.path.split(filename)[0]) is False:
os.makedirs(os.path.split(filename)[0])
driver = gdal.GetDriverByName("ENVI") # 数据类型必须有,因为要计算需要多大内存空间
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
if im_bands == 1:
# outRaster.GetRasterBand(1).WriteArray(array) # 写入数组数据
outband = dataset.GetRasterBand(1)
outband.WriteArray(im_data)
if no_data != 'null':
outband.SetNoDataValue(no_data)
outband.FlushCache()
else:
for i in range(im_bands):
outband = dataset.GetRasterBand(1 + i)
outband.WriteArray(im_data[i])
outband.FlushCache()
# outRaster.GetRasterBand(i + 1).WriteArray(array[i])
del dataset
@staticmethod
def write_img_rpc(filename, im_proj, im_geotrans, im_data, rpc_dict):
"""
图像中写入rpc信息
"""
# 判断栅格数据的数据类型
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_Int16
else:
datatype = gdal.GDT_Float32
# 判读数组维数
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
# 写入RPC参数
for k in rpc_dict.keys():
dataset.SetMetadataItem(k, rpc_dict[k], 'RPC')
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据
else:
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
def transtif2mask(self,out_tif_path, in_tif_path, threshold):
"""
:param out_tif_path:输出路径
:param in_tif_path:输入的路径
:param threshold:阈值
"""
im_proj, im_geotrans, im_arr, im_scope = self.read_img(in_tif_path)
im_arr_mask = (im_arr < threshold).astype(int)
self.write_img(out_tif_path, im_proj, im_geotrans, im_arr_mask)
def write_quick_view(self, tif_path, color_img=False, quick_view_path=None):
"""
生成快视图,默认快视图和影像同路径且同名
:param tif_path:影像路径
:param color_img:是否生成随机伪彩色图
:param quick_view_path:快视图路径
"""
if quick_view_path is None:
quick_view_path = os.path.splitext(tif_path)[0]+'.jpg'
n = self.get_bands(tif_path)
if n == 1: # 单波段
t_data = self.get_data(tif_path)
else: # 多波段,转为强度数据
t_data = self.get_data(tif_path)
t_data = t_data.astype(float)
t_data = np.sqrt(t_data[0] ** 2 + t_data[1] ** 2)
t_r = self.get_img_height(tif_path)
t_c = self.get_img_width(tif_path)
if t_r > 10000 or t_c > 10000:
q_r = int(t_r / 10)
q_c = int(t_c / 10)
elif 1024 < t_r < 10000 or 1024 < t_c < 10000:
if t_r > t_c:
q_r = 1024
q_c = int(t_c/t_r * 1024)
else:
q_c = 1024
q_r = int(t_r/t_c * 1024)
else:
q_r = t_r
q_c = t_c
if color_img is True:
# 生成伪彩色图
img = np.zeros((t_r, t_c, 3), dtype=np.uint8) # (高,宽,维度)
u = np.unique(t_data)
for i in u:
if i != 0:
w = np.where(t_data == i)
img[w[0], w[1], 0] = np.random.randint(0, 255) # 随机生成一个0到255之间的整数 可以通过挑参数设定不同的颜色范围
img[w[0], w[1], 1] = np.random.randint(0, 255)
img[w[0], w[1], 2] = np.random.randint(0, 255)
img = cv2.resize(img, (q_c, q_r)) # (宽,高)
cv2.imwrite(quick_view_path, img)
# cv2.imshow("result4", img)
# cv2.waitKey(0)
else:
# 灰度图
min = np.percentile(t_data, 2) # np.nanmin(t_data)
max = np.percentile(t_data, 98) # np.nanmax(t_data)
t_data[np.isnan(t_data)] = max
if (max - min) < 256 & (max-min) != 0: # 会报异常RuntimeWarning: divide by zero encountered in true_divide t_data = (t_data - min) / (max - min) * 255 jia's
t_data = (t_data - min) / (max - min) * 255
out_img = Image.fromarray(t_data)
out_img = out_img.resize((q_c, q_r)) # 重采样
out_img = out_img.convert("L") # 转换成灰度图
out_img.save(quick_view_path)
def limit_field(self, out_path, in_path, min_value, max_value):
"""
:param out_path:输出路径
:param in_path:主mask路径输出影像采用主mask的地理信息
:param min_value
:param max_value
"""
proj = self.get_projection(in_path)
geotrans = self.get_geotransform(in_path)
array = self.get_band_array(in_path, 1)
array[array < min_value] = min_value
array[array > max_value] = max_value
self.write_img(out_path, proj, geotrans, array)
return True
def band_merge(self, lon, lat, ori_sim):
lon_arr = self.get_data(lon)
lat_arr = self.get_data(lat)
temp = np.zeros((2, lon_arr.shape[0], lon_arr.shape[1]), dtype=float)
temp[0, :, :] = lon_arr[:, :]
temp[1, :, :] = lat_arr[:, :]
self.write_img(ori_sim, '', [0.0, 1.0, 0.0, 0.0, 0.0, 1.0], temp, '0')
def get_scopes(self, ori_sim):
ori_sim_data = self.get_data(ori_sim)
lon = ori_sim_data[0, :, :]
lat = ori_sim_data[1, :, :]
min_lon = np.nanmin(np.where((lon != 0) & ~np.isnan(lon), lon, np.inf))
max_lon = np.nanmax(np.where((lon != 0) & ~np.isnan(lon), lon, -np.inf))
min_lat = np.nanmin(np.where((lat != 0) & ~np.isnan(lat), lat, np.inf))
max_lat = np.nanmax(np.where((lat != 0) & ~np.isnan(lat), lat, -np.inf))
scopes = [[min_lon, max_lat], [max_lon, max_lat], [min_lon, min_lat], [max_lon, min_lat]]
return scopes
# if __name__ == '__main__':
# path = r'D:\BaiduNetdiskDownload\GZ\lon.rdr'
# path2 = r'D:\BaiduNetdiskDownload\GZ\lat.rdr'
# path3 = r'D:\BaiduNetdiskDownload\GZ\lon_lat.tif'
# s = ImageHandler().band_merge(path, path2, path3)
# print(s)
# pass